SamMikaelson commited on
Commit
48e5cd4
·
verified ·
1 Parent(s): 6dd535b

Add modeling_deepseekocr.py for standalone loading

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr.py +1037 -0
modeling_deepseekocr.py ADDED
@@ -0,0 +1,1037 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2
+ from .configuration_deepseek_v2 import DeepseekV2Config
3
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4
+ from typing import List, Optional, Tuple, Union
5
+ from transformers.cache_utils import Cache
6
+ import requests
7
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
8
+ from io import BytesIO
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import os
15
+ from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
16
+ from addict import Dict
17
+ from transformers import TextStreamer
18
+ from .conversation import get_conv_template
19
+ from abc import ABC
20
+ import math
21
+ import re
22
+ from tqdm import tqdm
23
+ import numpy as np
24
+ import time
25
+
26
+
27
+ def load_image(image_path):
28
+
29
+ try:
30
+ image = Image.open(image_path)
31
+
32
+ corrected_image = ImageOps.exif_transpose(image)
33
+
34
+ return corrected_image
35
+
36
+ except Exception as e:
37
+ print(f"error: {e}")
38
+ try:
39
+ return Image.open(image_path)
40
+ except:
41
+ return None
42
+
43
+
44
+ def re_match(text):
45
+ pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
46
+ matches = re.findall(pattern, text, re.DOTALL)
47
+
48
+ # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
49
+ # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
50
+
51
+ mathes_image = []
52
+ mathes_other = []
53
+ for a_match in matches:
54
+ if '<|ref|>image<|/ref|>' in a_match[0]:
55
+ mathes_image.append(a_match[0])
56
+ else:
57
+ mathes_other.append(a_match[0])
58
+ return matches, mathes_image, mathes_other
59
+
60
+
61
+ def extract_coordinates_and_label(ref_text, image_width, image_height):
62
+
63
+ try:
64
+ label_type = ref_text[1]
65
+ cor_list = eval(ref_text[2])
66
+ except Exception as e:
67
+ print(e)
68
+ return None
69
+
70
+ return (label_type, cor_list)
71
+
72
+
73
+ def draw_bounding_boxes(image, refs, ouput_path):
74
+
75
+ image_width, image_height = image.size
76
+
77
+ img_draw = image.copy()
78
+ draw = ImageDraw.Draw(img_draw)
79
+
80
+ overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
81
+ draw2 = ImageDraw.Draw(overlay)
82
+
83
+ # try:
84
+ # except IOError:
85
+ # try:
86
+ # font = ImageFont.truetype("DejaVuSans.ttf", 20)
87
+ # except IOError:
88
+ font = ImageFont.load_default()
89
+
90
+ img_idx = 0
91
+
92
+ for i, ref in enumerate(refs):
93
+ try:
94
+ result = extract_coordinates_and_label(ref, image_width, image_height)
95
+ if result:
96
+ label_type, points_list = result
97
+
98
+ color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
99
+
100
+ color_a = color + (20, )
101
+ for points in points_list:
102
+ x1, y1, x2, y2 = points
103
+
104
+ x1 = int(x1 / 999 * image_width)
105
+ y1 = int(y1 / 999 * image_height)
106
+
107
+ x2 = int(x2 / 999 * image_width)
108
+ y2 = int(y2 / 999 * image_height)
109
+
110
+ if label_type == 'image':
111
+ try:
112
+ cropped = image.crop((x1, y1, x2, y2))
113
+ cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
114
+ except Exception as e:
115
+ print(e)
116
+ pass
117
+ img_idx += 1
118
+
119
+ try:
120
+ if label_type == 'title':
121
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
122
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
123
+ else:
124
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
125
+ draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
126
+ text_x = x1
127
+ text_y = max(0, y1 - 15)
128
+
129
+
130
+ text_bbox = draw.textbbox((0, 0), label_type, font=font)
131
+ text_width = text_bbox[2] - text_bbox[0]
132
+ text_height = text_bbox[3] - text_bbox[1]
133
+ draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
134
+ fill=(255, 255, 255, 30))
135
+
136
+ draw.text((text_x, text_y), label_type, font=font, fill=color)
137
+ except:
138
+ pass
139
+ except:
140
+ continue
141
+ img_draw.paste(overlay, (0, 0), overlay)
142
+ return img_draw
143
+
144
+
145
+ def process_image_with_refs(image, ref_texts, output_path):
146
+
147
+ result_image = draw_bounding_boxes(image, ref_texts, output_path)
148
+
149
+ return result_image
150
+
151
+
152
+
153
+
154
+
155
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
156
+ best_ratio_diff = float('inf')
157
+ best_ratio = (1, 1)
158
+ area = width * height
159
+ for ratio in target_ratios:
160
+ target_aspect_ratio = ratio[0] / ratio[1]
161
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
162
+ if ratio_diff < best_ratio_diff:
163
+ best_ratio_diff = ratio_diff
164
+ best_ratio = ratio
165
+ elif ratio_diff == best_ratio_diff:
166
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
167
+ best_ratio = ratio
168
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
169
+ return best_ratio
170
+
171
+
172
+ def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False):
173
+ orig_width, orig_height = image.size
174
+ aspect_ratio = orig_width / orig_height
175
+
176
+ # calculate the existing image aspect ratio
177
+ target_ratios = set(
178
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
179
+ i * j <= max_num and i * j >= min_num)
180
+ # print(target_ratios)
181
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
182
+
183
+ # find the closest aspect ratio to the target
184
+ target_aspect_ratio = find_closest_aspect_ratio(
185
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
186
+
187
+ # print(target_aspect_ratio)
188
+ # calculate the target width and height
189
+ target_width = image_size * target_aspect_ratio[0]
190
+ target_height = image_size * target_aspect_ratio[1]
191
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
192
+
193
+ # resize the image
194
+ resized_img = image.resize((target_width, target_height))
195
+ processed_images = []
196
+ for i in range(blocks):
197
+ box = (
198
+ (i % (target_width // image_size)) * image_size,
199
+ (i // (target_width // image_size)) * image_size,
200
+ ((i % (target_width // image_size)) + 1) * image_size,
201
+ ((i // (target_width // image_size)) + 1) * image_size
202
+ )
203
+ # split the image
204
+ split_img = resized_img.crop(box)
205
+ processed_images.append(split_img)
206
+ assert len(processed_images) == blocks
207
+ if use_thumbnail and len(processed_images) != 1:
208
+ thumbnail_img = image.resize((image_size, image_size))
209
+ processed_images.append(thumbnail_img)
210
+ return processed_images, target_aspect_ratio
211
+
212
+
213
+
214
+ def normalize_transform(mean, std):
215
+ if mean is None and std is None:
216
+ transform = None
217
+ elif mean is None and std is not None:
218
+ mean = [0.] * len(std)
219
+ transform = transforms.Normalize(mean=mean, std=std)
220
+ elif mean is not None and std is None:
221
+ std = [1.] * len(mean)
222
+ transform = transforms.Normalize(mean=mean, std=std)
223
+ else:
224
+ transform = transforms.Normalize(mean=mean, std=std)
225
+
226
+ return transform
227
+
228
+
229
+
230
+ def format_messages(
231
+ conversations: List[Dict[str, str]],
232
+ sft_format: str = "deepseek",
233
+ system_prompt: str = "",
234
+ ):
235
+ """
236
+ Applies the SFT template to conversation.
237
+
238
+ Args:
239
+ conversations (List[Dict]): A List of messages.
240
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
241
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
242
+
243
+ Returns:
244
+ sft_prompt (str): The formatted text.
245
+ """
246
+
247
+ conv = get_conv_template(sft_format)
248
+ conv.set_system_message(system_prompt)
249
+ for message in conversations:
250
+ conv.append_message(message["role"], message["content"].strip())
251
+ sft_prompt = conv.get_prompt().strip()
252
+
253
+ return sft_prompt
254
+
255
+
256
+ def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
257
+ t = tokenizer.encode(text, add_special_tokens=False)
258
+ bos_id = 0
259
+ eos_id = 1
260
+ if bos:
261
+ t = [bos_id] + t
262
+ if eos:
263
+ t = t + [eos_id]
264
+
265
+ return t
266
+
267
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
268
+ """
269
+
270
+ Args:
271
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
272
+ [
273
+ {
274
+ "role": "User",
275
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
276
+ "images": ["./examples/table_datasets.png"]
277
+ },
278
+ {"role": "Assistant", "content": ""},
279
+ ]
280
+
281
+ Returns:
282
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
283
+
284
+ """
285
+
286
+ pil_images = []
287
+
288
+ for message in conversations:
289
+ if "images" not in message:
290
+ continue
291
+
292
+ for image_path in message["images"]:
293
+ # print('----------------')
294
+ # print(image_path)
295
+ # print('----------------')
296
+ # exit()
297
+
298
+ # pil_img = Image.open(image_path)
299
+ pil_img = load_image(image_path)
300
+ pil_img = pil_img.convert("RGB")
301
+ pil_images.append(pil_img)
302
+
303
+ return pil_images
304
+
305
+
306
+ class BaseTransform(ABC):
307
+
308
+ def set_rng(self, *args, **kwargs):
309
+ pass
310
+
311
+ def __call__(self, *args, **kwargs) -> torch.Tensor:
312
+ pass
313
+
314
+ @property
315
+ def default_shape(self):
316
+ raise NotImplementedError
317
+
318
+
319
+ class BasicImageTransform(BaseTransform):
320
+ def __init__(
321
+ self,
322
+ mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
323
+ std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
324
+ normalize: bool = True
325
+ ):
326
+ self.mean = mean
327
+ self.std = std
328
+
329
+ transform_pipelines = [
330
+ transforms.ToTensor()
331
+ ]
332
+
333
+ normalize = normalize_transform(mean, std) if normalize else nn.Identity()
334
+ if normalize is not None:
335
+ transform_pipelines.append(normalize)
336
+
337
+ self.transform = transforms.Compose(transform_pipelines)
338
+
339
+ def __call__(self, x):
340
+ x = self.transform(x)
341
+ return x
342
+
343
+ class NoEOSTextStreamer(TextStreamer):
344
+ def on_finalized_text(self, text: str, stream_end: bool = False):
345
+
346
+ eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
347
+ text = text.replace(eos_text, "\n")
348
+ print(text, flush=True, end="")
349
+
350
+
351
+ class DeepseekOCRConfig(DeepseekV2Config):
352
+ model_type = "DeepseekOCR"
353
+
354
+ class DeepseekOCRModel(DeepseekV2Model):
355
+ config_class = DeepseekOCRConfig
356
+
357
+ def __init__(self, config: DeepseekV2Config):
358
+ super(DeepseekOCRModel, self).__init__(config)
359
+
360
+ self.sam_model = build_sam_vit_b()
361
+ self.vision_model = build_clip_l()
362
+ # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
363
+ n_embed = 1280
364
+ self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
365
+ embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
366
+ self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
367
+ self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
368
+
369
+
370
+
371
+
372
+ def forward(
373
+ self,
374
+ input_ids: torch.LongTensor = None,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ position_ids: Optional[torch.LongTensor] = None,
377
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
378
+ inputs_embeds: Optional[torch.FloatTensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ images: Optional[torch.FloatTensor] = None,
383
+ images_seq_mask: Optional[torch.FloatTensor] = None,
384
+ images_spatial_crop: Optional[torch.FloatTensor] = None,
385
+ return_dict: Optional[bool] = None,
386
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
387
+
388
+
389
+
390
+
391
+ if inputs_embeds is None:
392
+ # inputs_embeds = self.embed_tokens(input_ids)
393
+ inputs_embeds = self.get_input_embeddings()(input_ids)
394
+
395
+
396
+
397
+ sam_model = getattr(self, 'sam_model', None)
398
+ # sam_model = self.sam_model
399
+ vision_model = getattr(self, 'vision_model', None)
400
+
401
+
402
+
403
+ if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
404
+
405
+ idx = 0
406
+
407
+ # sam_model = torch.jit.script(sam_model)
408
+
409
+ # start_time = time.time()
410
+ for image, crop_shape in zip(images, images_spatial_crop):
411
+ images_in_this_batch = []
412
+
413
+ patches = image[0]
414
+ image_ori = image[1]
415
+
416
+ with torch.no_grad():
417
+ # with torch.inference_mode():
418
+
419
+ if torch.sum(patches).item() != 0:
420
+ # P, C, H, W = patches.shape
421
+ crop_flag = 1
422
+ local_features_1 = sam_model(patches)
423
+
424
+ local_features_2 = vision_model(patches, local_features_1)
425
+ # vit_time = time.time()
426
+ local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
427
+ local_features = self.projector(local_features)
428
+
429
+
430
+ global_features_1 = sam_model(image_ori)
431
+ global_features_2 = vision_model(image_ori, global_features_1)
432
+ global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
433
+ global_features = self.projector(global_features)
434
+
435
+ print('=====================')
436
+ print('BASE: ', global_features.shape)
437
+ print('PATCHES: ', local_features.shape)
438
+ print('=====================')
439
+
440
+ _, hw, n_dim = global_features.shape
441
+ h = w = int(hw ** 0.5)
442
+
443
+ _2, hw2, n_dim2 = local_features.shape
444
+ h2 = w2 = int(hw2 ** 0.5)
445
+
446
+ width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
447
+
448
+ global_features = global_features.view(h, w, n_dim)
449
+
450
+ global_features = torch.cat(
451
+ [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
452
+ )
453
+
454
+ global_features = global_features.view(-1, n_dim)
455
+
456
+
457
+ local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
458
+ local_features = torch.cat(
459
+ [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
460
+ )
461
+ local_features = local_features.view(-1, n_dim2)
462
+
463
+ global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
464
+
465
+ # end_time = time.time()
466
+
467
+ # print('sam: ', sam_time - start_time)
468
+ # print('vit: ', vit_time - sam_time)
469
+ # print('all: ', end_time - start_time)
470
+
471
+ # exit()
472
+
473
+ else:
474
+ global_features_1 = sam_model(image_ori)
475
+ global_features_2 = vision_model(image_ori, global_features_1)
476
+ global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
477
+ global_features = self.projector(global_features)
478
+ print('=====================')
479
+ print('BASE: ', global_features.shape)
480
+ print('NO PATCHES')
481
+ print('=====================')
482
+ _, hw, n_dim = global_features.shape
483
+ h = w = int(hw ** 0.5)
484
+
485
+
486
+ global_features = global_features.view(h, w, n_dim)
487
+
488
+ global_features = torch.cat(
489
+ [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
490
+ )
491
+
492
+ global_features = global_features.view(-1, n_dim)
493
+
494
+ global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
495
+
496
+ images_in_this_batch.append(global_local_features)
497
+
498
+
499
+ # print(inputs_embeds.shape)
500
+
501
+ if images_in_this_batch:
502
+ images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
+ # exit()
504
+
505
+ inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
506
+
507
+ idx += 1
508
+
509
+
510
+ return super(DeepseekOCRModel, self).forward(
511
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
512
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
513
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
514
+ return_dict=return_dict
515
+ )
516
+
517
+
518
+ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
519
+
520
+ config_class = DeepseekOCRConfig
521
+ # supports_gradient_checkpointing = True
522
+
523
+ def __init__(self, config):
524
+ super(DeepseekV2ForCausalLM, self).__init__(config)
525
+ self.model = DeepseekOCRModel(config)
526
+
527
+ self.vocab_size = config.vocab_size
528
+
529
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
530
+
531
+ # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
532
+
533
+ # Initialize weights and apply final processing
534
+ self.post_init()
535
+
536
+ def get_model(self):
537
+ return self.model
538
+
539
+
540
+ def forward(
541
+ self,
542
+ input_ids: torch.LongTensor = None,
543
+ attention_mask: Optional[torch.Tensor] = None,
544
+ position_ids: Optional[torch.LongTensor] = None,
545
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
546
+ inputs_embeds: Optional[torch.FloatTensor] = None,
547
+ labels: Optional[torch.LongTensor] = None,
548
+ use_cache: Optional[bool] = None,
549
+ output_attentions: Optional[bool] = None,
550
+ output_hidden_states: Optional[bool] = None,
551
+ images: Optional[torch.FloatTensor] = None,
552
+ images_seq_mask: Optional[torch.FloatTensor] = None,
553
+ images_spatial_crop: Optional[torch.FloatTensor] = None,
554
+ return_dict: Optional[bool] = None,
555
+
556
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
557
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
558
+ output_hidden_states = (
559
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
560
+ )
561
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
562
+
563
+
564
+
565
+ outputs = self.model(
566
+ input_ids=input_ids,
567
+ past_key_values=past_key_values,
568
+ attention_mask=attention_mask,
569
+ position_ids=position_ids,
570
+ inputs_embeds=inputs_embeds,
571
+ use_cache=use_cache,
572
+ output_attentions=output_attentions,
573
+ output_hidden_states=output_hidden_states,
574
+ images=images,
575
+ images_seq_mask = images_seq_mask,
576
+ images_spatial_crop = images_spatial_crop,
577
+ return_dict=return_dict
578
+
579
+ )
580
+
581
+
582
+
583
+ # print(transformer_outputs)
584
+
585
+ hidden_states = outputs[0]
586
+ logits = self.lm_head(hidden_states)
587
+ logits = logits.float()
588
+
589
+ # logits
590
+
591
+ loss = None
592
+ if labels is not None:
593
+ # Shift so that tokens < n predict n
594
+ shift_logits = logits[..., :-1, :].contiguous()
595
+ shift_labels = labels[..., 1:].contiguous()
596
+ # Flatten the tokens
597
+ loss_fct = CrossEntropyLoss()
598
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
599
+ shift_labels = shift_labels.view(-1)
600
+ # Enable model parallelism
601
+ shift_labels = shift_labels.to(shift_logits.device)
602
+ loss = loss_fct(shift_logits, shift_labels)
603
+
604
+ if not return_dict:
605
+ output = (logits,) + outputs[1:]
606
+ return (loss,) + output if loss is not None else output
607
+
608
+ return CausalLMOutputWithPast(
609
+ loss=loss,
610
+ logits=logits,
611
+ past_key_values=outputs.past_key_values,
612
+ hidden_states=outputs.hidden_states,
613
+ attentions=outputs.attentions,
614
+ )
615
+
616
+
617
+ def prepare_inputs_for_generation(
618
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
619
+ ):
620
+ # Omit tokens covered by past_key_values
621
+ past_length = 0
622
+ if past_key_values is not None:
623
+ if isinstance(past_key_values, Cache):
624
+ cache_length = past_key_values.get_seq_length()
625
+ past_length = past_key_values.seen_tokens
626
+ max_cache_length = past_key_values.get_max_length()
627
+ else:
628
+ cache_length = past_length = past_key_values[0][0].shape[2]
629
+ max_cache_length = None
630
+
631
+ # Keep only the unprocessed tokens:
632
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
633
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
634
+ # input)
635
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
636
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
637
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
638
+ # input_ids based on the past_length.
639
+ elif past_length < input_ids.shape[1]:
640
+ input_ids = input_ids[:, past_length:]
641
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
642
+
643
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
644
+ if (
645
+ max_cache_length is not None
646
+ and attention_mask is not None
647
+ and cache_length + input_ids.shape[1] > max_cache_length
648
+ ):
649
+ attention_mask = attention_mask[:, -max_cache_length:]
650
+
651
+ position_ids = kwargs.get("position_ids", None)
652
+ if attention_mask is not None and position_ids is None:
653
+ # create position_ids on the fly for batch generation
654
+ position_ids = attention_mask.long().cumsum(-1) - 1
655
+ position_ids.masked_fill_(attention_mask == 0, 1)
656
+ if past_key_values:
657
+ position_ids = position_ids[:, -input_ids.shape[1] :]
658
+
659
+ # if self.generation_config.cache_implementation == "static":
660
+ # # generation with static cache
661
+ # cache_position = kwargs.get("cache_position", None)
662
+ # if cache_position is None:
663
+ # past_length = 0
664
+ # else:
665
+ # past_length = cache_position[-1] + 1
666
+ # input_ids = input_ids[:, past_length:]
667
+ # position_ids = position_ids[:, past_length:]
668
+
669
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
670
+ # same goes for position ids. Could also help with continued generation.
671
+ cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
672
+
673
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
674
+ if inputs_embeds is not None and past_key_values is None:
675
+ model_inputs = {"inputs_embeds": inputs_embeds}
676
+ else:
677
+ model_inputs = {"input_ids": input_ids}
678
+
679
+ model_inputs.update(
680
+ {
681
+ "position_ids": position_ids,
682
+ "past_key_values": past_key_values,
683
+ "use_cache": kwargs.get("use_cache"),
684
+ "attention_mask": attention_mask,
685
+ "images": kwargs.get("images", None),
686
+ "images_seq_mask": kwargs.get("images_seq_mask", None),
687
+ "images_spatial_crop": kwargs.get("images_spatial_crop", None),
688
+ }
689
+ )
690
+ return model_inputs
691
+
692
+
693
+ def disable_torch_init(self):
694
+ """
695
+ Disable the redundant torch default initialization to accelerate model creation.
696
+ """
697
+ import torch
698
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
699
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
700
+
701
+
702
+
703
+ def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
704
+ self.disable_torch_init()
705
+
706
+ os.makedirs(output_path, exist_ok=True)
707
+ os.makedirs(f'{output_path}/images', exist_ok=True)
708
+
709
+ if prompt and image_file:
710
+ conversation = [
711
+ {
712
+ "role": "<|User|>",
713
+ # "content": "<image>\n<|grounding|>Given the layout of the image. ",
714
+ "content": f'{prompt}',
715
+ # "content": "君不见黄河之水天上来的下一句是什么?",
716
+ # "content": "<image>\nFree OCR. ",
717
+ # "content": "<image>\nParse the figure. ",
718
+ # "content": "<image>\nExtract the text in the image. ",
719
+ "images": [f'{image_file}'],
720
+ },
721
+ {"role": "<|Assistant|>", "content": ""},
722
+ ]
723
+
724
+ elif prompt:
725
+ conversation = [
726
+ {
727
+ "role": "<|User|>",
728
+ # "content": "<image>\n<|grounding|>Given the layout of the image. ",
729
+ "content": f'{prompt}',
730
+ # "content": "君不见黄河之水天上来的下一句是什么?",
731
+ # "content": "<image>\nFree OCR. ",
732
+ # "content": "<image>\nParse the figure. ",
733
+ # "content": "<image>\nExtract the text in the image. ",
734
+ # "images": [f'{image_file}'],
735
+ },
736
+ {"role": "<|Assistant|>", "content": ""},
737
+ ]
738
+ else:
739
+ assert False, f'prompt is none!'
740
+
741
+ prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
742
+
743
+ patch_size = 16
744
+ downsample_ratio = 4
745
+ images = load_pil_images(conversation)
746
+
747
+ valid_img_tokens = 0
748
+ ratio = 1
749
+
750
+ image_draw = images[0].copy()
751
+
752
+ w,h = image_draw.size
753
+ # print(w, h)
754
+ ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
755
+
756
+
757
+ image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
758
+ images_seq_mask = []
759
+
760
+ image_token = '<image>'
761
+ image_token_id = 128815
762
+ text_splits = prompt.split(image_token)
763
+
764
+ images_list, images_crop_list, images_seq_mask = [], [], []
765
+ tokenized_str = []
766
+ images_spatial_crop = []
767
+ for text_sep, image in zip(text_splits, images):
768
+
769
+ tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
770
+ tokenized_str += tokenized_sep
771
+ images_seq_mask += [False] * len(tokenized_sep)
772
+
773
+ if crop_mode:
774
+
775
+ if image.size[0] <= 640 and image.size[1] <= 640:
776
+ crop_ratio = [1, 1]
777
+
778
+ else:
779
+ if crop_mode:
780
+ # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
781
+ images_crop_raw, crop_ratio = dynamic_preprocess(image)
782
+ else:
783
+ # best_width, best_height = self.image_size, self.image_size
784
+ crop_ratio = [1, 1]
785
+
786
+ """process the global view"""
787
+ # image = image.resize((base_size, base_size))
788
+ global_view = ImageOps.pad(image, (base_size, base_size),
789
+ color=tuple(int(x * 255) for x in image_transform.mean))
790
+
791
+ if base_size == 1024:
792
+ valid_img_tokens += int(256 * ratio)
793
+ elif base_size == 1280:
794
+ valid_img_tokens += int(400 * ratio)
795
+ # elif base_size == 640:
796
+ # valid_img_tokens += int(100 * ratio)
797
+
798
+
799
+
800
+
801
+
802
+ images_list.append(image_transform(global_view).to(torch.bfloat16))
803
+
804
+ # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
805
+
806
+ width_crop_num, height_crop_num = crop_ratio
807
+
808
+ images_spatial_crop.append([width_crop_num, height_crop_num])
809
+
810
+
811
+ if width_crop_num > 1 or height_crop_num > 1:
812
+ """process the local views"""
813
+
814
+ for i in range(len(images_crop_raw)):
815
+ images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
816
+
817
+ if image_size == 640:
818
+ valid_img_tokens += len(images_crop_list) * 100
819
+
820
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
821
+ num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
822
+
823
+
824
+
825
+ """add image tokens"""
826
+
827
+
828
+
829
+ tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base
830
+ tokenized_image += [image_token_id]
831
+ if width_crop_num > 1 or height_crop_num > 1:
832
+ tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * (
833
+ num_queries * height_crop_num)
834
+ tokenized_str += tokenized_image
835
+ images_seq_mask += [True] * len(tokenized_image)
836
+ # num_image_tokens.append(len(tokenized_image))
837
+
838
+ else:
839
+ # best_width, best_height = self.image_size, self.image_size
840
+ # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
841
+
842
+ """process the global view"""
843
+ if image_size <= 640:
844
+ print('directly resize')
845
+ image = image.resize((image_size, image_size))
846
+ # else:
847
+ global_view = ImageOps.pad(image, (image_size, image_size),
848
+ color=tuple(int(x * 255) for x in image_transform.mean))
849
+ images_list.append(image_transform(global_view).to(torch.bfloat16))
850
+
851
+ if base_size == 1024:
852
+ valid_img_tokens += int(256 * ratio)
853
+ elif base_size == 1280:
854
+ valid_img_tokens += int(400 * ratio)
855
+ elif base_size == 640:
856
+ valid_img_tokens += int(100 * 1)
857
+ elif base_size == 512:
858
+ valid_img_tokens += int(64 * 1)
859
+
860
+ width_crop_num, height_crop_num = 1, 1
861
+
862
+ images_spatial_crop.append([width_crop_num, height_crop_num])
863
+
864
+
865
+ """add image tokens"""
866
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
867
+
868
+ tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries
869
+ tokenized_image += [image_token_id]
870
+ # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
871
+ # num_queries * height_crop_num)
872
+ tokenized_str += tokenized_image
873
+ images_seq_mask += [True] * len(tokenized_image)
874
+ # num_image_tokens.append(len(tokenized_image))
875
+
876
+
877
+ """process the last text split"""
878
+ tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
879
+ tokenized_str += tokenized_sep
880
+ images_seq_mask += [False] * len(tokenized_sep)
881
+
882
+ """add the bos tokens"""
883
+ bos_id = 0
884
+ tokenized_str = [bos_id] + tokenized_str
885
+ images_seq_mask = [False] + images_seq_mask
886
+
887
+
888
+
889
+ input_ids = torch.LongTensor(tokenized_str)
890
+
891
+
892
+
893
+
894
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
895
+
896
+
897
+ if len(images_list) == 0:
898
+ images_ori = torch.zeros((1, 3, image_size, image_size))
899
+ images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
900
+ images_crop = torch.zeros((1, 3, base_size, base_size))
901
+
902
+ else:
903
+ images_ori = torch.stack(images_list, dim=0)
904
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
905
+ if images_crop_list:
906
+ images_crop = torch.stack(images_crop_list, dim=0)
907
+ else:
908
+ images_crop = torch.zeros((1, 3, base_size, base_size))
909
+
910
+
911
+
912
+ if not eval_mode:
913
+ streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
+ with torch.autocast("cuda", dtype=torch.bfloat16):
915
+ with torch.no_grad():
916
+ output_ids = self.generate(
917
+ input_ids.unsqueeze(0).cuda(),
918
+ images=[(images_crop.cuda(), images_ori.cuda())],
919
+ images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
920
+ images_spatial_crop = images_spatial_crop,
921
+ # do_sample=False,
922
+ # num_beams = 1,
923
+ temperature=0.0,
924
+ eos_token_id=tokenizer.eos_token_id,
925
+ streamer=streamer,
926
+ max_new_tokens=8192,
927
+ no_repeat_ngram_size = 20,
928
+ use_cache = True
929
+ )
930
+
931
+ else:
932
+ with torch.autocast("cuda", dtype=torch.bfloat16):
933
+ with torch.no_grad():
934
+ output_ids = self.generate(
935
+ input_ids.unsqueeze(0).cuda(),
936
+ images=[(images_crop.cuda(), images_ori.cuda())],
937
+ images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
938
+ images_spatial_crop = images_spatial_crop,
939
+ # do_sample=False,
940
+ # num_beams = 1,
941
+ temperature=0.0,
942
+ eos_token_id=tokenizer.eos_token_id,
943
+ max_new_tokens=8192,
944
+ no_repeat_ngram_size = 35,
945
+ use_cache = True
946
+ )
947
+
948
+
949
+ if '<image>' in conversation[0]['content'] and eval_mode:
950
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
951
+ stop_str = '<|end▁of▁sentence|>'
952
+ if outputs.endswith(stop_str):
953
+ outputs = outputs[:-len(stop_str)]
954
+ # re_match
955
+ outputs = outputs.strip()
956
+
957
+ return outputs
958
+
959
+ if '<image>' in conversation[0]['content'] and test_compress:
960
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
961
+ pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
962
+ print('='*50)
963
+ print('image size: ', (w, h))
964
+ print('valid image tokens: ', int(valid_img_tokens))
965
+ print('output texts tokens (valid): ', pure_texts_outputs_token_length)
966
+ print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
967
+ print('='*50)
968
+
969
+
970
+ if '<image>' in conversation[0]['content'] and save_results:
971
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
972
+ stop_str = '<|end▁of▁sentence|>'
973
+
974
+ print('='*15 + 'save results:' + '='*15)
975
+
976
+ # # # # conv.messages[-1][-1] = outputs
977
+ if outputs.endswith(stop_str):
978
+ outputs = outputs[:-len(stop_str)]
979
+ outputs = outputs.strip()
980
+
981
+ matches_ref, matches_images, mathes_other = re_match(outputs)
982
+ # print(matches_ref)
983
+ result = process_image_with_refs(image_draw, matches_ref, output_path)
984
+
985
+
986
+ for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
987
+ outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
988
+
989
+ for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
990
+ outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
991
+
992
+
993
+ # if 'structural formula' in conversation[0]['content']:
994
+ # outputs = '<smiles>' + outputs + '</smiles>'
995
+ with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
996
+ afile.write(outputs)
997
+
998
+ if 'line_type' in outputs:
999
+ import matplotlib.pyplot as plt
1000
+ lines = eval(outputs)['Line']['line']
1001
+
1002
+ line_type = eval(outputs)['Line']['line_type']
1003
+ # print(lines)
1004
+
1005
+ endpoints = eval(outputs)['Line']['line_endpoint']
1006
+
1007
+ fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1008
+ ax.set_xlim(-15, 15)
1009
+ ax.set_ylim(-15, 15)
1010
+
1011
+ for idx, line in enumerate(lines):
1012
+ try:
1013
+ p0 = eval(line.split(' -- ')[0])
1014
+ p1 = eval(line.split(' -- ')[-1])
1015
+
1016
+ if line_type[idx] == '--':
1017
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
1018
+ else:
1019
+ ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
1020
+
1021
+ ax.scatter(p0[0], p0[1], s=5, color = 'k')
1022
+ ax.scatter(p1[0], p1[1], s=5, color = 'k')
1023
+ except:
1024
+ pass
1025
+
1026
+ for endpoint in endpoints:
1027
+
1028
+ label = endpoint.split(': ')[0]
1029
+ (x, y) = eval(endpoint.split(': ')[1])
1030
+ ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1031
+ fontsize=5, fontweight='light')
1032
+
1033
+
1034
+ plt.savefig(f'{output_path}/geo.jpg')
1035
+ plt.close()
1036
+
1037
+ result.save(f"{output_path}/result_with_boxes.jpg")