imdigitalashish commited on
Commit
c2d866c
·
verified ·
1 Parent(s): 9494b3c

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AshishOCR model package - Custom OCR model based on vision-language architecture."""
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoProcessor, AutoImageProcessor
4
+
5
+ from .configuration_ashish_ocr import (
6
+ AshishOcrConfig,
7
+ AshishOcrTextConfig,
8
+ AshishOcrVisionConfig,
9
+ )
10
+ from .modeling_ashish_ocr import (
11
+ AshishOcrPreTrainedModel,
12
+ AshishOcrTextModel,
13
+ AshishOcrVisionEncoder,
14
+ AshishOcrForConditionalGeneration,
15
+ )
16
+ from .processing_ashish_ocr import (
17
+ AshishOcrImageProcessor,
18
+ AshishOcrProcessor,
19
+ )
20
+
21
+ # Register model with transformers Auto classes
22
+ AutoConfig.register("ashish_ocr", AshishOcrConfig)
23
+ AutoConfig.register("ashish_ocr_text", AshishOcrTextConfig)
24
+ AutoConfig.register("ashish_ocr_vision", AshishOcrVisionConfig)
25
+
26
+ AutoModel.register(AshishOcrConfig, AshishOcrForConditionalGeneration)
27
+ AutoModelForCausalLM.register(AshishOcrConfig, AshishOcrForConditionalGeneration)
28
+
29
+ AutoImageProcessor.register(AshishOcrConfig, AshishOcrImageProcessor)
30
+ AutoProcessor.register(AshishOcrConfig, AshishOcrProcessor)
31
+
32
+ __all__ = [
33
+ "AshishOcrConfig",
34
+ "AshishOcrTextConfig",
35
+ "AshishOcrVisionConfig",
36
+ "AshishOcrPreTrainedModel",
37
+ "AshishOcrTextModel",
38
+ "AshishOcrVisionEncoder",
39
+ "AshishOcrForConditionalGeneration",
40
+ "AshishOcrImageProcessor",
41
+ "AshishOcrProcessor",
42
+ ]
config.json CHANGED
@@ -2,6 +2,14 @@
2
  "architectures": [
3
  "AshishOcrForConditionalGeneration"
4
  ],
 
 
 
 
 
 
 
 
5
  "model_type": "ashish_ocr",
6
  "text_config": {
7
  "model_type": "ashish_ocr_text",
 
2
  "architectures": [
3
  "AshishOcrForConditionalGeneration"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_ashish_ocr.AshishOcrConfig",
7
+ "AutoModel": "modeling_ashish_ocr.AshishOcrForConditionalGeneration",
8
+ "AutoModelForCausalLM": "modeling_ashish_ocr.AshishOcrForConditionalGeneration",
9
+ "AutoModelForVision2Seq": "modeling_ashish_ocr.AshishOcrForConditionalGeneration",
10
+ "AutoProcessor": "processing_ashish_ocr.AshishOcrProcessor",
11
+ "AutoImageProcessor": "processing_ashish_ocr.AshishOcrImageProcessor"
12
+ },
13
  "model_type": "ashish_ocr",
14
  "text_config": {
15
  "model_type": "ashish_ocr_text",
configuration_ashish_ocr.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AshishOCR model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class AshishOcrVisionConfig(PretrainedConfig):
7
+ """Configuration class for AshishOCR vision encoder."""
8
+
9
+ model_type = "ashish_ocr_vision"
10
+
11
+ def __init__(
12
+ self,
13
+ hidden_size=1024,
14
+ depth=24,
15
+ num_heads=16,
16
+ attention_bias=True,
17
+ intermediate_size=4096,
18
+ hidden_act="silu",
19
+ hidden_dropout_prob=0.0,
20
+ initializer_range=0.02,
21
+ image_size=336,
22
+ patch_size=14,
23
+ out_hidden_size=1536,
24
+ rms_norm_eps=1e-05,
25
+ spatial_merge_size=2,
26
+ temporal_patch_size=2,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self.hidden_size = hidden_size
31
+ self.depth = depth
32
+ self.num_heads = num_heads
33
+ self.attention_bias = attention_bias
34
+ self.intermediate_size = intermediate_size
35
+ self.hidden_act = hidden_act
36
+ self.hidden_dropout_prob = hidden_dropout_prob
37
+ self.initializer_range = initializer_range
38
+ self.image_size = image_size
39
+ self.patch_size = patch_size
40
+ self.out_hidden_size = out_hidden_size
41
+ self.rms_norm_eps = rms_norm_eps
42
+ self.spatial_merge_size = spatial_merge_size
43
+ self.temporal_patch_size = temporal_patch_size
44
+
45
+
46
+ class AshishOcrTextConfig(PretrainedConfig):
47
+ """Configuration class for AshishOCR text decoder."""
48
+
49
+ model_type = "ashish_ocr_text"
50
+
51
+ def __init__(
52
+ self,
53
+ vocab_size=59392,
54
+ hidden_size=1536,
55
+ intermediate_size=4608,
56
+ num_hidden_layers=16,
57
+ num_attention_heads=16,
58
+ num_key_value_heads=8,
59
+ head_dim=128,
60
+ hidden_act="silu",
61
+ max_position_embeddings=131072,
62
+ initializer_range=0.02,
63
+ rms_norm_eps=1e-05,
64
+ use_cache=True,
65
+ tie_word_embeddings=False,
66
+ attention_bias=False,
67
+ attention_dropout=0.0,
68
+ pad_token_id=59246,
69
+ eos_token_id=None,
70
+ num_nextn_predict_layers=1,
71
+ rope_parameters=None,
72
+ dtype="bfloat16",
73
+ **kwargs,
74
+ ):
75
+ super().__init__(**kwargs)
76
+ self.vocab_size = vocab_size
77
+ self.hidden_size = hidden_size
78
+ self.intermediate_size = intermediate_size
79
+ self.num_hidden_layers = num_hidden_layers
80
+ self.num_attention_heads = num_attention_heads
81
+ self.num_key_value_heads = num_key_value_heads
82
+ self.head_dim = head_dim
83
+ self.hidden_act = hidden_act
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.initializer_range = initializer_range
86
+ self.rms_norm_eps = rms_norm_eps
87
+ self.use_cache = use_cache
88
+ self.tie_word_embeddings = tie_word_embeddings
89
+ self.attention_bias = attention_bias
90
+ self.attention_dropout = attention_dropout
91
+ self.pad_token_id = pad_token_id
92
+ self.eos_token_id = eos_token_id if eos_token_id is not None else [59246, 59253]
93
+ self.num_nextn_predict_layers = num_nextn_predict_layers
94
+ self.rope_parameters = rope_parameters
95
+ self.dtype = dtype
96
+
97
+
98
+ class AshishOcrConfig(PretrainedConfig):
99
+ """Configuration class for AshishOCR multimodal model."""
100
+
101
+ model_type = "ashish_ocr"
102
+ sub_configs = {"text_config": AshishOcrTextConfig, "vision_config": AshishOcrVisionConfig}
103
+
104
+ def __init__(
105
+ self,
106
+ text_config=None,
107
+ vision_config=None,
108
+ image_start_token_id=59256,
109
+ image_end_token_id=59257,
110
+ video_start_token_id=59258,
111
+ video_end_token_id=59259,
112
+ image_token_id=59280,
113
+ video_token_id=59281,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+
118
+ if text_config is None:
119
+ text_config = {}
120
+ if vision_config is None:
121
+ vision_config = {}
122
+
123
+ self.text_config = AshishOcrTextConfig(**text_config)
124
+ self.vision_config = AshishOcrVisionConfig(**vision_config)
125
+
126
+ self.image_start_token_id = image_start_token_id
127
+ self.image_end_token_id = image_end_token_id
128
+ self.video_start_token_id = video_start_token_id
129
+ self.video_end_token_id = video_end_token_id
130
+ self.image_token_id = image_token_id
131
+ self.video_token_id = video_token_id
132
+
133
+ # Inherit key parameters from text config
134
+ self.vocab_size = self.text_config.vocab_size
135
+ self.hidden_size = self.text_config.hidden_size
modeling_ashish_ocr.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AshishOCR model implementation."""
2
+
3
+ import math
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_ashish_ocr import AshishOcrConfig, AshishOcrTextConfig, AshishOcrVisionConfig
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class AshishOcrRMSNorm(nn.Module):
27
+ def __init__(self, hidden_size, eps=1e-6):
28
+ super().__init__()
29
+ self.weight = nn.Parameter(torch.ones(hidden_size))
30
+ self.variance_epsilon = eps
31
+
32
+ def forward(self, hidden_states):
33
+ input_dtype = hidden_states.dtype
34
+ hidden_states = hidden_states.to(torch.float32)
35
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
36
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
37
+ return self.weight * hidden_states.to(input_dtype)
38
+
39
+
40
+ class AshishOcrRotaryEmbedding(nn.Module):
41
+ def __init__(self, dim, max_position_embeddings=131072, base=10000, device=None):
42
+ super().__init__()
43
+ self.dim = dim
44
+ self.max_position_embeddings = max_position_embeddings
45
+ self.base = base
46
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim))
47
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
48
+
49
+ @torch.no_grad()
50
+ def forward(self, x, position_ids):
51
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
52
+ position_ids_expanded = position_ids[:, None, :].float()
53
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
54
+ emb = torch.cat((freqs, freqs), dim=-1)
55
+ cos = emb.cos()
56
+ sin = emb.sin()
57
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
58
+
59
+
60
+ def rotate_half(x):
61
+ x1 = x[..., : x.shape[-1] // 2]
62
+ x2 = x[..., x.shape[-1] // 2 :]
63
+ return torch.cat((-x2, x1), dim=-1)
64
+
65
+
66
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
67
+ cos = cos.unsqueeze(unsqueeze_dim)
68
+ sin = sin.unsqueeze(unsqueeze_dim)
69
+ q_embed = (q * cos) + (rotate_half(q) * sin)
70
+ k_embed = (k * cos) + (rotate_half(k) * sin)
71
+ return q_embed, k_embed
72
+
73
+
74
+ class AshishOcrMLP(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.hidden_size = config.hidden_size
78
+ self.intermediate_size = config.intermediate_size
79
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
80
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
81
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
82
+ self.act_fn = ACT2FN[config.hidden_act]
83
+
84
+ def forward(self, x):
85
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
86
+
87
+
88
+ class AshishOcrAttention(nn.Module):
89
+ def __init__(self, config: AshishOcrTextConfig, layer_idx: int):
90
+ super().__init__()
91
+ self.config = config
92
+ self.layer_idx = layer_idx
93
+ self.hidden_size = config.hidden_size
94
+ self.num_heads = config.num_attention_heads
95
+ self.head_dim = config.head_dim
96
+ self.num_key_value_heads = config.num_key_value_heads
97
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
98
+ self.attention_dropout = config.attention_dropout
99
+
100
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
101
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
102
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
103
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
104
+
105
+ self.rotary_emb = AshishOcrRotaryEmbedding(
106
+ self.head_dim,
107
+ max_position_embeddings=config.max_position_embeddings,
108
+ )
109
+
110
+ def forward(
111
+ self,
112
+ hidden_states: torch.Tensor,
113
+ attention_mask: Optional[torch.Tensor] = None,
114
+ position_ids: Optional[torch.LongTensor] = None,
115
+ past_key_value: Optional[Cache] = None,
116
+ output_attentions: bool = False,
117
+ use_cache: bool = False,
118
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
119
+ bsz, q_len, _ = hidden_states.size()
120
+
121
+ query_states = self.q_proj(hidden_states)
122
+ key_states = self.k_proj(hidden_states)
123
+ value_states = self.v_proj(hidden_states)
124
+
125
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
126
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
127
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
128
+
129
+ cos, sin = self.rotary_emb(value_states, position_ids)
130
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
131
+
132
+ if past_key_value is not None:
133
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
134
+
135
+ # Repeat k/v heads for grouped query attention
136
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
137
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
138
+
139
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
140
+
141
+ if attention_mask is not None:
142
+ attn_weights = attn_weights + attention_mask
143
+
144
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
145
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
146
+ attn_output = torch.matmul(attn_weights, value_states)
147
+
148
+ attn_output = attn_output.transpose(1, 2).contiguous()
149
+ attn_output = attn_output.reshape(bsz, q_len, -1)
150
+ attn_output = self.o_proj(attn_output)
151
+
152
+ if not output_attentions:
153
+ attn_weights = None
154
+
155
+ return attn_output, attn_weights, past_key_value
156
+
157
+
158
+ class AshishOcrDecoderLayer(nn.Module):
159
+ def __init__(self, config: AshishOcrTextConfig, layer_idx: int):
160
+ super().__init__()
161
+ self.hidden_size = config.hidden_size
162
+ self.self_attn = AshishOcrAttention(config, layer_idx)
163
+ self.mlp = AshishOcrMLP(config)
164
+ self.input_layernorm = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
165
+ self.post_attention_layernorm = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ position_ids: Optional[torch.LongTensor] = None,
172
+ past_key_value: Optional[Cache] = None,
173
+ output_attentions: bool = False,
174
+ use_cache: bool = False,
175
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
176
+ residual = hidden_states
177
+ hidden_states = self.input_layernorm(hidden_states)
178
+
179
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
180
+ hidden_states=hidden_states,
181
+ attention_mask=attention_mask,
182
+ position_ids=position_ids,
183
+ past_key_value=past_key_value,
184
+ output_attentions=output_attentions,
185
+ use_cache=use_cache,
186
+ )
187
+ hidden_states = residual + hidden_states
188
+
189
+ residual = hidden_states
190
+ hidden_states = self.post_attention_layernorm(hidden_states)
191
+ hidden_states = self.mlp(hidden_states)
192
+ hidden_states = residual + hidden_states
193
+
194
+ outputs = (hidden_states,)
195
+
196
+ if output_attentions:
197
+ outputs += (self_attn_weights,)
198
+
199
+ if use_cache:
200
+ outputs += (present_key_value,)
201
+
202
+ return outputs
203
+
204
+
205
+ # ==================== Vision Encoder ====================
206
+
207
+ class AshishOcrVisionMLP(nn.Module):
208
+ def __init__(self, config: AshishOcrVisionConfig):
209
+ super().__init__()
210
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
211
+ self.act = ACT2FN[config.hidden_act]
212
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
213
+
214
+ def forward(self, hidden_states):
215
+ hidden_states = self.fc1(hidden_states)
216
+ hidden_states = self.act(hidden_states)
217
+ hidden_states = self.fc2(hidden_states)
218
+ return hidden_states
219
+
220
+
221
+ class AshishOcrVisionAttention(nn.Module):
222
+ def __init__(self, config: AshishOcrVisionConfig):
223
+ super().__init__()
224
+ self.num_heads = config.num_heads
225
+ self.head_dim = config.hidden_size // config.num_heads
226
+
227
+ self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
228
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
229
+
230
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
231
+ bsz, seq_len, _ = hidden_states.size()
232
+
233
+ qkv = self.qkv(hidden_states)
234
+ qkv = qkv.reshape(bsz, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
235
+ q, k, v = qkv.unbind(0)
236
+
237
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
238
+ attn_weights = F.softmax(attn_weights, dim=-1)
239
+ attn_output = torch.matmul(attn_weights, v)
240
+
241
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1)
242
+ attn_output = self.proj(attn_output)
243
+ return attn_output
244
+
245
+
246
+ class AshishOcrVisionBlock(nn.Module):
247
+ def __init__(self, config: AshishOcrVisionConfig):
248
+ super().__init__()
249
+ self.norm1 = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
+ self.attn = AshishOcrVisionAttention(config)
251
+ self.norm2 = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
252
+ self.mlp = AshishOcrVisionMLP(config)
253
+
254
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
255
+ hidden_states = hidden_states + self.attn(self.norm1(hidden_states))
256
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
257
+ return hidden_states
258
+
259
+
260
+ class AshishOcrPatchEmbed(nn.Module):
261
+ def __init__(self, config: AshishOcrVisionConfig):
262
+ super().__init__()
263
+ self.patch_size = config.patch_size
264
+ self.temporal_patch_size = config.temporal_patch_size
265
+ self.proj = nn.Conv3d(
266
+ 3,
267
+ config.hidden_size,
268
+ kernel_size=(config.temporal_patch_size, config.patch_size, config.patch_size),
269
+ stride=(config.temporal_patch_size, config.patch_size, config.patch_size),
270
+ bias=False,
271
+ )
272
+
273
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
274
+ # hidden_states: (B, C, T, H, W)
275
+ hidden_states = self.proj(hidden_states)
276
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, N, D)
277
+ return hidden_states
278
+
279
+
280
+ class AshishOcrPatchMerger(nn.Module):
281
+ def __init__(self, config: AshishOcrVisionConfig):
282
+ super().__init__()
283
+ self.hidden_size = config.hidden_size
284
+ self.out_hidden_size = config.out_hidden_size
285
+ self.spatial_merge_size = config.spatial_merge_size
286
+
287
+ self.mlp = nn.Sequential(
288
+ AshishOcrRMSNorm(config.hidden_size * config.spatial_merge_size ** 2, eps=config.rms_norm_eps),
289
+ nn.Linear(config.hidden_size * config.spatial_merge_size ** 2, config.out_hidden_size, bias=False),
290
+ nn.GELU(),
291
+ nn.Linear(config.out_hidden_size, config.out_hidden_size, bias=False),
292
+ )
293
+
294
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
295
+ # Merge spatial patches
296
+ batch_size = hidden_states.shape[0]
297
+ merged_states = []
298
+
299
+ for b in range(batch_size):
300
+ t, h, w = grid_thw[b].tolist() if grid_thw.dim() > 1 else grid_thw.tolist()
301
+ states = hidden_states[b, :t*h*w]
302
+ states = states.view(t, h, w, -1)
303
+
304
+ # Merge spatial patches
305
+ h_new = h // self.spatial_merge_size
306
+ w_new = w // self.spatial_merge_size
307
+ states = states.view(t, h_new, self.spatial_merge_size, w_new, self.spatial_merge_size, -1)
308
+ states = states.permute(0, 1, 3, 2, 4, 5).contiguous()
309
+ states = states.view(t * h_new * w_new, -1)
310
+ merged_states.append(states)
311
+
312
+ hidden_states = torch.stack(merged_states, dim=0)
313
+ hidden_states = self.mlp(hidden_states)
314
+ return hidden_states
315
+
316
+
317
+ class AshishOcrVisionEncoder(nn.Module):
318
+ def __init__(self, config: AshishOcrVisionConfig):
319
+ super().__init__()
320
+ self.config = config
321
+ self.patch_embed = AshishOcrPatchEmbed(config)
322
+ self.blocks = nn.ModuleList([AshishOcrVisionBlock(config) for _ in range(config.depth)])
323
+ self.merger = AshishOcrPatchMerger(config)
324
+
325
+ def forward(
326
+ self,
327
+ pixel_values: torch.Tensor,
328
+ grid_thw: Optional[torch.Tensor] = None,
329
+ ) -> torch.Tensor:
330
+ hidden_states = self.patch_embed(pixel_values)
331
+
332
+ for block in self.blocks:
333
+ hidden_states = block(hidden_states)
334
+
335
+ if grid_thw is not None:
336
+ hidden_states = self.merger(hidden_states, grid_thw)
337
+
338
+ return hidden_states
339
+
340
+
341
+ # ==================== Main Model ====================
342
+
343
+ class AshishOcrPreTrainedModel(PreTrainedModel):
344
+ config_class = AshishOcrConfig
345
+ base_model_prefix = "model"
346
+ supports_gradient_checkpointing = True
347
+ _no_split_modules = ["AshishOcrDecoderLayer", "AshishOcrVisionBlock"]
348
+
349
+ def _init_weights(self, module):
350
+ std = self.config.text_config.initializer_range if hasattr(self.config, 'text_config') else 0.02
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+ elif isinstance(module, nn.Embedding):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+
358
+
359
+ class AshishOcrTextModel(AshishOcrPreTrainedModel):
360
+ def __init__(self, config: AshishOcrTextConfig):
361
+ super().__init__(config)
362
+ self.config = config
363
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
364
+ self.layers = nn.ModuleList(
365
+ [AshishOcrDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
366
+ )
367
+ self.norm = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
368
+ self.gradient_checkpointing = False
369
+ self.post_init()
370
+
371
+ def forward(
372
+ self,
373
+ input_ids: Optional[torch.LongTensor] = None,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ position_ids: Optional[torch.LongTensor] = None,
376
+ past_key_values: Optional[Cache] = None,
377
+ inputs_embeds: Optional[torch.FloatTensor] = None,
378
+ use_cache: Optional[bool] = None,
379
+ output_attentions: Optional[bool] = None,
380
+ output_hidden_states: Optional[bool] = None,
381
+ return_dict: Optional[bool] = None,
382
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
383
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
384
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
385
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
386
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
387
+
388
+ if inputs_embeds is None:
389
+ inputs_embeds = self.embed_tokens(input_ids)
390
+
391
+ batch_size, seq_length = inputs_embeds.shape[:2]
392
+
393
+ if position_ids is None:
394
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0)
395
+
396
+ if past_key_values is None:
397
+ past_key_values = DynamicCache()
398
+
399
+ # Create causal mask
400
+ if attention_mask is None:
401
+ attention_mask = torch.ones((batch_size, seq_length), device=inputs_embeds.device)
402
+
403
+ causal_mask = self._prepare_attention_mask(attention_mask, seq_length, inputs_embeds.dtype, inputs_embeds.device)
404
+
405
+ hidden_states = inputs_embeds
406
+ all_hidden_states = () if output_hidden_states else None
407
+ all_self_attns = () if output_attentions else None
408
+
409
+ for decoder_layer in self.layers:
410
+ if output_hidden_states:
411
+ all_hidden_states += (hidden_states,)
412
+
413
+ layer_outputs = decoder_layer(
414
+ hidden_states,
415
+ attention_mask=causal_mask,
416
+ position_ids=position_ids,
417
+ past_key_value=past_key_values,
418
+ output_attentions=output_attentions,
419
+ use_cache=use_cache,
420
+ )
421
+ hidden_states = layer_outputs[0]
422
+
423
+ if output_attentions:
424
+ all_self_attns += (layer_outputs[1],)
425
+
426
+ hidden_states = self.norm(hidden_states)
427
+
428
+ if output_hidden_states:
429
+ all_hidden_states += (hidden_states,)
430
+
431
+ return BaseModelOutputWithPast(
432
+ last_hidden_state=hidden_states,
433
+ past_key_values=past_key_values if use_cache else None,
434
+ hidden_states=all_hidden_states,
435
+ attentions=all_self_attns,
436
+ )
437
+
438
+ def _prepare_attention_mask(self, attention_mask, seq_length, dtype, device):
439
+ # Create causal mask
440
+ causal_mask = torch.triu(torch.ones((seq_length, seq_length), device=device), diagonal=1)
441
+ causal_mask = causal_mask.masked_fill(causal_mask == 1, float("-inf"))
442
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
443
+
444
+ # Expand attention mask
445
+ if attention_mask.dim() == 2:
446
+ extended_mask = attention_mask[:, None, None, :]
447
+ extended_mask = (1.0 - extended_mask) * float("-inf")
448
+ causal_mask = causal_mask + extended_mask
449
+
450
+ return causal_mask.to(dtype)
451
+
452
+
453
+ class AshishOcrForConditionalGeneration(AshishOcrPreTrainedModel):
454
+ _tied_weights_keys = ["lm_head.weight"]
455
+
456
+ def __init__(self, config: AshishOcrConfig):
457
+ super().__init__(config)
458
+ self.config = config
459
+
460
+ self.visual = AshishOcrVisionEncoder(config.vision_config)
461
+ self.model = AshishOcrTextModel(config.text_config)
462
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
463
+
464
+ self.image_token_id = config.image_token_id
465
+ self.video_token_id = config.video_token_id
466
+
467
+ self.post_init()
468
+
469
+ def get_input_embeddings(self):
470
+ return self.model.embed_tokens
471
+
472
+ def set_input_embeddings(self, value):
473
+ self.model.embed_tokens = value
474
+
475
+ def get_output_embeddings(self):
476
+ return self.lm_head
477
+
478
+ def set_output_embeddings(self, new_embeddings):
479
+ self.lm_head = new_embeddings
480
+
481
+ def forward(
482
+ self,
483
+ input_ids: Optional[torch.LongTensor] = None,
484
+ attention_mask: Optional[torch.Tensor] = None,
485
+ position_ids: Optional[torch.LongTensor] = None,
486
+ past_key_values: Optional[Cache] = None,
487
+ inputs_embeds: Optional[torch.FloatTensor] = None,
488
+ pixel_values: Optional[torch.FloatTensor] = None,
489
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
490
+ image_grid_thw: Optional[torch.LongTensor] = None,
491
+ video_grid_thw: Optional[torch.LongTensor] = None,
492
+ labels: Optional[torch.LongTensor] = None,
493
+ use_cache: Optional[bool] = None,
494
+ output_attentions: Optional[bool] = None,
495
+ output_hidden_states: Optional[bool] = None,
496
+ return_dict: Optional[bool] = None,
497
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
498
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
499
+
500
+ if inputs_embeds is None:
501
+ inputs_embeds = self.model.embed_tokens(input_ids)
502
+
503
+ # Process images if provided
504
+ if pixel_values is not None:
505
+ image_embeds = self.visual(pixel_values, image_grid_thw)
506
+ image_mask = input_ids == self.image_token_id
507
+ inputs_embeds = inputs_embeds.clone()
508
+ inputs_embeds[image_mask] = image_embeds.view(-1, image_embeds.shape[-1])
509
+
510
+ # Process videos if provided
511
+ if pixel_values_videos is not None:
512
+ video_embeds = self.visual(pixel_values_videos, video_grid_thw)
513
+ video_mask = input_ids == self.video_token_id
514
+ inputs_embeds = inputs_embeds.clone()
515
+ inputs_embeds[video_mask] = video_embeds.view(-1, video_embeds.shape[-1])
516
+
517
+ outputs = self.model(
518
+ attention_mask=attention_mask,
519
+ position_ids=position_ids,
520
+ past_key_values=past_key_values,
521
+ inputs_embeds=inputs_embeds,
522
+ use_cache=use_cache,
523
+ output_attentions=output_attentions,
524
+ output_hidden_states=output_hidden_states,
525
+ return_dict=return_dict,
526
+ )
527
+
528
+ hidden_states = outputs[0]
529
+ logits = self.lm_head(hidden_states)
530
+ logits = logits.float()
531
+
532
+ loss = None
533
+ if labels is not None:
534
+ shift_logits = logits[..., :-1, :].contiguous()
535
+ shift_labels = labels[..., 1:].contiguous()
536
+ loss_fct = CrossEntropyLoss()
537
+ shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
538
+ shift_labels = shift_labels.view(-1)
539
+ shift_labels = shift_labels.to(shift_logits.device)
540
+ loss = loss_fct(shift_logits, shift_labels)
541
+
542
+ if not return_dict:
543
+ output = (logits,) + outputs[1:]
544
+ return (loss,) + output if loss is not None else output
545
+
546
+ return CausalLMOutputWithPast(
547
+ loss=loss,
548
+ logits=logits,
549
+ past_key_values=outputs.past_key_values,
550
+ hidden_states=outputs.hidden_states,
551
+ attentions=outputs.attentions,
552
+ )
553
+
554
+ def prepare_inputs_for_generation(
555
+ self,
556
+ input_ids,
557
+ past_key_values=None,
558
+ attention_mask=None,
559
+ inputs_embeds=None,
560
+ pixel_values=None,
561
+ pixel_values_videos=None,
562
+ image_grid_thw=None,
563
+ video_grid_thw=None,
564
+ **kwargs,
565
+ ):
566
+ if past_key_values is not None:
567
+ input_ids = input_ids[:, -1:]
568
+
569
+ model_inputs = {
570
+ "input_ids": input_ids,
571
+ "past_key_values": past_key_values,
572
+ "attention_mask": attention_mask,
573
+ "inputs_embeds": inputs_embeds,
574
+ "pixel_values": pixel_values,
575
+ "pixel_values_videos": pixel_values_videos,
576
+ "image_grid_thw": image_grid_thw,
577
+ "video_grid_thw": video_grid_thw,
578
+ }
579
+ return model_inputs
preprocessor_config.json CHANGED
@@ -1,4 +1,8 @@
1
  {
 
 
 
 
2
  "size": {"shortest_edge": 12544, "longest_edge": 9633792},
3
  "do_rescale": true,
4
  "patch_size": 14,
@@ -6,6 +10,6 @@
6
  "merge_size": 2,
7
  "image_mean": [0.48145466, 0.4578275, 0.40821073],
8
  "image_std": [0.26862954, 0.26130258, 0.27577711],
9
- "image_processor_type": "Ashish46VImageProcessor",
10
- "processor_class": "Ashish46VProcessor"
11
  }
 
1
  {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_ashish_ocr.AshishOcrImageProcessor",
4
+ "AutoProcessor": "processing_ashish_ocr.AshishOcrProcessor"
5
+ },
6
  "size": {"shortest_edge": 12544, "longest_edge": 9633792},
7
  "do_rescale": true,
8
  "patch_size": 14,
 
10
  "merge_size": 2,
11
  "image_mean": [0.48145466, 0.4578275, 0.40821073],
12
  "image_std": [0.26862954, 0.26130258, 0.27577711],
13
+ "image_processor_type": "AshishOcrImageProcessor",
14
+ "processor_class": "AshishOcrProcessor"
15
  }
processing_ashish_ocr.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AshishOCR processor for handling image and text inputs."""
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
6
+ from transformers.image_utils import ImageInput
7
+ from transformers.processing_utils import ProcessorMixin
8
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
9
+
10
+
11
+ class AshishOcrImageProcessor(BaseImageProcessor):
12
+ """Image processor for AshishOCR model."""
13
+
14
+ model_input_names = ["pixel_values", "image_grid_thw"]
15
+
16
+ def __init__(
17
+ self,
18
+ do_resize: bool = True,
19
+ size: dict = None,
20
+ do_rescale: bool = True,
21
+ rescale_factor: float = 1/255,
22
+ do_normalize: bool = True,
23
+ image_mean: list = None,
24
+ image_std: list = None,
25
+ min_pixels: int = 56 * 56,
26
+ max_pixels: int = 28 * 28 * 1280,
27
+ patch_size: int = 14,
28
+ temporal_patch_size: int = 2,
29
+ merge_size: int = 2,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.do_resize = do_resize
34
+ self.size = size if size is not None else {"shortest_edge": 336}
35
+ self.do_rescale = do_rescale
36
+ self.rescale_factor = rescale_factor
37
+ self.do_normalize = do_normalize
38
+ self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
39
+ self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
40
+ self.min_pixels = min_pixels
41
+ self.max_pixels = max_pixels
42
+ self.patch_size = patch_size
43
+ self.temporal_patch_size = temporal_patch_size
44
+ self.merge_size = merge_size
45
+
46
+ def preprocess(
47
+ self,
48
+ images: ImageInput,
49
+ **kwargs,
50
+ ) -> BatchFeature:
51
+ import numpy as np
52
+ import torch
53
+ from PIL import Image
54
+
55
+ if not isinstance(images, list):
56
+ images = [images]
57
+
58
+ processed_images = []
59
+ grid_thw = []
60
+
61
+ for image in images:
62
+ if isinstance(image, str):
63
+ image = Image.open(image).convert("RGB")
64
+ elif not isinstance(image, Image.Image):
65
+ image = Image.fromarray(np.array(image))
66
+
67
+ # Resize
68
+ width, height = image.size
69
+ target_size = self.size.get("shortest_edge", 336)
70
+
71
+ # Calculate resize dimensions
72
+ if width < height:
73
+ new_width = target_size
74
+ new_height = int(height * target_size / width)
75
+ else:
76
+ new_height = target_size
77
+ new_width = int(width * target_size / height)
78
+
79
+ # Ensure dimensions are divisible by patch_size
80
+ new_width = (new_width // self.patch_size) * self.patch_size
81
+ new_height = (new_height // self.patch_size) * self.patch_size
82
+
83
+ image = image.resize((new_width, new_height), Image.BILINEAR)
84
+
85
+ # Convert to tensor
86
+ image_array = np.array(image).astype(np.float32)
87
+
88
+ if self.do_rescale:
89
+ image_array = image_array * self.rescale_factor
90
+
91
+ if self.do_normalize:
92
+ image_array = (image_array - np.array(self.image_mean)) / np.array(self.image_std)
93
+
94
+ # HWC to CHW
95
+ image_tensor = torch.tensor(image_array).permute(2, 0, 1)
96
+
97
+ # Add temporal dimension for 3D conv: (C, H, W) -> (C, T, H, W)
98
+ image_tensor = image_tensor.unsqueeze(1).repeat(1, self.temporal_patch_size, 1, 1)
99
+
100
+ processed_images.append(image_tensor)
101
+
102
+ # Calculate grid size (T, H, W in patches)
103
+ t = 1
104
+ h = new_height // self.patch_size
105
+ w = new_width // self.patch_size
106
+ grid_thw.append([t, h, w])
107
+
108
+ pixel_values = torch.stack(processed_images, dim=0)
109
+ image_grid_thw = torch.tensor(grid_thw, dtype=torch.long)
110
+
111
+ return BatchFeature(data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw})
112
+
113
+
114
+ class AshishOcrProcessor(ProcessorMixin):
115
+ """Processor for AshishOCR that combines image processor and tokenizer."""
116
+
117
+ attributes = ["image_processor", "tokenizer"]
118
+ image_processor_class = "AshishOcrImageProcessor"
119
+ tokenizer_class = "AutoTokenizer"
120
+
121
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
122
+ super().__init__(image_processor, tokenizer)
123
+
124
+ def __call__(
125
+ self,
126
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
127
+ images: ImageInput = None,
128
+ videos: ImageInput = None,
129
+ padding: bool = False,
130
+ truncation: bool = None,
131
+ max_length: int = None,
132
+ return_tensors: str = None,
133
+ **kwargs,
134
+ ) -> BatchFeature:
135
+ encoding = BatchFeature()
136
+
137
+ if images is not None:
138
+ image_features = self.image_processor(images, **kwargs)
139
+ encoding.update(image_features)
140
+
141
+ if text is not None:
142
+ text_encoding = self.tokenizer(
143
+ text,
144
+ padding=padding,
145
+ truncation=truncation,
146
+ max_length=max_length,
147
+ return_tensors=return_tensors,
148
+ **kwargs,
149
+ )
150
+ encoding.update(text_encoding)
151
+
152
+ return encoding
153
+
154
+ def batch_decode(self, *args, **kwargs):
155
+ return self.tokenizer.batch_decode(*args, **kwargs)
156
+
157
+ def decode(self, *args, **kwargs):
158
+ return self.tokenizer.decode(*args, **kwargs)
159
+
160
+ @property
161
+ def model_input_names(self):
162
+ tokenizer_input_names = self.tokenizer.model_input_names
163
+ image_processor_input_names = self.image_processor.model_input_names
164
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
tokenizer_config.json CHANGED
@@ -44,6 +44,6 @@
44
  "model_max_length": 655380,
45
  "pad_token": "<|endoftext|>",
46
  "padding_side": "left",
47
- "processor_class": "Ashish46VProcessor",
48
  "tokenizer_class": "TokenizersBackend"
49
  }
 
44
  "model_max_length": 655380,
45
  "pad_token": "<|endoftext|>",
46
  "padding_side": "left",
47
+ "processor_class": "AshishOcrProcessor",
48
  "tokenizer_class": "TokenizersBackend"
49
  }