Manus AI commited on
Commit
bfec147
·
1 Parent(s): faa79f8

Fix Florence2LanguageConfig AttributeError and pin transformers version

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. util/utils.py +566 -559
requirements.txt CHANGED
@@ -3,7 +3,7 @@ easyocr
3
  torchvision
4
  supervision==0.18.0
5
  openai
6
- transformers
7
  ultralytics
8
  azure-identity
9
  numpy
 
3
  torchvision
4
  supervision==0.18.0
5
  openai
6
+ transformers==4.45.2
7
  ultralytics
8
  azure-identity
9
  numpy
util/utils.py CHANGED
@@ -1,559 +1,566 @@
1
- # from ultralytics import YOLO
2
- import os
3
- import io
4
- import base64
5
- import time
6
- from PIL import Image, ImageDraw, ImageFont
7
- import json
8
- import requests
9
- # utility function
10
- import os
11
- from openai import AzureOpenAI
12
-
13
- import json
14
- import sys
15
- import os
16
- import cv2
17
- import numpy as np
18
- # %matplotlib inline
19
- from matplotlib import pyplot as plt
20
- import easyocr
21
- from paddleocr import PaddleOCR
22
- _reader = None
23
- _paddle_ocr = None
24
-
25
- def get_easyocr_reader():
26
- global _reader
27
- if _reader is None:
28
- import easyocr
29
- _reader = easyocr.Reader(['en'])
30
- return _reader
31
-
32
- def get_paddle_ocr():
33
- global _paddle_ocr
34
- if _paddle_ocr is None:
35
- from paddleocr import PaddleOCR
36
- _paddle_ocr = PaddleOCR(
37
- lang='en',
38
- use_angle_cls=False,
39
- use_gpu=False,
40
- show_log=False,
41
- max_batch_size=1024,
42
- use_dilation=True,
43
- det_db_score_mode='slow',
44
- rec_batch_num=1024)
45
- return _paddle_ocr
46
- import time
47
- import base64
48
-
49
- import os
50
- import ast
51
- import torch
52
- from typing import Tuple, List, Union
53
- from torchvision.ops import box_convert
54
- import re
55
- from torchvision.transforms import ToPILImage
56
- import supervision as sv
57
- import torchvision.transforms as T
58
- from util.box_annotator import BoxAnnotator
59
-
60
-
61
- def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
62
- if not device:
63
- device = "cuda" if torch.cuda.is_available() else "cpu"
64
- if model_name == "blip2":
65
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
66
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
67
- if device == 'cpu':
68
- model = Blip2ForConditionalGeneration.from_pretrained(
69
- model_name_or_path, device_map=None, torch_dtype=torch.float32
70
- )
71
- else:
72
- model = Blip2ForConditionalGeneration.from_pretrained(
73
- model_name_or_path, device_map=None, torch_dtype=torch.float16
74
- ).to(device)
75
- elif model_name == "florence2":
76
- from transformers import AutoProcessor, AutoModelForCausalLM
77
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
78
- if device == 'cpu':
79
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
80
- else:
81
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
82
- return {'model': model.to(device), 'processor': processor}
83
-
84
-
85
- def get_yolo_model(model_path):
86
- from ultralytics import YOLO
87
- # Load the model.
88
- model = YOLO(model_path)
89
- return model
90
-
91
-
92
- @torch.inference_mode()
93
- def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
94
- # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
95
- to_pil = ToPILImage()
96
- if starting_idx:
97
- non_ocr_boxes = filtered_boxes[starting_idx:]
98
- else:
99
- non_ocr_boxes = filtered_boxes
100
- croped_pil_image = []
101
- for i, coord in enumerate(non_ocr_boxes):
102
- try:
103
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
104
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
105
- cropped_image = image_source[ymin:ymax, xmin:xmax, :]
106
- cropped_image = cv2.resize(cropped_image, (64, 64))
107
- croped_pil_image.append(to_pil(cropped_image))
108
- except:
109
- continue
110
-
111
- model, processor = caption_model_processor['model'], caption_model_processor['processor']
112
- if not prompt:
113
- if 'florence' in model.config.name_or_path:
114
- prompt = "<CAPTION>"
115
- else:
116
- prompt = "The image shows"
117
-
118
- generated_texts = []
119
- device = model.device
120
- # batch_size = 64
121
- for i in range(0, len(croped_pil_image), batch_size):
122
- start = time.time()
123
- batch = croped_pil_image[i:i+batch_size]
124
- t1 = time.time()
125
- if model.device.type == 'cuda':
126
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
127
- else:
128
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
129
- # if 'florence' in model.config.name_or_path:
130
- generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
131
- # else:
132
- # generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
133
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
134
- generated_text = [gen.strip() for gen in generated_text]
135
- generated_texts.extend(generated_text)
136
-
137
- return generated_texts
138
-
139
-
140
-
141
- def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
142
- to_pil = ToPILImage()
143
- if ocr_bbox:
144
- non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
145
- else:
146
- non_ocr_boxes = filtered_boxes
147
- croped_pil_image = []
148
- for i, coord in enumerate(non_ocr_boxes):
149
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
150
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
151
- cropped_image = image_source[ymin:ymax, xmin:xmax, :]
152
- croped_pil_image.append(to_pil(cropped_image))
153
-
154
- model, processor = caption_model_processor['model'], caption_model_processor['processor']
155
- device = model.device
156
- messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
157
- prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
158
-
159
- batch_size = 5 # Number of samples per batch
160
- generated_texts = []
161
-
162
- for i in range(0, len(croped_pil_image), batch_size):
163
- images = croped_pil_image[i:i+batch_size]
164
- image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
165
- inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
166
- texts = [prompt] * len(images)
167
- for i, txt in enumerate(texts):
168
- input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
169
- inputs['input_ids'].append(input['input_ids'])
170
- inputs['attention_mask'].append(input['attention_mask'])
171
- inputs['pixel_values'].append(input['pixel_values'])
172
- inputs['image_sizes'].append(input['image_sizes'])
173
- max_len = max([x.shape[1] for x in inputs['input_ids']])
174
- for i, v in enumerate(inputs['input_ids']):
175
- inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
176
- inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
177
- inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
178
-
179
- generation_args = {
180
- "max_new_tokens": 25,
181
- "temperature": 0.01,
182
- "do_sample": False,
183
- }
184
- generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
185
- # # remove input tokens
186
- generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
187
- response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
188
- response = [res.strip('\n').strip() for res in response]
189
- generated_texts.extend(response)
190
-
191
- return generated_texts
192
-
193
- def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
194
- assert ocr_bbox is None or isinstance(ocr_bbox, List)
195
-
196
- def box_area(box):
197
- return (box[2] - box[0]) * (box[3] - box[1])
198
-
199
- def intersection_area(box1, box2):
200
- x1 = max(box1[0], box2[0])
201
- y1 = max(box1[1], box2[1])
202
- x2 = min(box1[2], box2[2])
203
- y2 = min(box1[3], box2[3])
204
- return max(0, x2 - x1) * max(0, y2 - y1)
205
-
206
- def IoU(box1, box2):
207
- intersection = intersection_area(box1, box2)
208
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
209
- if box_area(box1) > 0 and box_area(box2) > 0:
210
- ratio1 = intersection / box_area(box1)
211
- ratio2 = intersection / box_area(box2)
212
- else:
213
- ratio1, ratio2 = 0, 0
214
- return max(intersection / union, ratio1, ratio2)
215
-
216
- def is_inside(box1, box2):
217
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
218
- intersection = intersection_area(box1, box2)
219
- ratio1 = intersection / box_area(box1)
220
- return ratio1 > 0.95
221
-
222
- boxes = boxes.tolist()
223
- filtered_boxes = []
224
- if ocr_bbox:
225
- filtered_boxes.extend(ocr_bbox)
226
- # print('ocr_bbox!!!', ocr_bbox)
227
- for i, box1 in enumerate(boxes):
228
- # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
229
- is_valid_box = True
230
- for j, box2 in enumerate(boxes):
231
- # keep the smaller box
232
- if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
233
- is_valid_box = False
234
- break
235
- if is_valid_box:
236
- # add the following 2 lines to include ocr bbox
237
- if ocr_bbox:
238
- # only add the box if it does not overlap with any ocr bbox
239
- if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
240
- filtered_boxes.append(box1)
241
- else:
242
- filtered_boxes.append(box1)
243
- return torch.tensor(filtered_boxes)
244
-
245
-
246
- def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
247
- '''
248
- ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
249
- boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
250
-
251
- '''
252
- assert ocr_bbox is None or isinstance(ocr_bbox, List)
253
-
254
- def box_area(box):
255
- return (box[2] - box[0]) * (box[3] - box[1])
256
-
257
- def intersection_area(box1, box2):
258
- x1 = max(box1[0], box2[0])
259
- y1 = max(box1[1], box2[1])
260
- x2 = min(box1[2], box2[2])
261
- y2 = min(box1[3], box2[3])
262
- return max(0, x2 - x1) * max(0, y2 - y1)
263
-
264
- def IoU(box1, box2):
265
- intersection = intersection_area(box1, box2)
266
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
267
- if box_area(box1) > 0 and box_area(box2) > 0:
268
- ratio1 = intersection / box_area(box1)
269
- ratio2 = intersection / box_area(box2)
270
- else:
271
- ratio1, ratio2 = 0, 0
272
- return max(intersection / union, ratio1, ratio2)
273
-
274
- def is_inside(box1, box2):
275
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
276
- intersection = intersection_area(box1, box2)
277
- ratio1 = intersection / box_area(box1)
278
- return ratio1 > 0.80
279
-
280
- # boxes = boxes.tolist()
281
- filtered_boxes = []
282
- if ocr_bbox:
283
- filtered_boxes.extend(ocr_bbox)
284
- # print('ocr_bbox!!!', ocr_bbox)
285
- for i, box1_elem in enumerate(boxes):
286
- box1 = box1_elem['bbox']
287
- is_valid_box = True
288
- for j, box2_elem in enumerate(boxes):
289
- # keep the smaller box
290
- box2 = box2_elem['bbox']
291
- if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
292
- is_valid_box = False
293
- break
294
- if is_valid_box:
295
- if ocr_bbox:
296
- # keep yolo boxes + prioritize ocr label
297
- box_added = False
298
- ocr_labels = ''
299
- for box3_elem in ocr_bbox:
300
- if not box_added:
301
- box3 = box3_elem['bbox']
302
- if is_inside(box3, box1): # ocr inside icon
303
- # box_added = True
304
- # delete the box3_elem from ocr_bbox
305
- try:
306
- # gather all ocr labels
307
- ocr_labels += box3_elem['content'] + ' '
308
- filtered_boxes.remove(box3_elem)
309
- except:
310
- continue
311
- # break
312
- elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
313
- box_added = True
314
- break
315
- else:
316
- continue
317
- if not box_added:
318
- if ocr_labels:
319
- filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels,})
320
- else:
321
- filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, })
322
- else:
323
- filtered_boxes.append(box1)
324
- return filtered_boxes # torch.tensor(filtered_boxes)
325
-
326
-
327
- def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
328
- transform = T.Compose(
329
- [
330
- T.RandomResize([800], max_size=1333),
331
- T.ToTensor(),
332
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
333
- ]
334
- )
335
- image_source = Image.open(image_path).convert("RGB")
336
- image = np.asarray(image_source)
337
- image_transformed, _ = transform(image_source, None)
338
- return image, image_transformed
339
-
340
-
341
- def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
342
- text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
343
- """
344
- This function annotates an image with bounding boxes and labels.
345
-
346
- Parameters:
347
- image_source (np.ndarray): The source image to be annotated.
348
- boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
349
- logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
350
- phrases (List[str]): A list of labels for each bounding box.
351
- text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
352
-
353
- Returns:
354
- np.ndarray: The annotated image.
355
- """
356
- h, w, _ = image_source.shape
357
- boxes = boxes * torch.Tensor([w, h, w, h])
358
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
359
- xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
360
- detections = sv.Detections(xyxy=xyxy)
361
-
362
- labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
363
-
364
- box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
365
- annotated_frame = image_source.copy()
366
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
367
-
368
- label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
369
- return annotated_frame, label_coordinates
370
-
371
-
372
- def predict(model, image, caption, box_threshold, text_threshold):
373
- """ Use huggingface model to replace the original model
374
- """
375
- model, processor = model['model'], model['processor']
376
- device = model.device
377
-
378
- inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
379
- with torch.no_grad():
380
- outputs = model(**inputs)
381
-
382
- results = processor.post_process_grounded_object_detection(
383
- outputs,
384
- inputs.input_ids,
385
- box_threshold=box_threshold, # 0.4,
386
- text_threshold=text_threshold, # 0.3,
387
- target_sizes=[image.size[::-1]]
388
- )[0]
389
- boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
390
- return boxes, logits, phrases
391
-
392
-
393
- def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
394
- """ Use huggingface model to replace the original model
395
- """
396
- # model = model['model']
397
- if scale_img:
398
- result = model.predict(
399
- source=image,
400
- conf=box_threshold,
401
- imgsz=imgsz,
402
- iou=iou_threshold, # default 0.7
403
- )
404
- else:
405
- result = model.predict(
406
- source=image,
407
- conf=box_threshold,
408
- iou=iou_threshold, # default 0.7
409
- )
410
- boxes = result[0].boxes.xyxy#.tolist() # in pixel space
411
- conf = result[0].boxes.conf
412
- phrases = [str(i) for i in range(len(boxes))]
413
-
414
- return boxes, conf, phrases
415
-
416
- def int_box_area(box, w, h):
417
- x1, y1, x2, y2 = box
418
- int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
419
- area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
420
- return area
421
-
422
- def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
423
- """Process either an image path or Image object
424
-
425
- Args:
426
- image_source: Either a file path (str) or PIL Image object
427
- ...
428
- """
429
- if isinstance(image_source, str):
430
- image_source = Image.open(image_source).convert("RGB")
431
-
432
- w, h = image_source.size
433
- if not imgsz:
434
- imgsz = (h, w)
435
- # print('image size:', w, h)
436
- xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
437
- xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
438
- image_source = np.asarray(image_source)
439
- phrases = [str(i) for i in range(len(phrases))]
440
-
441
- # annotate the image with labels
442
- if ocr_bbox:
443
- ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
444
- ocr_bbox=ocr_bbox.tolist()
445
- else:
446
- print('no ocr bbox!!!')
447
- ocr_bbox = None
448
-
449
- ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt,} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
450
- xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
451
- filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
452
-
453
- # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
454
- filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
455
- # get the index of the first 'content': None
456
- starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
457
- filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
458
- print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
459
-
460
- # get parsed icon local semantics
461
- time1 = time.time()
462
- if use_local_semantics:
463
- caption_model = caption_model_processor['model']
464
- if 'phi3_v' in caption_model.config.model_type:
465
- parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
466
- else:
467
- parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
468
- ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
469
- icon_start = len(ocr_text)
470
- parsed_content_icon_ls = []
471
- # fill the filtered_boxes_elem None content with parsed_content_icon in order
472
- for i, box in enumerate(filtered_boxes_elem):
473
- if box['content'] is None:
474
- box['content'] = parsed_content_icon.pop(0)
475
- for i, txt in enumerate(parsed_content_icon):
476
- parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
477
- parsed_content_merged = ocr_text + parsed_content_icon_ls
478
- else:
479
- ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
480
- parsed_content_merged = ocr_text
481
- print('time to get parsed content:', time.time()-time1)
482
-
483
- filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
484
-
485
- phrases = [i for i in range(len(filtered_boxes))]
486
-
487
- # draw boxes
488
- if draw_bbox_config:
489
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
490
- else:
491
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
492
-
493
- pil_img = Image.fromarray(annotated_frame)
494
- buffered = io.BytesIO()
495
- pil_img.save(buffered, format="PNG")
496
- encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
497
- if output_coord_in_ratio:
498
- label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
499
- assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
500
-
501
- return encoded_image, label_coordinates, filtered_boxes_elem
502
-
503
-
504
- def get_xywh(input):
505
- x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
506
- x, y, w, h = int(x), int(y), int(w), int(h)
507
- return x, y, w, h
508
-
509
- def get_xyxy(input):
510
- x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
511
- x, y, xp, yp = int(x), int(y), int(xp), int(yp)
512
- return x, y, xp, yp
513
-
514
- def get_xywh_yolo(input):
515
- x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
516
- x, y, w, h = int(x), int(y), int(w), int(h)
517
- return x, y, w, h
518
-
519
- def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
520
- if isinstance(image_source, str):
521
- image_source = Image.open(image_source)
522
- if image_source.mode == 'RGBA':
523
- # Convert RGBA to RGB to avoid alpha channel issues
524
- image_source = image_source.convert('RGB')
525
- image_np = np.array(image_source)
526
- w, h = image_source.size
527
- if use_paddleocr:
528
- if easyocr_args is None:
529
- text_threshold = 0.5
530
- else:
531
- text_threshold = easyocr_args['text_threshold']
532
- p_ocr = get_paddle_ocr()
533
- result = p_ocr.ocr(image_np, cls=False)[0]
534
- coord = [item[0] for item in result if item[1][1] > text_threshold]
535
- text = [item[1][0] for item in result if item[1][1] > text_threshold]
536
- else: # EasyOCR
537
- if easyocr_args is None:
538
- easyocr_args = {}
539
- e_reader = get_easyocr_reader()
540
- result = e_reader.readtext(image_np, **easyocr_args)
541
- coord = [item[0] for item in result]
542
- text = [item[1] for item in result]
543
- if display_img:
544
- opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
545
- bb = []
546
- for item in coord:
547
- x, y, a, b = get_xywh(item)
548
- bb.append((x, y, a, b))
549
- cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
550
- # matplotlib expects RGB
551
- plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
552
- else:
553
- if output_bb_format == 'xywh':
554
- bb = [get_xywh(item) for item in coord]
555
- elif output_bb_format == 'xyxy':
556
- bb = [get_xyxy(item) for item in coord]
557
- return (text, bb), goal_filtering
558
-
559
-
 
 
 
 
 
 
 
 
1
+ # from ultralytics import YOLO
2
+ import os
3
+ import io
4
+ import base64
5
+ import time
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import json
8
+ import requests
9
+ # utility function
10
+ import os
11
+ from openai import AzureOpenAI
12
+
13
+ import json
14
+ import sys
15
+ import os
16
+ import cv2
17
+ import numpy as np
18
+ # %matplotlib inline
19
+ from matplotlib import pyplot as plt
20
+ import easyocr
21
+ from paddleocr import PaddleOCR
22
+ _reader = None
23
+ _paddle_ocr = None
24
+
25
+ def get_easyocr_reader():
26
+ global _reader
27
+ if _reader is None:
28
+ import easyocr
29
+ _reader = easyocr.Reader(['en'])
30
+ return _reader
31
+
32
+ def get_paddle_ocr():
33
+ global _paddle_ocr
34
+ if _paddle_ocr is None:
35
+ from paddleocr import PaddleOCR
36
+ _paddle_ocr = PaddleOCR(
37
+ lang='en',
38
+ use_angle_cls=False,
39
+ use_gpu=False,
40
+ show_log=False,
41
+ max_batch_size=1024,
42
+ use_dilation=True,
43
+ det_db_score_mode='slow',
44
+ rec_batch_num=1024)
45
+ return _paddle_ocr
46
+ import time
47
+ import base64
48
+
49
+ import os
50
+ import ast
51
+ import torch
52
+ from typing import Tuple, List, Union
53
+ from torchvision.ops import box_convert
54
+ import re
55
+ from torchvision.transforms import ToPILImage
56
+ import supervision as sv
57
+ import torchvision.transforms as T
58
+ from util.box_annotator import BoxAnnotator
59
+
60
+
61
+ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
62
+ if not device:
63
+ device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ if model_name == "blip2":
65
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
66
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
67
+ if device == 'cpu':
68
+ model = Blip2ForConditionalGeneration.from_pretrained(
69
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
70
+ )
71
+ else:
72
+ model = Blip2ForConditionalGeneration.from_pretrained(
73
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
74
+ ).to(device)
75
+ elif model_name == "florence2":
76
+ from transformers import AutoProcessor, AutoModelForCausalLM
77
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
78
+ if device == 'cpu':
79
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
80
+ else:
81
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
82
+
83
+ # Fix for 'Florence2LanguageConfig' object has no attribute 'forced_bos_token_id'
84
+ if not hasattr(model.config, 'forced_bos_token_id'):
85
+ model.config.forced_bos_token_id = None
86
+ if hasattr(model.config, 'text_config') and not hasattr(model.config.text_config, 'forced_bos_token_id'):
87
+ model.config.text_config.forced_bos_token_id = None
88
+
89
+ return {'model': model.to(device), 'processor': processor}
90
+
91
+
92
+ def get_yolo_model(model_path):
93
+ from ultralytics import YOLO
94
+ # Load the model.
95
+ model = YOLO(model_path)
96
+ return model
97
+
98
+
99
+ @torch.inference_mode()
100
+ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
101
+ # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
102
+ to_pil = ToPILImage()
103
+ if starting_idx:
104
+ non_ocr_boxes = filtered_boxes[starting_idx:]
105
+ else:
106
+ non_ocr_boxes = filtered_boxes
107
+ croped_pil_image = []
108
+ for i, coord in enumerate(non_ocr_boxes):
109
+ try:
110
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
111
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
112
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
113
+ cropped_image = cv2.resize(cropped_image, (64, 64))
114
+ croped_pil_image.append(to_pil(cropped_image))
115
+ except:
116
+ continue
117
+
118
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
119
+ if not prompt:
120
+ if 'florence' in model.config.name_or_path:
121
+ prompt = "<CAPTION>"
122
+ else:
123
+ prompt = "The image shows"
124
+
125
+ generated_texts = []
126
+ device = model.device
127
+ # batch_size = 64
128
+ for i in range(0, len(croped_pil_image), batch_size):
129
+ start = time.time()
130
+ batch = croped_pil_image[i:i+batch_size]
131
+ t1 = time.time()
132
+ if model.device.type == 'cuda':
133
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
134
+ else:
135
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
136
+ # if 'florence' in model.config.name_or_path:
137
+ generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
138
+ # else:
139
+ # generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
140
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
141
+ generated_text = [gen.strip() for gen in generated_text]
142
+ generated_texts.extend(generated_text)
143
+
144
+ return generated_texts
145
+
146
+
147
+
148
+ def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
149
+ to_pil = ToPILImage()
150
+ if ocr_bbox:
151
+ non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
152
+ else:
153
+ non_ocr_boxes = filtered_boxes
154
+ croped_pil_image = []
155
+ for i, coord in enumerate(non_ocr_boxes):
156
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
157
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
158
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
159
+ croped_pil_image.append(to_pil(cropped_image))
160
+
161
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
162
+ device = model.device
163
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
164
+ prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
165
+
166
+ batch_size = 5 # Number of samples per batch
167
+ generated_texts = []
168
+
169
+ for i in range(0, len(croped_pil_image), batch_size):
170
+ images = croped_pil_image[i:i+batch_size]
171
+ image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
172
+ inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
173
+ texts = [prompt] * len(images)
174
+ for i, txt in enumerate(texts):
175
+ input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
176
+ inputs['input_ids'].append(input['input_ids'])
177
+ inputs['attention_mask'].append(input['attention_mask'])
178
+ inputs['pixel_values'].append(input['pixel_values'])
179
+ inputs['image_sizes'].append(input['image_sizes'])
180
+ max_len = max([x.shape[1] for x in inputs['input_ids']])
181
+ for i, v in enumerate(inputs['input_ids']):
182
+ inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
183
+ inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
184
+ inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
185
+
186
+ generation_args = {
187
+ "max_new_tokens": 25,
188
+ "temperature": 0.01,
189
+ "do_sample": False,
190
+ }
191
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
192
+ # # remove input tokens
193
+ generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
194
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
195
+ response = [res.strip('\n').strip() for res in response]
196
+ generated_texts.extend(response)
197
+
198
+ return generated_texts
199
+
200
+ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
201
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
202
+
203
+ def box_area(box):
204
+ return (box[2] - box[0]) * (box[3] - box[1])
205
+
206
+ def intersection_area(box1, box2):
207
+ x1 = max(box1[0], box2[0])
208
+ y1 = max(box1[1], box2[1])
209
+ x2 = min(box1[2], box2[2])
210
+ y2 = min(box1[3], box2[3])
211
+ return max(0, x2 - x1) * max(0, y2 - y1)
212
+
213
+ def IoU(box1, box2):
214
+ intersection = intersection_area(box1, box2)
215
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
216
+ if box_area(box1) > 0 and box_area(box2) > 0:
217
+ ratio1 = intersection / box_area(box1)
218
+ ratio2 = intersection / box_area(box2)
219
+ else:
220
+ ratio1, ratio2 = 0, 0
221
+ return max(intersection / union, ratio1, ratio2)
222
+
223
+ def is_inside(box1, box2):
224
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
225
+ intersection = intersection_area(box1, box2)
226
+ ratio1 = intersection / box_area(box1)
227
+ return ratio1 > 0.95
228
+
229
+ boxes = boxes.tolist()
230
+ filtered_boxes = []
231
+ if ocr_bbox:
232
+ filtered_boxes.extend(ocr_bbox)
233
+ # print('ocr_bbox!!!', ocr_bbox)
234
+ for i, box1 in enumerate(boxes):
235
+ # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
236
+ is_valid_box = True
237
+ for j, box2 in enumerate(boxes):
238
+ # keep the smaller box
239
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
240
+ is_valid_box = False
241
+ break
242
+ if is_valid_box:
243
+ # add the following 2 lines to include ocr bbox
244
+ if ocr_bbox:
245
+ # only add the box if it does not overlap with any ocr bbox
246
+ if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
247
+ filtered_boxes.append(box1)
248
+ else:
249
+ filtered_boxes.append(box1)
250
+ return torch.tensor(filtered_boxes)
251
+
252
+
253
+ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
254
+ '''
255
+ ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
256
+ boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
257
+
258
+ '''
259
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
260
+
261
+ def box_area(box):
262
+ return (box[2] - box[0]) * (box[3] - box[1])
263
+
264
+ def intersection_area(box1, box2):
265
+ x1 = max(box1[0], box2[0])
266
+ y1 = max(box1[1], box2[1])
267
+ x2 = min(box1[2], box2[2])
268
+ y2 = min(box1[3], box2[3])
269
+ return max(0, x2 - x1) * max(0, y2 - y1)
270
+
271
+ def IoU(box1, box2):
272
+ intersection = intersection_area(box1, box2)
273
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
274
+ if box_area(box1) > 0 and box_area(box2) > 0:
275
+ ratio1 = intersection / box_area(box1)
276
+ ratio2 = intersection / box_area(box2)
277
+ else:
278
+ ratio1, ratio2 = 0, 0
279
+ return max(intersection / union, ratio1, ratio2)
280
+
281
+ def is_inside(box1, box2):
282
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
283
+ intersection = intersection_area(box1, box2)
284
+ ratio1 = intersection / box_area(box1)
285
+ return ratio1 > 0.80
286
+
287
+ # boxes = boxes.tolist()
288
+ filtered_boxes = []
289
+ if ocr_bbox:
290
+ filtered_boxes.extend(ocr_bbox)
291
+ # print('ocr_bbox!!!', ocr_bbox)
292
+ for i, box1_elem in enumerate(boxes):
293
+ box1 = box1_elem['bbox']
294
+ is_valid_box = True
295
+ for j, box2_elem in enumerate(boxes):
296
+ # keep the smaller box
297
+ box2 = box2_elem['bbox']
298
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
299
+ is_valid_box = False
300
+ break
301
+ if is_valid_box:
302
+ if ocr_bbox:
303
+ # keep yolo boxes + prioritize ocr label
304
+ box_added = False
305
+ ocr_labels = ''
306
+ for box3_elem in ocr_bbox:
307
+ if not box_added:
308
+ box3 = box3_elem['bbox']
309
+ if is_inside(box3, box1): # ocr inside icon
310
+ # box_added = True
311
+ # delete the box3_elem from ocr_bbox
312
+ try:
313
+ # gather all ocr labels
314
+ ocr_labels += box3_elem['content'] + ' '
315
+ filtered_boxes.remove(box3_elem)
316
+ except:
317
+ continue
318
+ # break
319
+ elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
320
+ box_added = True
321
+ break
322
+ else:
323
+ continue
324
+ if not box_added:
325
+ if ocr_labels:
326
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels,})
327
+ else:
328
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, })
329
+ else:
330
+ filtered_boxes.append(box1)
331
+ return filtered_boxes # torch.tensor(filtered_boxes)
332
+
333
+
334
+ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
335
+ transform = T.Compose(
336
+ [
337
+ T.RandomResize([800], max_size=1333),
338
+ T.ToTensor(),
339
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
340
+ ]
341
+ )
342
+ image_source = Image.open(image_path).convert("RGB")
343
+ image = np.asarray(image_source)
344
+ image_transformed, _ = transform(image_source, None)
345
+ return image, image_transformed
346
+
347
+
348
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
349
+ text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
350
+ """
351
+ This function annotates an image with bounding boxes and labels.
352
+
353
+ Parameters:
354
+ image_source (np.ndarray): The source image to be annotated.
355
+ boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
356
+ logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
357
+ phrases (List[str]): A list of labels for each bounding box.
358
+ text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
359
+
360
+ Returns:
361
+ np.ndarray: The annotated image.
362
+ """
363
+ h, w, _ = image_source.shape
364
+ boxes = boxes * torch.Tensor([w, h, w, h])
365
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
366
+ xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
367
+ detections = sv.Detections(xyxy=xyxy)
368
+
369
+ labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
370
+
371
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
372
+ annotated_frame = image_source.copy()
373
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
374
+
375
+ label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
376
+ return annotated_frame, label_coordinates
377
+
378
+
379
+ def predict(model, image, caption, box_threshold, text_threshold):
380
+ """ Use huggingface model to replace the original model
381
+ """
382
+ model, processor = model['model'], model['processor']
383
+ device = model.device
384
+
385
+ inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
386
+ with torch.no_grad():
387
+ outputs = model(**inputs)
388
+
389
+ results = processor.post_process_grounded_object_detection(
390
+ outputs,
391
+ inputs.input_ids,
392
+ box_threshold=box_threshold, # 0.4,
393
+ text_threshold=text_threshold, # 0.3,
394
+ target_sizes=[image.size[::-1]]
395
+ )[0]
396
+ boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
397
+ return boxes, logits, phrases
398
+
399
+
400
+ def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
401
+ """ Use huggingface model to replace the original model
402
+ """
403
+ # model = model['model']
404
+ if scale_img:
405
+ result = model.predict(
406
+ source=image,
407
+ conf=box_threshold,
408
+ imgsz=imgsz,
409
+ iou=iou_threshold, # default 0.7
410
+ )
411
+ else:
412
+ result = model.predict(
413
+ source=image,
414
+ conf=box_threshold,
415
+ iou=iou_threshold, # default 0.7
416
+ )
417
+ boxes = result[0].boxes.xyxy#.tolist() # in pixel space
418
+ conf = result[0].boxes.conf
419
+ phrases = [str(i) for i in range(len(boxes))]
420
+
421
+ return boxes, conf, phrases
422
+
423
+ def int_box_area(box, w, h):
424
+ x1, y1, x2, y2 = box
425
+ int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
426
+ area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
427
+ return area
428
+
429
+ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
430
+ """Process either an image path or Image object
431
+
432
+ Args:
433
+ image_source: Either a file path (str) or PIL Image object
434
+ ...
435
+ """
436
+ if isinstance(image_source, str):
437
+ image_source = Image.open(image_source).convert("RGB")
438
+
439
+ w, h = image_source.size
440
+ if not imgsz:
441
+ imgsz = (h, w)
442
+ # print('image size:', w, h)
443
+ xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
444
+ xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
445
+ image_source = np.asarray(image_source)
446
+ phrases = [str(i) for i in range(len(phrases))]
447
+
448
+ # annotate the image with labels
449
+ if ocr_bbox:
450
+ ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
451
+ ocr_bbox=ocr_bbox.tolist()
452
+ else:
453
+ print('no ocr bbox!!!')
454
+ ocr_bbox = None
455
+
456
+ ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt,} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
457
+ xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
458
+ filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
459
+
460
+ # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
461
+ filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
462
+ # get the index of the first 'content': None
463
+ starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
464
+ filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
465
+ print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
466
+
467
+ # get parsed icon local semantics
468
+ time1 = time.time()
469
+ if use_local_semantics:
470
+ caption_model = caption_model_processor['model']
471
+ if 'phi3_v' in caption_model.config.model_type:
472
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
473
+ else:
474
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
475
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
476
+ icon_start = len(ocr_text)
477
+ parsed_content_icon_ls = []
478
+ # fill the filtered_boxes_elem None content with parsed_content_icon in order
479
+ for i, box in enumerate(filtered_boxes_elem):
480
+ if box['content'] is None:
481
+ box['content'] = parsed_content_icon.pop(0)
482
+ for i, txt in enumerate(parsed_content_icon):
483
+ parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
484
+ parsed_content_merged = ocr_text + parsed_content_icon_ls
485
+ else:
486
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
487
+ parsed_content_merged = ocr_text
488
+ print('time to get parsed content:', time.time()-time1)
489
+
490
+ filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
491
+
492
+ phrases = [i for i in range(len(filtered_boxes))]
493
+
494
+ # draw boxes
495
+ if draw_bbox_config:
496
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
497
+ else:
498
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
499
+
500
+ pil_img = Image.fromarray(annotated_frame)
501
+ buffered = io.BytesIO()
502
+ pil_img.save(buffered, format="PNG")
503
+ encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
504
+ if output_coord_in_ratio:
505
+ label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
506
+ assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
507
+
508
+ return encoded_image, label_coordinates, filtered_boxes_elem
509
+
510
+
511
+ def get_xywh(input):
512
+ x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
513
+ x, y, w, h = int(x), int(y), int(w), int(h)
514
+ return x, y, w, h
515
+
516
+ def get_xyxy(input):
517
+ x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
518
+ x, y, xp, yp = int(x), int(y), int(xp), int(yp)
519
+ return x, y, xp, yp
520
+
521
+ def get_xywh_yolo(input):
522
+ x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
523
+ x, y, w, h = int(x), int(y), int(w), int(h)
524
+ return x, y, w, h
525
+
526
+ def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
527
+ if isinstance(image_source, str):
528
+ image_source = Image.open(image_source)
529
+ if image_source.mode == 'RGBA':
530
+ # Convert RGBA to RGB to avoid alpha channel issues
531
+ image_source = image_source.convert('RGB')
532
+ image_np = np.array(image_source)
533
+ w, h = image_source.size
534
+ if use_paddleocr:
535
+ if easyocr_args is None:
536
+ text_threshold = 0.5
537
+ else:
538
+ text_threshold = easyocr_args['text_threshold']
539
+ p_ocr = get_paddle_ocr()
540
+ result = p_ocr.ocr(image_np, cls=False)[0]
541
+ coord = [item[0] for item in result if item[1][1] > text_threshold]
542
+ text = [item[1][0] for item in result if item[1][1] > text_threshold]
543
+ else: # EasyOCR
544
+ if easyocr_args is None:
545
+ easyocr_args = {}
546
+ e_reader = get_easyocr_reader()
547
+ result = e_reader.readtext(image_np, **easyocr_args)
548
+ coord = [item[0] for item in result]
549
+ text = [item[1] for item in result]
550
+ if display_img:
551
+ opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
552
+ bb = []
553
+ for item in coord:
554
+ x, y, a, b = get_xywh(item)
555
+ bb.append((x, y, a, b))
556
+ cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
557
+ # matplotlib expects RGB
558
+ plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
559
+ else:
560
+ if output_bb_format == 'xywh':
561
+ bb = [get_xywh(item) for item in coord]
562
+ elif output_bb_format == 'xyxy':
563
+ bb = [get_xyxy(item) for item in coord]
564
+ return (text, bb), goal_filtering
565
+
566
+