AnthonyDi commited on
Commit
5b4cdc4
·
verified ·
1 Parent(s): 8238159

(Trained with Unsloth)

Browse files
Files changed (4) hide show
  1. config.json +9 -13
  2. deepencoderv2.py +1015 -0
  3. modeling_deepseekocr2.py +1051 -0
  4. tokenizer_config.json +1 -1
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
  "architectures": [
3
- "DeepseekOCRForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
- "AutoConfig": "modeling_deepseekocr.DeepseekOCRConfig",
9
- "AutoModel": "modeling_deepseekocr.DeepseekOCRForCausalLM"
10
  },
11
  "aux_loss_alpha": 0.001,
12
  "bos_token_id": 0,
@@ -65,7 +65,7 @@
65
  "lm_head": true,
66
  "max_position_embeddings": 8192,
67
  "mlp_bias": false,
68
- "model_type": "DeepseekOCR",
69
  "moe_intermediate_size": 896,
70
  "n_group": 1,
71
  "n_routed_experts": 64,
@@ -77,7 +77,7 @@
77
  "num_key_value_heads": 10,
78
  "pad_token_id": 2,
79
  "projector_config": {
80
- "input_dim": 2048,
81
  "model_type": "mlp_projector",
82
  "n_embed": 1280,
83
  "projector_type": "linear"
@@ -96,22 +96,18 @@
96
  "topk_group": 1,
97
  "topk_method": "greedy",
98
  "transformers_version": "4.56.2",
99
- "unsloth_version": "2026.1.3",
100
  "use_cache": true,
101
  "use_mla": false,
102
  "v_head_dim": 0,
103
  "vision_config": {
104
  "image_size": 1024,
105
  "mlp_ratio": 3.7362,
106
- "model_name": "deeplip_b_l",
107
  "model_type": "vision",
108
  "width": {
109
- "clip-l-14-224": {
110
- "heads": 16,
111
- "image_size": 224,
112
- "layers": 24,
113
- "patch_size": 14,
114
- "width": 1024
115
  },
116
  "sam_vit_b": {
117
  "downsample_channels": [
 
1
  {
2
  "architectures": [
3
+ "DeepseekOCR2ForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
+ "AutoConfig": "modeling_deepseekocr2.DeepseekOCR2Config",
9
+ "AutoModel": "modeling_deepseekocr2.DeepseekOCR2ForCausalLM"
10
  },
11
  "aux_loss_alpha": 0.001,
12
  "bos_token_id": 0,
 
65
  "lm_head": true,
66
  "max_position_embeddings": 8192,
67
  "mlp_bias": false,
68
+ "model_type": "DeepseekOCR2",
69
  "moe_intermediate_size": 896,
70
  "n_group": 1,
71
  "n_routed_experts": 64,
 
77
  "num_key_value_heads": 10,
78
  "pad_token_id": 2,
79
  "projector_config": {
80
+ "input_dim": 896,
81
  "model_type": "mlp_projector",
82
  "n_embed": 1280,
83
  "projector_type": "linear"
 
96
  "topk_group": 1,
97
  "topk_method": "greedy",
98
  "transformers_version": "4.56.2",
99
+ "unsloth_version": "2026.1.4",
100
  "use_cache": true,
101
  "use_mla": false,
102
  "v_head_dim": 0,
103
  "vision_config": {
104
  "image_size": 1024,
105
  "mlp_ratio": 3.7362,
106
+ "model_name": "deepencoderv2",
107
  "model_type": "vision",
108
  "width": {
109
+ "qwen2-0-5b": {
110
+ "dim": 896
 
 
 
 
111
  },
112
  "sam_vit_b": {
113
  "downsample_channels": [
deepencoderv2.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import copy
5
+
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ # from megatron.model import LayerNorm
10
+
11
+ import transformers
12
+
13
+
14
+ from typing import Optional, Tuple, Type
15
+ from functools import partial
16
+
17
+
18
+
19
+ class MlpProjector(nn.Module):
20
+
21
+ def __init__(self, cfg):
22
+
23
+ super().__init__()
24
+
25
+ self.cfg = cfg
26
+
27
+ if cfg.projector_type == "identity":
28
+ modules = nn.Identity()
29
+
30
+ elif cfg.projector_type == "linear":
31
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
32
+
33
+ elif cfg.projector_type == "mlp_gelu":
34
+ mlp_depth = cfg.get("depth", 1)
35
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
36
+ for _ in range(1, mlp_depth):
37
+ modules.append(nn.GELU())
38
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
39
+ modules = nn.Sequential(*modules)
40
+
41
+ elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
42
+ mlp_depth = cfg.get("depth", 1)
43
+ mlp_ratio = cfg.get("mlp_ratio", 1)
44
+ modules = [
45
+ nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
46
+ nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
47
+ ]
48
+ for _ in range(1, mlp_depth - 1):
49
+ modules.append(nn.GELU())
50
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
51
+ modules.append(nn.GELU())
52
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
53
+ modules = nn.Sequential(*modules)
54
+
55
+ elif cfg.projector_type == "downsample_mlp_gelu":
56
+ mlp_depth = cfg.get("depth", 1)
57
+ mlp_ratio = cfg.get("mlp_ratio", 1)
58
+ modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
59
+ for _ in range(1, mlp_depth - 1):
60
+ modules.append(nn.GELU())
61
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
62
+ modules.append(nn.GELU())
63
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
64
+ modules = nn.Sequential(*modules)
65
+
66
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
67
+ mlp_depth = cfg.get("depth", 1)
68
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
69
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
70
+
71
+ modules = []
72
+ for _ in range(1, mlp_depth):
73
+ modules.append(nn.GELU())
74
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
75
+ modules = nn.Sequential(*modules)
76
+
77
+ elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
78
+ mlp_depth = cfg.get("depth", 1)
79
+ channel_div = cfg.get("channel_div", 0.5)
80
+ self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
81
+ self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
82
+
83
+ modules = []
84
+ for _ in range(1, mlp_depth):
85
+ modules.append(nn.GELU())
86
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
87
+ modules = nn.Sequential(*modules)
88
+
89
+ elif cfg.projector_type == "low_high_split_mlp_gelu":
90
+ mlp_depth = cfg.get("depth", 1)
91
+ modules = []
92
+ for _ in range(1, mlp_depth):
93
+ modules.append(nn.GELU())
94
+ modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
95
+ modules = nn.Sequential(*modules)
96
+ self.high_layers = nn.Sequential(*modules)
97
+ self.low_layers = copy.deepcopy(modules)
98
+
99
+ else:
100
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
101
+
102
+ if cfg.get("token_pooling", False):
103
+ self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
104
+
105
+ if cfg.get("conv_fusion_high_low_features", False):
106
+ self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
107
+ self.layers = modules
108
+
109
+ def forward(self, x):
110
+ if self.cfg.get("token_pooling", False):
111
+ batch_size, wxh, channels = x.shape
112
+ w = h = int(wxh**0.5)
113
+ x = x.view(batch_size, w, h, channels)
114
+ x = x.permute(0, 3, 1, 2)
115
+ # import ipdb; ipdb.set_trace()
116
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
117
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
118
+ # 在通道维度上拼接
119
+ patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
120
+
121
+ # 通过线性层
122
+ patches = patches.permute(0, 2, 1, 3).contiguous()
123
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
124
+
125
+ x = self.token_pooling_layer(patches)
126
+
127
+ if self.cfg.get("conv_fusion_high_low_features", False):
128
+ x = self.fusion_layer(x[:, 0]) + x[:, 1]
129
+
130
+ if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
131
+ high_x, low_x = x[0], x[1]
132
+ high_x = self.high_up_proj(high_x)
133
+ low_x = self.low_up_proj(low_x)
134
+ x = torch.concat([high_x, low_x], dim=-1)
135
+
136
+ if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
137
+ high_x = x[...,:self.cfg.input_dim[0]]
138
+ low_x = x[...,self.cfg.input_dim[0]:]
139
+ high_x = self.high_up_proj(high_x)
140
+ low_x = self.low_up_proj(low_x)
141
+ x = torch.concat([high_x, low_x], dim=-1)
142
+
143
+ if self.cfg.projector_type == 'low_high_split_mlp_gelu':
144
+ high_x, low_x = x[0], x[1]
145
+ high_x = self.high_layers(high_x)
146
+ low_x = self.low_layers(low_x)
147
+ x = torch.concat([high_x, low_x], dim=-1)
148
+ return x
149
+
150
+ if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
151
+ bs, hw, input_dim = x.shape
152
+ h = w = int((hw) ** 0.5)
153
+
154
+ """compute padding"""
155
+ if h % self.cfg.downsample_ratio:
156
+ pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
157
+ else:
158
+ pad = 0
159
+ x = x.reshape(bs, h, w, input_dim)
160
+ if pad > 0:
161
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
162
+
163
+ """4 to 1 concat"""
164
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
165
+ x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
166
+ x = x.permute(0, 2, 1)
167
+
168
+ return self.layers(x)
169
+
170
+ @staticmethod
171
+ def get_flops_per_sample(cfg):
172
+ if cfg.projector_type == "linear":
173
+ fwd = 2 * cfg.input_dim * cfg.n_embed
174
+
175
+ elif "mlp_gelu" in cfg.projector_type :
176
+ mlp_depth = cfg.get("depth", 1)
177
+ downsample_ratio = cfg.get("downsample_ratio", 1)
178
+ input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
179
+ input_dim = input_dim * downsample_ratio * downsample_ratio
180
+ fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
181
+ else:
182
+ fwd = 0
183
+
184
+ return fwd * 3
185
+
186
+
187
+ #===================qwen2================================
188
+
189
+ class CustomQwen2Decoder(nn.Module):
190
+ """
191
+ Qwen2 visual encoder
192
+ non-causal attention + causal attention
193
+ token_type_ids :0=non-causal, 1=causal
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ decoder_layer: int = 24,
199
+ max_position_embeddings: int = 131072,
200
+ hidden_dimension: int = 896,
201
+ num_attention_heads: int = 14,
202
+ num_key_value_heads: int = 2,
203
+ intermediate_size: int = 4864,
204
+ vocab_size: int = 151936,
205
+ attn_implementation: str = "sdpa", # ⭐
206
+ rms_norm_eps: float = 1e-06,
207
+ rope_theta: float = 1000000.0,
208
+ attention_dropout: float = 0.0,
209
+ hidden_act: str = "silu",
210
+ initializer_range: float = 0.02,
211
+ ):
212
+ super().__init__()
213
+
214
+ # attn_implementation check
215
+ if attn_implementation == "flash_attention_2":
216
+ raise ValueError(
217
+ "CustomQwen2Decoder do not support flash_attention_2,"
218
+ "new attention mask needs 'sdpa' or 'eager'"
219
+ )
220
+
221
+ # load
222
+ Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model')
223
+ Qwen2Config = getattr(transformers, 'Qwen2Config')
224
+
225
+ # config
226
+ config = Qwen2Config(
227
+ hidden_size=hidden_dimension,
228
+ num_hidden_layers=decoder_layer,
229
+ num_attention_heads=num_attention_heads,
230
+ num_key_value_heads=num_key_value_heads,
231
+ intermediate_size=intermediate_size,
232
+ max_position_embeddings=max_position_embeddings,
233
+ vocab_size=vocab_size,
234
+ rms_norm_eps=rms_norm_eps,
235
+ rope_theta=rope_theta,
236
+ attention_dropout=attention_dropout,
237
+ hidden_act=hidden_act,
238
+ initializer_range=initializer_range,
239
+ _attn_implementation=attn_implementation, # ⭐
240
+ )
241
+
242
+ #
243
+ self.model = self._create_custom_model(Qwen2Model, config)
244
+
245
+ del self.model.embed_tokens
246
+
247
+ def _create_custom_model(self, Qwen2Model, config):
248
+ """ Qwen2Model """
249
+
250
+ class CustomQwen2ModelInner(Qwen2Model):
251
+
252
+
253
+ def forward(
254
+ self,
255
+ input_ids=None,
256
+ attention_mask=None,
257
+ position_ids=None,
258
+ past_key_values=None,
259
+ inputs_embeds=None,
260
+ token_type_ids=None, # ⭐
261
+ use_cache=None,
262
+ output_attentions=None,
263
+ output_hidden_states=None,
264
+ return_dict=None,
265
+ cache_position=None,
266
+ ):
267
+ # token_type_ids
268
+ self._current_token_type_ids = token_type_ids
269
+
270
+ outputs = super().forward(
271
+ input_ids=input_ids,
272
+ attention_mask=attention_mask,
273
+ position_ids=position_ids,
274
+ past_key_values=past_key_values,
275
+ inputs_embeds=inputs_embeds,
276
+ use_cache=use_cache,
277
+ output_attentions=output_attentions,
278
+ output_hidden_states=output_hidden_states,
279
+ return_dict=return_dict,
280
+ cache_position=cache_position,
281
+ )
282
+
283
+ return outputs
284
+
285
+ def _update_causal_mask(
286
+ self,
287
+ attention_mask,
288
+ input_tensor,
289
+ cache_position,
290
+ past_key_values,
291
+ output_attentions,
292
+ ):
293
+ dtype, device = input_tensor.dtype, input_tensor.device
294
+ min_dtype = torch.finfo(dtype).min
295
+ batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1]
296
+
297
+ token_type_ids = self._current_token_type_ids
298
+
299
+ # attention mask
300
+ causal_mask = self._create_custom_4d_mask(
301
+ sequence_length=sequence_length,
302
+ dtype=dtype,
303
+ device=device,
304
+ batch_size=batch_size,
305
+ token_type_ids=token_type_ids,
306
+ )
307
+
308
+ # padding mask
309
+ if attention_mask is not None and attention_mask.dim() == 2:
310
+ padding_mask = attention_mask[:, None, None, :].to(dtype=dtype)
311
+ padding_mask = (1.0 - padding_mask) * min_dtype
312
+ causal_mask = causal_mask + padding_mask
313
+
314
+ return causal_mask
315
+
316
+ def _create_custom_4d_mask(
317
+ self,
318
+ sequence_length,
319
+ dtype,
320
+ device,
321
+ batch_size,
322
+ token_type_ids,
323
+ ):
324
+ min_dtype = torch.finfo(dtype).min
325
+
326
+ masks = []
327
+ for b in range(batch_size):
328
+ mask = torch.full(
329
+ (sequence_length, sequence_length),
330
+ fill_value=min_dtype,
331
+ dtype=dtype,
332
+ device=device
333
+ )
334
+
335
+ type_ids = token_type_ids[b]
336
+
337
+ image_positions = (type_ids == 0).nonzero(as_tuple=True)[0]
338
+ text_positions = (type_ids == 1).nonzero(as_tuple=True)[0]
339
+
340
+ # non-casual
341
+ if len(image_positions) > 0:
342
+ mask[image_positions[:, None], image_positions] = 0.0
343
+
344
+ # causal
345
+ for i, text_pos in enumerate(text_positions):
346
+ if len(image_positions) > 0:
347
+ mask[text_pos, image_positions] = 0.0
348
+ mask[text_pos, text_positions[:i+1]] = 0.0
349
+
350
+ masks.append(mask)
351
+
352
+ mask = torch.stack(masks, dim=0).unsqueeze(1)
353
+ return mask
354
+
355
+ return CustomQwen2ModelInner(config)
356
+
357
+ def forward(
358
+ self,
359
+ inputs_embeds,
360
+ token_type_ids,
361
+ attention_mask=None,
362
+ **kwargs
363
+ ):
364
+ """
365
+ Args:
366
+ inputs_embeds: [batch_size, seq_len, hidden_dim]
367
+ token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal
368
+ attention_mask: [batch_size, seq_len], optional
369
+ """
370
+ return self.model(
371
+ inputs_embeds=inputs_embeds,
372
+ token_type_ids=token_type_ids,
373
+ attention_mask=attention_mask,
374
+ **kwargs
375
+ )
376
+
377
+
378
+
379
+
380
+
381
+ # batch_size = 2
382
+ # inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
383
+
384
+ # inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
385
+ # token_type_ids = torch.cat([
386
+ # torch.zeros(batch_size, 256, dtype=torch.long),
387
+ # torch.ones(batch_size, 256, dtype=torch.long),
388
+ # ], dim=1).cuda()
389
+
390
+ # # start = time.time()
391
+ # with torch.no_grad():
392
+ # outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids)
393
+ # print(outputs_sdpa[0].shape)
394
+ # print(f"SDPA time: {time.time() - start:.4f}s")
395
+
396
+
397
+
398
+ class Qwen2Decoder2Encoder(nn.Module):
399
+ """
400
+ Decoder based on Multilingual BART
401
+ Set the initial weights and configuration with a pretrained multilingual BART model,
402
+ and modify the detailed configurations as a Nougat decoder
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ decoder_layer: int,
408
+ hidden_dimension: int,
409
+ num_attention_heads: int,
410
+ num_key_value_heads: int,
411
+ intermediate_size: int,
412
+ max_query: int,
413
+ ):
414
+ super().__init__()
415
+
416
+ self.model = CustomQwen2Decoder(
417
+ decoder_layer=decoder_layer,
418
+ hidden_dimension=hidden_dimension,
419
+ num_attention_heads=num_attention_heads,
420
+ num_key_value_heads=num_key_value_heads,
421
+ intermediate_size=intermediate_size,
422
+ attn_implementation="sdpa",
423
+ )
424
+
425
+
426
+
427
+
428
+ self.query_768 = nn.Embedding(144, hidden_dimension)
429
+ self.query_1024 = nn.Embedding(256, hidden_dimension)
430
+
431
+
432
+ # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension)
433
+
434
+
435
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
436
+ x = x.flatten(2).transpose(1, 2)
437
+
438
+ bs, n_query, _ = x.shape
439
+
440
+ if n_query == 144:
441
+ param_img = self.query_768.weight
442
+ elif n_query == 256:
443
+ param_img = self.query_1024.weight
444
+
445
+ batch_query_imgs = param_img.unsqueeze(0).expand(
446
+ bs, -1, -1
447
+ ) # (batch_size, num_queries, hidden_size)
448
+
449
+
450
+
451
+ x_combined = torch.cat([x, batch_query_imgs], dim=1)
452
+
453
+ token_type_ids = torch.cat([
454
+ torch.zeros(bs, n_query, dtype=torch.long),
455
+ torch.ones(bs, n_query, dtype=torch.long),
456
+ ], dim=1)
457
+
458
+
459
+ y = self.model(x_combined, token_type_ids)[0]
460
+
461
+
462
+ y = y[:, n_query:, :] # causal flow query
463
+
464
+
465
+ return y
466
+
467
+
468
+ def build_qwen2_decoder_as_encoder(
469
+ decoder_layer=24,
470
+ hidden_dimension=896,
471
+ num_attention_heads=14,
472
+ num_key_value_heads=2,
473
+ intermediate_size=4864,
474
+ max_query = 400,
475
+ checkpoint=None,
476
+ ):
477
+
478
+ decoder_as_encoder = Qwen2Decoder2Encoder(
479
+ decoder_layer=decoder_layer,
480
+ hidden_dimension = hidden_dimension,
481
+ num_attention_heads = num_attention_heads,
482
+ num_key_value_heads = num_key_value_heads,
483
+ intermediate_size = intermediate_size,
484
+ max_query = max_query
485
+ )
486
+
487
+
488
+
489
+
490
+ if checkpoint is not None:
491
+ # with open(checkpoint, "rb") as f:
492
+ state_dict = torch.load(checkpoint)
493
+
494
+ decoder_as_encoder.load_state_dict(state_dict, strict=True)
495
+ # tob
496
+ print(checkpoint)
497
+ return decoder_as_encoder
498
+
499
+
500
+
501
+
502
+ #=========================Sam-Vary=================================
503
+
504
+
505
+ def get_abs_pos_sam(abs_pos, tgt_size):
506
+
507
+ dtype = abs_pos.dtype
508
+
509
+ src_size = abs_pos.size(1)
510
+
511
+ if src_size != tgt_size:
512
+ old_pos_embed = abs_pos.permute(0, 3, 1, 2)
513
+ old_pos_embed = old_pos_embed.to(torch.float32)
514
+ new_pos_embed = F.interpolate(
515
+ old_pos_embed,
516
+ size=(tgt_size, tgt_size),
517
+ mode='bicubic',
518
+ antialias=True,
519
+ align_corners=False,
520
+ ).to(dtype)
521
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
522
+ return new_pos_embed
523
+ else:
524
+ return abs_pos
525
+
526
+
527
+
528
+
529
+ class MLPBlock(nn.Module):
530
+ def __init__(
531
+ self,
532
+ embedding_dim: int,
533
+ mlp_dim: int,
534
+ act: Type[nn.Module] = nn.GELU,
535
+ ) -> None:
536
+ super().__init__()
537
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
538
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
539
+ self.act = act()
540
+
541
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
542
+ return self.lin2(self.act(self.lin1(x)))
543
+
544
+
545
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
546
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
547
+ class LayerNorm2d(nn.Module):
548
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
549
+ super().__init__()
550
+ self.weight = nn.Parameter(torch.ones(num_channels))
551
+ self.bias = nn.Parameter(torch.zeros(num_channels))
552
+ self.eps = eps
553
+
554
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
555
+ u = x.mean(1, keepdim=True)
556
+ s = (x - u).pow(2).mean(1, keepdim=True)
557
+ x = (x - u) / torch.sqrt(s + self.eps)
558
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
559
+ return x
560
+
561
+
562
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
563
+ class ImageEncoderViT(nn.Module):
564
+ def __init__(
565
+ self,
566
+ img_size: int = 1024,
567
+ patch_size: int = 16,
568
+ in_chans: int = 3,
569
+ embed_dim: int = 768,
570
+ depth: int = 12,
571
+ num_heads: int = 12,
572
+ mlp_ratio: float = 4.0,
573
+ out_chans: int = 256,
574
+ qkv_bias: bool = True,
575
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
576
+ act_layer: Type[nn.Module] = nn.GELU,
577
+ use_abs_pos: bool = True,
578
+ use_rel_pos: bool = False,
579
+ rel_pos_zero_init: bool = True,
580
+ window_size: int = 0,
581
+ global_attn_indexes: Tuple[int, ...] = (),
582
+ ) -> None:
583
+ """
584
+ Args:
585
+ img_size (int): Input image size.
586
+ patch_size (int): Patch size.
587
+ in_chans (int): Number of input image channels.
588
+ embed_dim (int): Patch embedding dimension.
589
+ depth (int): Depth of ViT.
590
+ num_heads (int): Number of attention heads in each ViT block.
591
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
592
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
593
+ norm_layer (nn.Module): Normalization layer.
594
+ act_layer (nn.Module): Activation layer.
595
+ use_abs_pos (bool): If True, use absolute positional embeddings.
596
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
597
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
598
+ window_size (int): Window size for window attention blocks.
599
+ global_attn_indexes (list): Indexes for blocks using global attention.
600
+ """
601
+ super().__init__()
602
+ self.img_size = img_size
603
+
604
+ self.patch_embed = PatchEmbed(
605
+ kernel_size=(patch_size, patch_size),
606
+ stride=(patch_size, patch_size),
607
+ in_chans=in_chans,
608
+ embed_dim=embed_dim,
609
+ )
610
+
611
+ self.pos_embed: Optional[nn.Parameter] = None
612
+ if use_abs_pos:
613
+ # Initialize absolute positional embedding with pretrain image size.
614
+ self.pos_embed = nn.Parameter(
615
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
616
+ )
617
+
618
+ self.blocks = nn.ModuleList()
619
+ for i in range(depth):
620
+ block = Block(
621
+ dim=embed_dim,
622
+ num_heads=num_heads,
623
+ mlp_ratio=mlp_ratio,
624
+ qkv_bias=qkv_bias,
625
+ norm_layer=norm_layer,
626
+ act_layer=act_layer,
627
+ use_rel_pos=use_rel_pos,
628
+ rel_pos_zero_init=rel_pos_zero_init,
629
+ window_size=window_size if i not in global_attn_indexes else 0,
630
+ input_size=(img_size // patch_size, img_size // patch_size),
631
+ )
632
+ self.blocks.append(block)
633
+
634
+ self.neck = nn.Sequential(
635
+ nn.Conv2d(
636
+ embed_dim,
637
+ out_chans,
638
+ kernel_size=1,
639
+ bias=False,
640
+ ),
641
+ LayerNorm2d(out_chans),
642
+ nn.Conv2d(
643
+ out_chans,
644
+ out_chans,
645
+ kernel_size=3,
646
+ padding=1,
647
+ bias=False,
648
+ ),
649
+ LayerNorm2d(out_chans),
650
+ )
651
+
652
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
653
+ self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False)
654
+
655
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
656
+ x = self.patch_embed(x)
657
+ if self.pos_embed is not None:
658
+ # x = x + self.pos_embed
659
+ x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
660
+
661
+ for blk in self.blocks:
662
+ x = blk(x)
663
+
664
+ x = self.neck(x.permute(0, 3, 1, 2))
665
+ x2 = self.net_2(x)
666
+ x3 = self.net_3(x2.clone())
667
+
668
+ return x3
669
+
670
+
671
+ class Block(nn.Module):
672
+ """Transformer blocks with support of window attention and residual propagation blocks"""
673
+
674
+ def __init__(
675
+ self,
676
+ dim: int,
677
+ num_heads: int,
678
+ mlp_ratio: float = 4.0,
679
+ qkv_bias: bool = True,
680
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
681
+ act_layer: Type[nn.Module] = nn.GELU,
682
+ use_rel_pos: bool = False,
683
+ rel_pos_zero_init: bool = True,
684
+ window_size: int = 0,
685
+ input_size: Optional[Tuple[int, int]] = None,
686
+ ) -> None:
687
+ """
688
+ Args:
689
+ dim (int): Number of input channels.
690
+ num_heads (int): Number of attention heads in each ViT block.
691
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
692
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
693
+ norm_layer (nn.Module): Normalization layer.
694
+ act_layer (nn.Module): Activation layer.
695
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
696
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
697
+ window_size (int): Window size for window attention blocks. If it equals 0, then
698
+ use global attention.
699
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
700
+ positional parameter size.
701
+ """
702
+ super().__init__()
703
+ self.norm1 = norm_layer(dim)
704
+ self.attn = Attention(
705
+ dim,
706
+ num_heads=num_heads,
707
+ qkv_bias=qkv_bias,
708
+ use_rel_pos=use_rel_pos,
709
+ rel_pos_zero_init=rel_pos_zero_init,
710
+ input_size=input_size if window_size == 0 else (window_size, window_size),
711
+ )
712
+
713
+ self.norm2 = norm_layer(dim)
714
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
715
+
716
+ self.window_size = window_size
717
+
718
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
719
+ shortcut = x
720
+ x = self.norm1(x)
721
+ # Window partition
722
+ if self.window_size > 0:
723
+ H, W = x.shape[1], x.shape[2]
724
+ x, pad_hw = window_partition(x, self.window_size)
725
+
726
+ x = self.attn(x)
727
+ # Reverse window partition
728
+ if self.window_size > 0:
729
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
730
+
731
+ x = shortcut + x
732
+ x = x + self.mlp(self.norm2(x))
733
+
734
+ return x
735
+
736
+
737
+ class Attention(nn.Module):
738
+ """Multi-head Attention block with relative position embeddings."""
739
+
740
+ def __init__(
741
+ self,
742
+ dim: int,
743
+ num_heads: int = 8,
744
+ qkv_bias: bool = True,
745
+ use_rel_pos: bool = False,
746
+ rel_pos_zero_init: bool = True,
747
+ input_size: Optional[Tuple[int, int]] = None,
748
+ ) -> None:
749
+ """
750
+ Args:
751
+ dim (int): Number of input channels.
752
+ num_heads (int): Number of attention heads.
753
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
754
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
755
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
756
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
757
+ positional parameter size.
758
+ """
759
+ super().__init__()
760
+ self.num_heads = num_heads
761
+ head_dim = dim // num_heads
762
+ self.scale = head_dim**-0.5
763
+
764
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
765
+ self.proj = nn.Linear(dim, dim)
766
+
767
+ self.use_rel_pos = use_rel_pos
768
+ if self.use_rel_pos:
769
+ assert (
770
+ input_size is not None
771
+ ), "Input size must be provided if using relative positional encoding."
772
+ # initialize relative positional embeddings
773
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
774
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
775
+
776
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
777
+ B, H, W, _ = x.shape
778
+ # qkv with shape (3, B, nHead, H * W, C)
779
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
780
+ # q, k, v with shape (B * nHead, H * W, C)
781
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
782
+
783
+ rel_h, rel_w = None, None
784
+ if self.use_rel_pos:
785
+ rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
786
+
787
+ q = q.view(B, self.num_heads, H * W, -1)
788
+ k = k.view(B, self.num_heads, H * W, -1)
789
+ v = v.view(B, self.num_heads, H * W, -1)
790
+
791
+ if self.use_rel_pos:
792
+ rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
793
+ rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
794
+ attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
795
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
796
+ # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
797
+ else:
798
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
799
+
800
+ x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
801
+
802
+ x = self.proj(x)
803
+
804
+ return x
805
+
806
+
807
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
808
+ """
809
+ Partition into non-overlapping windows with padding if needed.
810
+ Args:
811
+ x (tensor): input tokens with [B, H, W, C].
812
+ window_size (int): window size.
813
+
814
+ Returns:
815
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
816
+ (Hp, Wp): padded height and width before partition
817
+ """
818
+ B, H, W, C = x.shape
819
+
820
+ pad_h = (window_size - H % window_size) % window_size
821
+ pad_w = (window_size - W % window_size) % window_size
822
+ if pad_h > 0 or pad_w > 0:
823
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
824
+ Hp, Wp = H + pad_h, W + pad_w
825
+
826
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
827
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
828
+ return windows, (Hp, Wp)
829
+
830
+
831
+ def window_unpartition(
832
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
833
+ ) -> torch.Tensor:
834
+ """
835
+ Window unpartition into original sequences and removing padding.
836
+ Args:
837
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
838
+ window_size (int): window size.
839
+ pad_hw (Tuple): padded height and width (Hp, Wp).
840
+ hw (Tuple): original height and width (H, W) before padding.
841
+
842
+ Returns:
843
+ x: unpartitioned sequences with [B, H, W, C].
844
+ """
845
+ Hp, Wp = pad_hw
846
+ H, W = hw
847
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
848
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
849
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
850
+
851
+ if Hp > H or Wp > W:
852
+ x = x[:, :H, :W, :].contiguous()
853
+ return x
854
+
855
+
856
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
857
+ """
858
+ Get relative positional embeddings according to the relative positions of
859
+ query and key sizes.
860
+ Args:
861
+ q_size (int): size of query q.
862
+ k_size (int): size of key k.
863
+ rel_pos (Tensor): relative position embeddings (L, C).
864
+
865
+ Returns:
866
+ Extracted positional embeddings according to relative positions.
867
+ """
868
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
869
+ # Interpolate rel pos if needed.
870
+ if rel_pos.shape[0] != max_rel_dist:
871
+ # Interpolate rel pos.
872
+ dtype = rel_pos.dtype
873
+ rel_pos = rel_pos.to(torch.float32)
874
+ rel_pos_resized = F.interpolate(
875
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
876
+ size=max_rel_dist,
877
+ mode="linear",
878
+ ).to(dtype)
879
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
880
+ else:
881
+ rel_pos_resized = rel_pos
882
+
883
+ # Scale the coords with short length if shapes for q and k are different.
884
+ q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
885
+ k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
886
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
887
+
888
+ return rel_pos_resized[relative_coords.long()]
889
+
890
+
891
+ def add_decomposed_rel_pos(
892
+ q: torch.Tensor,
893
+ rel_pos_h: torch.Tensor,
894
+ rel_pos_w: torch.Tensor,
895
+ q_size: Tuple[int, int],
896
+ k_size: Tuple[int, int],
897
+ ) -> torch.Tensor:
898
+ """
899
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
900
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
901
+ Args:
902
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
903
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
904
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
905
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
906
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
907
+
908
+ Returns:
909
+ attn (Tensor): attention map with added relative positional embeddings.
910
+ """
911
+ q_h, q_w = q_size
912
+ k_h, k_w = k_size
913
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
914
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
915
+
916
+ B, _, dim = q.shape
917
+ r_q = q.reshape(B, q_h, q_w, dim)
918
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
919
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
920
+ rel_h = rel_h.unsqueeze(-1)
921
+ rel_w = rel_w.unsqueeze(-2)
922
+ rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
923
+ rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
924
+
925
+ return rel_h, rel_w
926
+
927
+
928
+ class PatchEmbed(nn.Module):
929
+ """
930
+ Image to Patch Embedding.
931
+ """
932
+
933
+ def __init__(
934
+ self,
935
+ kernel_size: Tuple[int, int] = (16, 16),
936
+ stride: Tuple[int, int] = (16, 16),
937
+ padding: Tuple[int, int] = (0, 0),
938
+ in_chans: int = 3,
939
+ embed_dim: int = 768,
940
+ ) -> None:
941
+ """
942
+ Args:
943
+ kernel_size (Tuple): kernel size of the projection layer.
944
+ stride (Tuple): stride of the projection layer.
945
+ padding (Tuple): padding size of the projection layer.
946
+ in_chans (int): Number of input image channels.
947
+ embed_dim (int): Patch embedding dimension.
948
+ """
949
+ super().__init__()
950
+
951
+ self.proj = nn.Conv2d(
952
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
953
+ )
954
+
955
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
956
+ x = self.proj(x)
957
+ # B C H W -> B H W C
958
+ x = x.permute(0, 2, 3, 1)
959
+ return x
960
+
961
+
962
+ def build_sam_vit_b(checkpoint=None):
963
+ return _build_sam(
964
+ encoder_embed_dim=768,
965
+ encoder_depth=12,
966
+ encoder_num_heads=12,
967
+ encoder_global_attn_indexes=[2, 5, 8, 11],
968
+ checkpoint=checkpoint,
969
+ )
970
+
971
+ def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
972
+ image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
973
+ # sam = _apply_eval_dtype_sam(sam, dtype)
974
+ image_encoder = torch.compile(image_encoder, mode=compile_mode)
975
+ return image_encoder
976
+
977
+
978
+ def _build_sam(
979
+ encoder_embed_dim,
980
+ encoder_depth,
981
+ encoder_num_heads,
982
+ encoder_global_attn_indexes,
983
+ checkpoint=None,
984
+ ):
985
+ prompt_embed_dim = 256
986
+ image_size = 1024
987
+ vit_patch_size = 16
988
+ image_embedding_size = image_size // vit_patch_size
989
+ image_encoder=ImageEncoderViT(
990
+ depth=encoder_depth,
991
+ embed_dim=encoder_embed_dim,
992
+ img_size=image_size,
993
+ mlp_ratio=4,
994
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
995
+ num_heads=encoder_num_heads,
996
+ patch_size=vit_patch_size,
997
+ qkv_bias=True,
998
+ use_rel_pos=True,
999
+ global_attn_indexes=encoder_global_attn_indexes,
1000
+ window_size=14,
1001
+ out_chans=prompt_embed_dim,
1002
+ )
1003
+ image_encoder.eval()
1004
+ if checkpoint is not None:
1005
+ # with open(checkpoint, "rb") as f:
1006
+ state_dict = torch.load(checkpoint)
1007
+ # print(state_dict.keys())
1008
+ # for key in state_dict:
1009
+ # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
1010
+ # ocr-anyting
1011
+ # image_encoder.load_state_dict(state_dict, strict=True)
1012
+ # tob
1013
+ image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
1014
+ print(checkpoint)
1015
+ return image_encoder
modeling_deepseekocr2.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import re
4
+ from tqdm import tqdm
5
+ from abc import ABC
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ from addict import Dict
9
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import CrossEntropyLoss
15
+ from torchvision import transforms
16
+
17
+ from transformers.cache_utils import Cache
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
20
+ from transformers import DeepseekV2Config
21
+ from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
22
+ DeepseekV2Attention,
23
+ DeepseekV2MLP,
24
+ DeepseekV2MoE,
25
+ DeepseekV2RMSNorm,
26
+ DeepseekV2DecoderLayer,
27
+ )
28
+ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
29
+ from transformers import TextStreamer
30
+ from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector
31
+ from .conversation import get_conv_template
32
+
33
+ torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
34
+
35
+ def load_image(image_path):
36
+
37
+ try:
38
+ image = Image.open(image_path)
39
+
40
+ corrected_image = ImageOps.exif_transpose(image)
41
+
42
+ return corrected_image
43
+
44
+ except Exception as e:
45
+ print(f"error: {e}")
46
+ try:
47
+ return Image.open(image_path)
48
+ except:
49
+ return None
50
+
51
+
52
+ def re_match(text):
53
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
54
+ matches = re.findall(pattern, text, re.DOTALL)
55
+
56
+ # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
57
+ # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
58
+
59
+ mathes_image = []
60
+ mathes_other = []
61
+ for a_match in matches:
62
+ if '<|ref|>image<|/ref|>' in a_match[0]:
63
+ mathes_image.append(a_match[0])
64
+ else:
65
+ mathes_other.append(a_match[0])
66
+ return matches, mathes_image, mathes_other
67
+
68
+
69
+ def extract_coordinates_and_label(ref_text, image_width, image_height):
70
+
71
+ try:
72
+ label_type = ref_text[1]
73
+ cor_list = eval(ref_text[2])
74
+ except Exception as e:
75
+ print(e)
76
+ return None
77
+
78
+ return (label_type, cor_list)
79
+
80
+
81
+ def draw_bounding_boxes(image, refs, ouput_path):
82
+
83
+ image_width, image_height = image.size
84
+
85
+ img_draw = image.copy()
86
+ draw = ImageDraw.Draw(img_draw)
87
+
88
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
89
+ draw2 = ImageDraw.Draw(overlay)
90
+
91
+ # try:
92
+ # except IOError:
93
+ # try:
94
+ # font = ImageFont.truetype("DejaVuSans.ttf", 20)
95
+ # except IOError:
96
+ font = ImageFont.load_default()
97
+
98
+ img_idx = 0
99
+
100
+ for i, ref in enumerate(refs):
101
+ try:
102
+ result = extract_coordinates_and_label(ref, image_width, image_height)
103
+ if result:
104
+ label_type, points_list = result
105
+
106
+ color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
107
+
108
+ color_a = color + (20, )
109
+ for points in points_list:
110
+ x1, y1, x2, y2 = points
111
+
112
+ x1 = int(x1 / 999 * image_width)
113
+ y1 = int(y1 / 999 * image_height)
114
+
115
+ x2 = int(x2 / 999 * image_width)
116
+ y2 = int(y2 / 999 * image_height)
117
+
118
+ if label_type == 'image':
119
+ try:
120
+ cropped = image.crop((x1, y1, x2, y2))
121
+ cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
122
+ except Exception as e:
123
+ print(e)
124
+ pass
125
+ img_idx += 1
126
+
127
+ try:
128
+ if label_type == 'title':
129
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
130
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
131
+ else:
132
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
133
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
134
+ text_x = x1
135
+ text_y = max(0, y1 - 15)
136
+
137
+
138
+ text_bbox = draw.textbbox((0, 0), label_type, font=font)
139
+ text_width = text_bbox[2] - text_bbox[0]
140
+ text_height = text_bbox[3] - text_bbox[1]
141
+ draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
142
+ fill=(255, 255, 255, 30))
143
+
144
+ draw.text((text_x, text_y), label_type, font=font, fill=color)
145
+ except:
146
+ pass
147
+ except:
148
+ continue
149
+ img_draw.paste(overlay, (0, 0), overlay)
150
+ return img_draw
151
+
152
+
153
+ def process_image_with_refs(image, ref_texts, output_path):
154
+
155
+ result_image = draw_bounding_boxes(image, ref_texts, output_path)
156
+
157
+ return result_image
158
+
159
+
160
+
161
+
162
+
163
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
164
+ best_ratio_diff = float('inf')
165
+ best_ratio = (1, 1)
166
+ area = width * height
167
+ for ratio in target_ratios:
168
+ target_aspect_ratio = ratio[0] / ratio[1]
169
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
170
+ if ratio_diff < best_ratio_diff:
171
+ best_ratio_diff = ratio_diff
172
+ best_ratio = ratio
173
+ elif ratio_diff == best_ratio_diff:
174
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
175
+ best_ratio = ratio
176
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
177
+ return best_ratio
178
+
179
+
180
+ def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False):
181
+ orig_width, orig_height = image.size
182
+ aspect_ratio = orig_width / orig_height
183
+
184
+ # calculate the existing image aspect ratio
185
+ target_ratios = set(
186
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
187
+ i * j <= max_num and i * j >= min_num)
188
+ # print(target_ratios)
189
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
190
+
191
+ # find the closest aspect ratio to the target
192
+ target_aspect_ratio = find_closest_aspect_ratio(
193
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
194
+
195
+ # print(target_aspect_ratio)
196
+ # calculate the target width and height
197
+ target_width = image_size * target_aspect_ratio[0]
198
+ target_height = image_size * target_aspect_ratio[1]
199
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
200
+
201
+ # resize the image
202
+ resized_img = image.resize((target_width, target_height))
203
+ processed_images = []
204
+ for i in range(blocks):
205
+ box = (
206
+ (i % (target_width // image_size)) * image_size,
207
+ (i // (target_width // image_size)) * image_size,
208
+ ((i % (target_width // image_size)) + 1) * image_size,
209
+ ((i // (target_width // image_size)) + 1) * image_size
210
+ )
211
+ # split the image
212
+ split_img = resized_img.crop(box)
213
+ processed_images.append(split_img)
214
+ assert len(processed_images) == blocks
215
+ if use_thumbnail and len(processed_images) != 1:
216
+ thumbnail_img = image.resize((image_size, image_size))
217
+ processed_images.append(thumbnail_img)
218
+ return processed_images, target_aspect_ratio
219
+
220
+
221
+
222
+ def normalize_transform(mean, std):
223
+ if mean is None and std is None:
224
+ transform = None
225
+ elif mean is None and std is not None:
226
+ mean = [0.] * len(std)
227
+ transform = transforms.Normalize(mean=mean, std=std)
228
+ elif mean is not None and std is None:
229
+ std = [1.] * len(mean)
230
+ transform = transforms.Normalize(mean=mean, std=std)
231
+ else:
232
+ transform = transforms.Normalize(mean=mean, std=std)
233
+
234
+ return transform
235
+
236
+
237
+
238
+ def format_messages(
239
+ conversations: List[Dict[str, str]],
240
+ sft_format: str = "deepseek",
241
+ system_prompt: str = "",
242
+ ):
243
+ """
244
+ Applies the SFT template to conversation.
245
+
246
+ Args:
247
+ conversations (List[Dict]): A List of messages.
248
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
249
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
250
+
251
+ Returns:
252
+ sft_prompt (str): The formatted text.
253
+ """
254
+
255
+ conv = get_conv_template(sft_format)
256
+ conv.set_system_message(system_prompt)
257
+ for message in conversations:
258
+ conv.append_message(message["role"], message["content"].strip())
259
+ sft_prompt = conv.get_prompt().strip()
260
+
261
+ return sft_prompt
262
+
263
+
264
+ def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
265
+ t = tokenizer.encode(text, add_special_tokens=False)
266
+ bos_id = 0
267
+ eos_id = 1
268
+ if bos:
269
+ t = [bos_id] + t
270
+ if eos:
271
+ t = t + [eos_id]
272
+
273
+ return t
274
+
275
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
276
+ """
277
+
278
+ Args:
279
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
280
+ [
281
+ {
282
+ "role": "User",
283
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
284
+ "images": ["./examples/table_datasets.png"]
285
+ },
286
+ {"role": "Assistant", "content": ""},
287
+ ]
288
+
289
+ Returns:
290
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
291
+
292
+ """
293
+
294
+ pil_images = []
295
+
296
+ for message in conversations:
297
+ if "images" not in message:
298
+ continue
299
+
300
+ for image_path in message["images"]:
301
+ # print('----------------')
302
+ # print(image_path)
303
+ # print('----------------')
304
+ # exit()
305
+
306
+ # pil_img = Image.open(image_path)
307
+ pil_img = load_image(image_path)
308
+ pil_img = pil_img.convert("RGB")
309
+ pil_images.append(pil_img)
310
+
311
+ return pil_images
312
+
313
+
314
+ class BaseTransform(ABC):
315
+
316
+ def set_rng(self, *args, **kwargs):
317
+ pass
318
+
319
+ def __call__(self, *args, **kwargs) -> torch.Tensor:
320
+ pass
321
+
322
+ @property
323
+ def default_shape(self):
324
+ raise NotImplementedError
325
+
326
+
327
+ class BasicImageTransform(BaseTransform):
328
+ def __init__(
329
+ self,
330
+ mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
331
+ std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
332
+ normalize: bool = True
333
+ ):
334
+ self.mean = mean
335
+ self.std = std
336
+
337
+ transform_pipelines = [
338
+ transforms.ToTensor()
339
+ ]
340
+
341
+ normalize = normalize_transform(mean, std) if normalize else nn.Identity()
342
+ if normalize is not None:
343
+ transform_pipelines.append(normalize)
344
+
345
+ self.transform = transforms.Compose(transform_pipelines)
346
+
347
+ def __call__(self, x):
348
+ x = self.transform(x)
349
+ return x
350
+
351
+ class NoEOSTextStreamer(TextStreamer):
352
+ def on_finalized_text(self, text: str, stream_end: bool = False):
353
+
354
+ eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
355
+ text = text.replace(eos_text, "\n")
356
+ print(text, flush=True, end="")
357
+
358
+ def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int):
359
+ nn.Module.__init__(self)
360
+ self.hidden_size = config.hidden_size
361
+
362
+ if config.use_mla:
363
+ self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
364
+ else:
365
+ config.head_dim = config.hidden_size // config.num_attention_heads
366
+ self.self_attn = LlamaAttention(config, layer_idx)
367
+ self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
368
+
369
+ self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
370
+ self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
371
+
372
+
373
+ DeepseekV2DecoderLayer.__init__ = decoder_layer_init
374
+
375
+ class DeepseekOCR2Config(DeepseekV2Config):
376
+ model_type = "DeepseekOCR2"
377
+
378
+ class DeepseekOCR2Model(DeepseekV2Model):
379
+ config_class = DeepseekOCR2Config
380
+
381
+ def __init__(self, config: DeepseekV2Config):
382
+ super(DeepseekOCR2Model, self).__init__(config)
383
+
384
+ self.sam_model = build_sam_vit_b()
385
+ self.qwen2_model = build_qwen2_decoder_as_encoder()
386
+ # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
387
+ n_embed = 1280
388
+ self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed))
389
+ embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
390
+ # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
391
+ self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
392
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
393
+
394
+
395
+ def forward(
396
+ self,
397
+ input_ids: torch.LongTensor = None,
398
+ attention_mask: Optional[torch.Tensor] = None,
399
+ position_ids: Optional[torch.LongTensor] = None,
400
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
401
+ inputs_embeds: Optional[torch.FloatTensor] = None,
402
+ use_cache: Optional[bool] = None,
403
+ output_attentions: Optional[bool] = None,
404
+ output_hidden_states: Optional[bool] = None,
405
+ images: Optional[torch.FloatTensor] = None,
406
+ images_seq_mask: Optional[torch.FloatTensor] = None,
407
+ images_spatial_crop: Optional[torch.FloatTensor] = None,
408
+ return_dict: Optional[bool] = None,
409
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
410
+
411
+ if inputs_embeds is None:
412
+ # inputs_embeds = self.embed_tokens(input_ids)
413
+ inputs_embeds = self.get_input_embeddings()(input_ids)
414
+ inputs_embeds = inputs_embeds.clone()
415
+
416
+ sam_model = getattr(self, 'sam_model', None)
417
+ # sam_model = self.sam_model
418
+ qwen2_model = getattr(self, 'qwen2_model', None)
419
+
420
+ if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
421
+
422
+ idx = 0
423
+
424
+ # sam_model = torch.jit.script(sam_model)
425
+
426
+ # start_time = time.time()
427
+ for image, crop_shape in zip(images, images_spatial_crop):
428
+ images_in_this_batch = []
429
+
430
+ patches = image[0]
431
+ image_ori = image[1]
432
+
433
+ with torch.no_grad():
434
+ # with torch.inference_mode():
435
+
436
+ if torch.sum(patches).item() != 0:
437
+ # P, C, H, W = patches.shape
438
+ crop_flag = 1
439
+ local_features_1 = sam_model(patches)
440
+
441
+ local_features_2 = qwen2_model(local_features_1)
442
+ # vit_time = time.time()
443
+ local_features = local_features_2
444
+ local_features = self.projector(local_features)
445
+
446
+
447
+ global_features_1 = sam_model(image_ori)
448
+ global_features_2 = qwen2_model(global_features_1)
449
+ global_features = global_features_2
450
+ global_features = self.projector(global_features)
451
+
452
+ # print('=====================')
453
+ # print('BASE: ', global_features.shape)
454
+ # print('PATCHES: ', local_features.shape)
455
+ # print('=====================')
456
+
457
+ _, hw, n_dim = global_features.shape
458
+ # h = w = int(hw ** 0.5)
459
+
460
+ _2, hw2, n_dim2 = local_features.shape
461
+ # h2 = w2 = int(hw2 ** 0.5)
462
+
463
+
464
+ global_features = global_features.view(-1, n_dim)
465
+
466
+
467
+ local_features = local_features.view(-1, n_dim2)
468
+
469
+ global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
470
+
471
+ # end_time = time.time()
472
+
473
+ # print('sam: ', sam_time - start_time)
474
+ # print('vit: ', vit_time - sam_time)
475
+ # print('all: ', end_time - start_time)
476
+
477
+ # exit()
478
+
479
+ else:
480
+ global_features_1 = sam_model(image_ori)
481
+ global_features_2 = qwen2_model(global_features_1)
482
+ global_features = global_features_2
483
+ global_features = self.projector(global_features)
484
+ # print('=====================')
485
+ # print('BASE: ', global_features.shape)
486
+ # print('NO PATCHES')
487
+ # print('=====================')
488
+ _, hw, n_dim = global_features.shape
489
+ # h = w = int(hw ** 0.5)
490
+
491
+
492
+ # global_features = global_features.view(h, w, n_dim)
493
+
494
+ # global_features = torch.cat(
495
+ # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
496
+ # )
497
+
498
+ global_features = global_features.view(-1, n_dim)
499
+
500
+ global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
501
+
502
+ images_in_this_batch.append(global_local_features)
503
+
504
+
505
+ # print(inputs_embeds.shape)
506
+
507
+ if images_in_this_batch:
508
+ images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
509
+ # exit()
510
+
511
+ # inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
512
+ images_in_this_batch = images_in_this_batch.to(
513
+ device=inputs_embeds.device, dtype=inputs_embeds.dtype
514
+ )
515
+ mask = images_seq_mask[idx].unsqueeze(-1).to(inputs_embeds.device) # bool [T, 1]
516
+ updated_row = inputs_embeds[idx].masked_scatter(mask, images_in_this_batch)
517
+ inputs_embeds[idx] = updated_row
518
+
519
+ idx += 1
520
+
521
+
522
+ return super(DeepseekOCR2Model, self).forward(
523
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
524
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
525
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
526
+ return_dict=return_dict
527
+ )
528
+
529
+
530
+ class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM):
531
+
532
+ config_class = DeepseekOCR2Config
533
+ # supports_gradient_checkpointing = True
534
+
535
+ def __init__(self, config):
536
+ super(DeepseekV2ForCausalLM, self).__init__(config)
537
+ self.model = DeepseekOCR2Model(config)
538
+
539
+ self.vocab_size = config.vocab_size
540
+
541
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
542
+
543
+ # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
544
+
545
+ # Initialize weights and apply final processing
546
+ self.post_init()
547
+
548
+ def get_model(self):
549
+ return self.model
550
+
551
+
552
+ def forward(
553
+ self,
554
+ input_ids: torch.LongTensor = None,
555
+ attention_mask: Optional[torch.Tensor] = None,
556
+ position_ids: Optional[torch.LongTensor] = None,
557
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
558
+ inputs_embeds: Optional[torch.FloatTensor] = None,
559
+ labels: Optional[torch.LongTensor] = None,
560
+ use_cache: Optional[bool] = None,
561
+ output_attentions: Optional[bool] = None,
562
+ output_hidden_states: Optional[bool] = None,
563
+ images: Optional[torch.FloatTensor] = None,
564
+ images_seq_mask: Optional[torch.FloatTensor] = None,
565
+ images_spatial_crop: Optional[torch.FloatTensor] = None,
566
+ return_dict: Optional[bool] = None,
567
+
568
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
569
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
570
+ output_hidden_states = (
571
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
572
+ )
573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
574
+
575
+
576
+
577
+ outputs = self.model(
578
+ input_ids=input_ids,
579
+ past_key_values=past_key_values,
580
+ attention_mask=attention_mask,
581
+ position_ids=position_ids,
582
+ inputs_embeds=inputs_embeds,
583
+ use_cache=use_cache,
584
+ output_attentions=output_attentions,
585
+ output_hidden_states=output_hidden_states,
586
+ images=images,
587
+ images_seq_mask = images_seq_mask,
588
+ images_spatial_crop = images_spatial_crop,
589
+ return_dict=return_dict
590
+
591
+ )
592
+
593
+
594
+
595
+ # print(transformer_outputs)
596
+
597
+ hidden_states = outputs[0]
598
+ logits = self.lm_head(hidden_states)
599
+ logits = logits.float()
600
+
601
+ # logits
602
+
603
+ loss = None
604
+ if labels is not None:
605
+ # Shift so that tokens < n predict n
606
+ shift_logits = logits[..., :-1, :].contiguous()
607
+ shift_labels = labels[..., 1:].contiguous()
608
+ # Flatten the tokens
609
+ loss_fct = CrossEntropyLoss()
610
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
611
+ shift_labels = shift_labels.view(-1)
612
+ # Enable model parallelism
613
+ shift_labels = shift_labels.to(shift_logits.device)
614
+ loss = loss_fct(shift_logits, shift_labels)
615
+
616
+ if not return_dict:
617
+ output = (logits,) + outputs[1:]
618
+ return (loss,) + output if loss is not None else output
619
+
620
+ return CausalLMOutputWithPast(
621
+ loss=loss,
622
+ logits=logits,
623
+ past_key_values=outputs.past_key_values,
624
+ hidden_states=outputs.hidden_states,
625
+ attentions=outputs.attentions,
626
+ )
627
+
628
+
629
+ def prepare_inputs_for_generation(
630
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
631
+ ):
632
+ # Omit tokens covered by past_key_values
633
+ past_length = 0
634
+ if past_key_values is not None:
635
+ if isinstance(past_key_values, Cache):
636
+ cache_length = past_key_values.get_seq_length()
637
+ past_length = past_key_values.get_seq_length()
638
+ max_cache_length = None
639
+ else:
640
+ cache_length = past_length = past_key_values[0][0].shape[2]
641
+ max_cache_length = None
642
+
643
+ # Keep only the unprocessed tokens:
644
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
645
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
646
+ # input)
647
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
648
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
649
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
650
+ # input_ids based on the past_length.
651
+ elif past_length < input_ids.shape[1]:
652
+ input_ids = input_ids[:, past_length:]
653
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
654
+
655
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
656
+ if (
657
+ max_cache_length is not None
658
+ and attention_mask is not None
659
+ and cache_length + input_ids.shape[1] > max_cache_length
660
+ ):
661
+ attention_mask = attention_mask[:, -max_cache_length:]
662
+
663
+ position_ids = kwargs.get("position_ids", None)
664
+ if attention_mask is not None and position_ids is None:
665
+ # create position_ids on the fly for batch generation
666
+ position_ids = attention_mask.long().cumsum(-1) - 1
667
+ position_ids.masked_fill_(attention_mask == 0, 1)
668
+ if past_key_values:
669
+ position_ids = position_ids[:, -input_ids.shape[1] :]
670
+
671
+ # if self.generation_config.cache_implementation == "static":
672
+ # # generation with static cache
673
+ # cache_position = kwargs.get("cache_position", None)
674
+ # if cache_position is None:
675
+ # past_length = 0
676
+ # else:
677
+ # past_length = cache_position[-1] + 1
678
+ # input_ids = input_ids[:, past_length:]
679
+ # position_ids = position_ids[:, past_length:]
680
+
681
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
682
+ # same goes for position ids. Could also help with continued generation.
683
+ cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
684
+
685
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
686
+ if inputs_embeds is not None and past_key_values is None:
687
+ model_inputs = {"inputs_embeds": inputs_embeds}
688
+ else:
689
+ model_inputs = {"input_ids": input_ids}
690
+
691
+ model_inputs.update(
692
+ {
693
+ "position_ids": position_ids,
694
+ "past_key_values": past_key_values,
695
+ "use_cache": kwargs.get("use_cache"),
696
+ "attention_mask": attention_mask,
697
+ "images": kwargs.get("images", None),
698
+ "images_seq_mask": kwargs.get("images_seq_mask", None),
699
+ "images_spatial_crop": kwargs.get("images_spatial_crop", None),
700
+ }
701
+ )
702
+ return model_inputs
703
+
704
+
705
+ def disable_torch_init(self):
706
+ """
707
+ Disable the redundant torch default initialization to accelerate model creation.
708
+ """
709
+ import torch
710
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
711
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
712
+
713
+
714
+
715
+ def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
716
+ self.disable_torch_init()
717
+
718
+ os.makedirs(output_path, exist_ok=True)
719
+ os.makedirs(f'{output_path}/images', exist_ok=True)
720
+
721
+ if prompt and image_file:
722
+ conversation = [
723
+ {
724
+ "role": "<|User|>",
725
+ # "content": "<image>\n<|grounding|>Given the layout of the image. ",
726
+ "content": f'{prompt}',
727
+ # "content": "君不见黄河之水天上来的下一句是什么?",
728
+ # "content": "<image>\nFree OCR. ",
729
+ # "content": "<image>\nParse the figure. ",
730
+ # "content": "<image>\nExtract the text in the image. ",
731
+ "images": [f'{image_file}'],
732
+ },
733
+ {"role": "<|Assistant|>", "content": ""},
734
+ ]
735
+
736
+ elif prompt:
737
+ conversation = [
738
+ {
739
+ "role": "<|User|>",
740
+ # "content": "<image>\n<|grounding|>Given the layout of the image. ",
741
+ "content": f'{prompt}',
742
+ # "content": "君不见黄河之水天上来的下一句是什么?",
743
+ # "content": "<image>\nFree OCR. ",
744
+ # "content": "<image>\nParse the figure. ",
745
+ # "content": "<image>\nExtract the text in the image. ",
746
+ # "images": [f'{image_file}'],
747
+ },
748
+ {"role": "<|Assistant|>", "content": ""},
749
+ ]
750
+ else:
751
+ assert False, f'prompt is none!'
752
+
753
+ prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
754
+
755
+ patch_size = 16
756
+ downsample_ratio = 4
757
+ images = load_pil_images(conversation)
758
+
759
+ valid_img_tokens = 0
760
+ ratio = 1
761
+
762
+ image_draw = images[0].copy()
763
+
764
+ w,h = image_draw.size
765
+ # print(w, h)
766
+ ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
767
+
768
+
769
+ image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
770
+ images_seq_mask = []
771
+
772
+ image_token = '<image>'
773
+ image_token_id = 128815
774
+ text_splits = prompt.split(image_token)
775
+
776
+ images_list, images_crop_list, images_seq_mask = [], [], []
777
+ tokenized_str = []
778
+ images_spatial_crop = []
779
+ for text_sep, image in zip(text_splits, images):
780
+
781
+ tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
782
+ tokenized_str += tokenized_sep
783
+ images_seq_mask += [False] * len(tokenized_sep)
784
+
785
+ if crop_mode:
786
+
787
+ if image.size[0] <= 768 and image.size[1] <= 768:
788
+ crop_ratio = [1, 1]
789
+
790
+ else:
791
+ if crop_mode:
792
+ # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
793
+ images_crop_raw, crop_ratio = dynamic_preprocess(image)
794
+ else:
795
+ # best_width, best_height = self.image_size, self.image_size
796
+ crop_ratio = [1, 1]
797
+
798
+ """process the global view"""
799
+ # image = image.resize((base_size, base_size))
800
+ global_view = ImageOps.pad(image, (base_size, base_size),
801
+ color=tuple(int(x * 255) for x in image_transform.mean))
802
+
803
+ if base_size == 1024:
804
+ valid_img_tokens += int(256 * ratio)
805
+ elif base_size == 1280:
806
+ valid_img_tokens += int(400 * ratio)
807
+ # elif base_size == 640:
808
+ # valid_img_tokens += int(100 * ratio)
809
+
810
+
811
+
812
+
813
+
814
+ images_list.append(image_transform(global_view).to(torch_dtype))
815
+
816
+ # global_view_tensor = image_transform(global_view).to(torch_dtype)
817
+
818
+ width_crop_num, height_crop_num = crop_ratio
819
+
820
+ images_spatial_crop.append([width_crop_num, height_crop_num])
821
+
822
+
823
+ if width_crop_num > 1 or height_crop_num > 1:
824
+ """process the local views"""
825
+
826
+ for i in range(len(images_crop_raw)):
827
+ images_crop_list.append(image_transform(images_crop_raw[i]).to(torch_dtype))
828
+
829
+ if image_size == 768:
830
+ valid_img_tokens += len(images_crop_list) * 144
831
+
832
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
833
+ num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
834
+
835
+
836
+
837
+ """add image tokens"""
838
+
839
+
840
+
841
+ tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base
842
+ tokenized_image += [image_token_id]
843
+ if width_crop_num > 1 or height_crop_num > 1:
844
+ tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * (
845
+ num_queries * height_crop_num)
846
+ tokenized_str += tokenized_image
847
+ images_seq_mask += [True] * len(tokenized_image)
848
+ # num_image_tokens.append(len(tokenized_image))
849
+
850
+ else:
851
+ # best_width, best_height = self.image_size, self.image_size
852
+ # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
853
+
854
+ """process the global view"""
855
+ if image_size <= 768:
856
+ print('directly resize')
857
+ image = image.resize((image_size, image_size))
858
+ # else:
859
+ global_view = ImageOps.pad(image, (image_size, image_size),
860
+ color=tuple(int(x * 255) for x in image_transform.mean))
861
+ images_list.append(image_transform(global_view).to(torch_dtype))
862
+
863
+ if base_size == 1024:
864
+ valid_img_tokens += int(256 * ratio)
865
+ elif base_size == 1280:
866
+ valid_img_tokens += int(400 * ratio)
867
+ elif base_size == 640:
868
+ valid_img_tokens += int(100 * 1)
869
+ elif base_size == 512:
870
+ valid_img_tokens += int(64 * 1)
871
+ elif base_size == 768:
872
+ valid_img_tokens += int(144 * 1)
873
+
874
+ width_crop_num, height_crop_num = 1, 1
875
+
876
+ images_spatial_crop.append([width_crop_num, height_crop_num])
877
+
878
+
879
+ """add image tokens"""
880
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
881
+
882
+ tokenized_image = ([image_token_id] * num_queries) * num_queries
883
+ tokenized_image += [image_token_id]
884
+ # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
885
+ # num_queries * height_crop_num)
886
+ tokenized_str += tokenized_image
887
+ images_seq_mask += [True] * len(tokenized_image)
888
+ # num_image_tokens.append(len(tokenized_image))
889
+
890
+
891
+ """process the last text split"""
892
+ tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
893
+ tokenized_str += tokenized_sep
894
+ images_seq_mask += [False] * len(tokenized_sep)
895
+
896
+ """add the bos tokens"""
897
+ bos_id = 0
898
+ tokenized_str = [bos_id] + tokenized_str
899
+ images_seq_mask = [False] + images_seq_mask
900
+
901
+
902
+
903
+ input_ids = torch.LongTensor(tokenized_str)
904
+
905
+
906
+
907
+
908
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
909
+
910
+
911
+ if len(images_list) == 0:
912
+ images_ori = torch.zeros((1, 3, image_size, image_size))
913
+ images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
914
+ images_crop = torch.zeros((1, 3, base_size, base_size))
915
+
916
+ else:
917
+ images_ori = torch.stack(images_list, dim=0)
918
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
919
+ if images_crop_list:
920
+ images_crop = torch.stack(images_crop_list, dim=0)
921
+ else:
922
+ images_crop = torch.zeros((1, 3, base_size, base_size))
923
+
924
+
925
+
926
+ if not eval_mode:
927
+ streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
928
+ with torch.autocast("cuda", dtype=torch_dtype):
929
+ with torch.no_grad():
930
+ output_ids = self.generate(
931
+ input_ids.unsqueeze(0).cuda(),
932
+ images=[(images_crop.cuda(), images_ori.cuda())],
933
+ images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
934
+ images_spatial_crop = images_spatial_crop,
935
+ # do_sample=False,
936
+ # num_beams = 1,
937
+ temperature=0.0,
938
+ eos_token_id=tokenizer.eos_token_id,
939
+ streamer=streamer,
940
+ max_new_tokens=8192,
941
+ no_repeat_ngram_size = 20,
942
+ use_cache = True
943
+ )
944
+
945
+ else:
946
+ with torch.autocast("cuda", dtype=torch_dtype):
947
+ with torch.no_grad():
948
+ output_ids = self.generate(
949
+ input_ids.unsqueeze(0).cuda(),
950
+ images=[(images_crop.cuda(), images_ori.cuda())],
951
+ images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
952
+ images_spatial_crop = images_spatial_crop,
953
+ # do_sample=False,
954
+ # num_beams = 1,
955
+ temperature=0.0,
956
+ eos_token_id=tokenizer.eos_token_id,
957
+ max_new_tokens=8192,
958
+ no_repeat_ngram_size = 35,
959
+ use_cache = True
960
+ )
961
+
962
+
963
+ if '<image>' in conversation[0]['content'] and eval_mode:
964
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
965
+ stop_str = '<|end▁of▁sentence|>'
966
+ if outputs.endswith(stop_str):
967
+ outputs = outputs[:-len(stop_str)]
968
+ # re_match
969
+ outputs = outputs.strip()
970
+
971
+ return outputs
972
+
973
+ if '<image>' in conversation[0]['content'] and test_compress:
974
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
975
+ pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
976
+ print('='*50)
977
+ print('image size: ', (w, h))
978
+ print('valid image tokens: ', int(valid_img_tokens))
979
+ print('output texts tokens (valid): ', pure_texts_outputs_token_length)
980
+ print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
981
+ print('='*50)
982
+
983
+
984
+ if '<image>' in conversation[0]['content'] and save_results:
985
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
986
+ stop_str = '<|end▁of▁sentence|>'
987
+
988
+ print('='*15 + 'save results:' + '='*15)
989
+
990
+ # # # # conv.messages[-1][-1] = outputs
991
+ if outputs.endswith(stop_str):
992
+ outputs = outputs[:-len(stop_str)]
993
+ outputs = outputs.strip()
994
+
995
+ matches_ref, matches_images, mathes_other = re_match(outputs)
996
+ # print(matches_ref)
997
+ result = process_image_with_refs(image_draw, matches_ref, output_path)
998
+
999
+
1000
+ for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
1001
+ outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
1002
+
1003
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
1004
+ outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
1005
+
1006
+
1007
+ # if 'structural formula' in conversation[0]['content']:
1008
+ # outputs = '<smiles>' + outputs + '</smiles>'
1009
+ with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
1010
+ afile.write(outputs)
1011
+
1012
+ if 'line_type' in outputs:
1013
+ import matplotlib.pyplot as plt
1014
+ lines = eval(outputs)['Line']['line']
1015
+
1016
+ line_type = eval(outputs)['Line']['line_type']
1017
+ # print(lines)
1018
+
1019
+ endpoints = eval(outputs)['Line']['line_endpoint']
1020
+
1021
+ fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1022
+ ax.set_xlim(-15, 15)
1023
+ ax.set_ylim(-15, 15)
1024
+
1025
+ for idx, line in enumerate(lines):
1026
+ try:
1027
+ p0 = eval(line.split(' -- ')[0])
1028
+ p1 = eval(line.split(' -- ')[-1])
1029
+
1030
+ if line_type[idx] == '--':
1031
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
1032
+ else:
1033
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
1034
+
1035
+ ax.scatter(p0[0], p0[1], s=5, color = 'k')
1036
+ ax.scatter(p1[0], p1[1], s=5, color = 'k')
1037
+ except:
1038
+ pass
1039
+
1040
+ for endpoint in endpoints:
1041
+
1042
+ label = endpoint.split(': ')[0]
1043
+ (x, y) = eval(endpoint.split(': ')[1])
1044
+ ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1045
+ fontsize=5, fontweight='light')
1046
+
1047
+
1048
+ plt.savefig(f'{output_path}/geo.jpg')
1049
+ plt.close()
1050
+
1051
+ result.save(f"{output_path}/result_with_boxes.jpg")
tokenizer_config.json CHANGED
@@ -6655,7 +6655,7 @@
6655
  "legacy": true,
6656
  "model_max_length": 1000000000000000019884624838656,
6657
  "pad_token": "<|▁pad▁|>",
6658
- "padding_side": "left",
6659
  "tokenizer_class": "LlamaTokenizerFast",
6660
  "unk_token": null,
6661
  "use_default_system_prompt": false
 
6655
  "legacy": true,
6656
  "model_max_length": 1000000000000000019884624838656,
6657
  "pad_token": "<|▁pad▁|>",
6658
+ "padding_side": "right",
6659
  "tokenizer_class": "LlamaTokenizerFast",
6660
  "unk_token": null,
6661
  "use_default_system_prompt": false