ThomasDh-C commited on
Commit
cdab2b6
·
verified ·
1 Parent(s): 02238c5

Transformers

Browse files
box_annotator.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from supervision.detection.core import Detections
7
+ from supervision.draw.color import Color, ColorPalette
8
+
9
+
10
+ class BoxAnnotator:
11
+ """
12
+ A class for drawing bounding boxes on an image using detections provided.
13
+
14
+ Attributes:
15
+ color (Union[Color, ColorPalette]): The color to draw the bounding box,
16
+ can be a single color or a color palette
17
+ thickness (int): The thickness of the bounding box lines, default is 2
18
+ text_color (Color): The color of the text on the bounding box, default is white
19
+ text_scale (float): The scale of the text on the bounding box, default is 0.5
20
+ text_thickness (int): The thickness of the text on the bounding box,
21
+ default is 1
22
+ text_padding (int): The padding around the text on the bounding box,
23
+ default is 5
24
+
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
30
+ thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
31
+ text_color: Color = Color.BLACK,
32
+ text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
33
+ text_thickness: int = 2, #1, # 2 for demo
34
+ text_padding: int = 10,
35
+ avoid_overlap: bool = True,
36
+ ):
37
+ self.color: Union[Color, ColorPalette] = color
38
+ self.thickness: int = thickness
39
+ self.text_color: Color = text_color
40
+ self.text_scale: float = text_scale
41
+ self.text_thickness: int = text_thickness
42
+ self.text_padding: int = text_padding
43
+ self.avoid_overlap: bool = avoid_overlap
44
+
45
+ def annotate(
46
+ self,
47
+ scene: np.ndarray,
48
+ detections: Detections,
49
+ labels: Optional[List[str]] = None,
50
+ skip_label: bool = False,
51
+ image_size: Optional[Tuple[int, int]] = None,
52
+ ) -> np.ndarray:
53
+ """
54
+ Draws bounding boxes on the frame using the detections provided.
55
+
56
+ Args:
57
+ scene (np.ndarray): The image on which the bounding boxes will be drawn
58
+ detections (Detections): The detections for which the
59
+ bounding boxes will be drawn
60
+ labels (Optional[List[str]]): An optional list of labels
61
+ corresponding to each detection. If `labels` are not provided,
62
+ corresponding `class_id` will be used as label.
63
+ skip_label (bool): Is set to `True`, skips bounding box label annotation.
64
+ Returns:
65
+ np.ndarray: The image with the bounding boxes drawn on it
66
+
67
+ Example:
68
+ ```python
69
+ import supervision as sv
70
+
71
+ classes = ['person', ...]
72
+ image = ...
73
+ detections = sv.Detections(...)
74
+
75
+ box_annotator = sv.BoxAnnotator()
76
+ labels = [
77
+ f"{classes[class_id]} {confidence:0.2f}"
78
+ for _, _, confidence, class_id, _ in detections
79
+ ]
80
+ annotated_frame = box_annotator.annotate(
81
+ scene=image.copy(),
82
+ detections=detections,
83
+ labels=labels
84
+ )
85
+ ```
86
+ """
87
+ font = cv2.FONT_HERSHEY_SIMPLEX
88
+ for i in range(len(detections)):
89
+ x1, y1, x2, y2 = detections.xyxy[i].astype(int)
90
+ class_id = (
91
+ detections.class_id[i] if detections.class_id is not None else None
92
+ )
93
+ idx = class_id if class_id is not None else i
94
+ color = (
95
+ self.color.by_idx(idx)
96
+ if isinstance(self.color, ColorPalette)
97
+ else self.color
98
+ )
99
+ cv2.rectangle(
100
+ img=scene,
101
+ pt1=(x1, y1),
102
+ pt2=(x2, y2),
103
+ color=color.as_bgr(),
104
+ thickness=self.thickness,
105
+ )
106
+ if skip_label:
107
+ continue
108
+
109
+ text = (
110
+ f"{class_id}"
111
+ if (labels is None or len(detections) != len(labels))
112
+ else labels[i]
113
+ )
114
+
115
+ text_width, text_height = cv2.getTextSize(
116
+ text=text,
117
+ fontFace=font,
118
+ fontScale=self.text_scale,
119
+ thickness=self.text_thickness,
120
+ )[0]
121
+
122
+ if not self.avoid_overlap:
123
+ text_x = x1 + self.text_padding
124
+ text_y = y1 - self.text_padding
125
+
126
+ text_background_x1 = x1
127
+ text_background_y1 = y1 - 2 * self.text_padding - text_height
128
+
129
+ text_background_x2 = x1 + 2 * self.text_padding + text_width
130
+ text_background_y2 = y1
131
+ # text_x = x1 - self.text_padding - text_width
132
+ # text_y = y1 + self.text_padding + text_height
133
+ # text_background_x1 = x1 - 2 * self.text_padding - text_width
134
+ # text_background_y1 = y1
135
+ # text_background_x2 = x1
136
+ # text_background_y2 = y1 + 2 * self.text_padding + text_height
137
+ else:
138
+ text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
139
+
140
+ cv2.rectangle(
141
+ img=scene,
142
+ pt1=(text_background_x1, text_background_y1),
143
+ pt2=(text_background_x2, text_background_y2),
144
+ color=color.as_bgr(),
145
+ thickness=cv2.FILLED,
146
+ )
147
+ # import pdb; pdb.set_trace()
148
+ box_color = color.as_rgb()
149
+ luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
150
+ text_color = (0,0,0) if luminance > 160 else (255,255,255)
151
+ cv2.putText(
152
+ img=scene,
153
+ text=text,
154
+ org=(text_x, text_y),
155
+ fontFace=font,
156
+ fontScale=self.text_scale,
157
+ # color=self.text_color.as_rgb(),
158
+ color=text_color,
159
+ thickness=self.text_thickness,
160
+ lineType=cv2.LINE_AA,
161
+ )
162
+ return scene
163
+
164
+
165
+ def box_area(box):
166
+ return (box[2] - box[0]) * (box[3] - box[1])
167
+
168
+ def intersection_area(box1, box2):
169
+ x1 = max(box1[0], box2[0])
170
+ y1 = max(box1[1], box2[1])
171
+ x2 = min(box1[2], box2[2])
172
+ y2 = min(box1[3], box2[3])
173
+ return max(0, x2 - x1) * max(0, y2 - y1)
174
+
175
+ def IoU(box1, box2, return_max=True):
176
+ intersection = intersection_area(box1, box2)
177
+ union = box_area(box1) + box_area(box2) - intersection
178
+ if box_area(box1) > 0 and box_area(box2) > 0:
179
+ ratio1 = intersection / box_area(box1)
180
+ ratio2 = intersection / box_area(box2)
181
+ else:
182
+ ratio1, ratio2 = 0, 0
183
+ if return_max:
184
+ return max(intersection / union, ratio1, ratio2)
185
+ else:
186
+ return intersection / union
187
+
188
+
189
+ def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
190
+ """ check overlap of text and background detection box, and get_optimal_label_pos,
191
+ pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
192
+ Threshold: default to 0.3
193
+ """
194
+
195
+ def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
196
+ is_overlap = False
197
+ for i in range(len(detections)):
198
+ detection = detections.xyxy[i].astype(int)
199
+ if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
200
+ is_overlap = True
201
+ break
202
+ # check if the text is out of the image
203
+ if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
204
+ is_overlap = True
205
+ return is_overlap
206
+
207
+ # if pos == 'top left':
208
+ text_x = x1 + text_padding
209
+ text_y = y1 - text_padding
210
+
211
+ text_background_x1 = x1
212
+ text_background_y1 = y1 - 2 * text_padding - text_height
213
+
214
+ text_background_x2 = x1 + 2 * text_padding + text_width
215
+ text_background_y2 = y1
216
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
217
+ if not is_overlap:
218
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
219
+
220
+ # elif pos == 'outer left':
221
+ text_x = x1 - text_padding - text_width
222
+ text_y = y1 + text_padding + text_height
223
+
224
+ text_background_x1 = x1 - 2 * text_padding - text_width
225
+ text_background_y1 = y1
226
+
227
+ text_background_x2 = x1
228
+ text_background_y2 = y1 + 2 * text_padding + text_height
229
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
230
+ if not is_overlap:
231
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
232
+
233
+
234
+ # elif pos == 'outer right':
235
+ text_x = x2 + text_padding
236
+ text_y = y1 + text_padding + text_height
237
+
238
+ text_background_x1 = x2
239
+ text_background_y1 = y1
240
+
241
+ text_background_x2 = x2 + 2 * text_padding + text_width
242
+ text_background_y2 = y1 + 2 * text_padding + text_height
243
+
244
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
245
+ if not is_overlap:
246
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
247
+
248
+ # elif pos == 'top right':
249
+ text_x = x2 - text_padding - text_width
250
+ text_y = y1 - text_padding
251
+
252
+ text_background_x1 = x2 - 2 * text_padding - text_width
253
+ text_background_y1 = y1 - 2 * text_padding - text_height
254
+
255
+ text_background_x2 = x2
256
+ text_background_y2 = y1
257
+
258
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
259
+ if not is_overlap:
260
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
261
+
262
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "box_threshold": 0.05,
3
+ "model_type": "omniparserv2",
4
+ "transformers_version": "4.48.1"
5
+ }
configuration_omniparser.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class OmniparserConfig(PretrainedConfig):
5
+ model_type = "omniparserv2"
6
+
7
+ def __init__(
8
+ self,
9
+ box_threshold: float = 0.05,
10
+ **kwargs,
11
+ ):
12
+ self.box_threshold = box_threshold
13
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
3
+ size 16
modelling_omniparser.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ from .configuration_omniparser import OmniparserConfig
4
+ from ultralytics import YOLO
5
+ from torchvision.transforms import ToPILImage
6
+ from torchvision.ops import box_convert
7
+ import base64
8
+ import supervision as sv
9
+ from .box_annotator import BoxAnnotator
10
+ import os
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import io
15
+ import cv2
16
+ import easyocr
17
+ from typing import List, Dict
18
+ from dataclasses import dataclass
19
+ from transformers.modeling_outputs import ModelOutput
20
+ import json
21
+
22
+ @dataclass
23
+ class OmniparserOutput(ModelOutput):
24
+ """
25
+ Output type of [`OmniparserModel`].
26
+
27
+ Args:
28
+ annotated_image (`str`):
29
+ The image with bounding boxes and labels drawn on it, returned as base64 encoded string.
30
+ parsed_content_list (`List[Dict]`):
31
+ Lists of detected elements with their properties (type, bbox, interactivity, content, source).
32
+ """
33
+ annotated_image: str = None
34
+ parsed_content_list: List[Dict] = None
35
+
36
+
37
+ class OmniparserModel(PreTrainedModel):
38
+ config_class = OmniparserConfig
39
+ _keys_to_ignore_on_load_missing = [r"caption_model\..*"]
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self._device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ current_file_directory = os.path.dirname(os.path.abspath(__file__))
45
+ self.icon_detect_path = os.path.join(current_file_directory, 'icon_detect/model.pt')
46
+ icon_caption_path = os.path.join(current_file_directory, 'icon_caption')
47
+ caption_model_processor = self.get_caption_model_processor(model_name_or_path=icon_caption_path)
48
+ self.caption_model, self.caption_model_processor = caption_model_processor['model'], caption_model_processor['processor']
49
+ self.reader = easyocr.Reader(['en'])
50
+ self.to_pil = ToPILImage()
51
+ self.som_model_inited = False
52
+
53
+ def forward(self, image_source):
54
+ # Lazy load YOLO model as autoload calls .eval which breaks YOLO unless init in forward
55
+ if not self.som_model_inited:
56
+ self.som_model = YOLO(self.icon_detect_path)
57
+ self.som_model_inited = True
58
+
59
+ if image_source is None:
60
+ raise ValueError("No image provided")
61
+ if isinstance(image_source, str):
62
+ image_source = Image.open(image_source)
63
+ image_source = image_source.convert("RGB")
64
+
65
+ # Scan for text
66
+ ocr_text, ocr_bbox = self.check_ocr_box(image_source, easyocr_args={'text_threshold': 0.8})
67
+ # Yolo + merge text and YOLO + florence
68
+ annotated_image, parsed_content_list = self.get_som_labeled_img(image_source,
69
+ self.som_model,
70
+ box_threshold=self.config.box_threshold,
71
+ ocr_bbox=ocr_bbox,
72
+ ocr_text=ocr_text)
73
+
74
+ return OmniparserOutput(
75
+ annotated_image=annotated_image,
76
+ parsed_content_list=parsed_content_list
77
+ )
78
+
79
+ def get_caption_model_processor(self, model_name_or_path):
80
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
81
+ if self._device_type == 'cpu':
82
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
83
+ else:
84
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(self._device_type)
85
+ return {'model': model.to(self._device_type), 'processor': processor}
86
+
87
+ def check_ocr_box(self, image_source, easyocr_args):
88
+ image_np = np.array(image_source)
89
+ w, h = image_source.size
90
+ result = self.reader.readtext(image_np, **easyocr_args)
91
+ coord = [item[0] for item in result]
92
+ text = [item[1] for item in result]
93
+ bb = [self.get_xyxy(item) for item in coord]
94
+ return text, bb
95
+
96
+ def get_som_labeled_img(self, image_source,
97
+ model,
98
+ box_threshold,
99
+ ocr_bbox,
100
+ ocr_text,
101
+ iou_threshold=0.7,
102
+ prompt=None,
103
+ batch_size=128):
104
+ w, h = image_source.size
105
+ box_overlay_ratio = max(w, h) / 3200
106
+ draw_bbox_config = {
107
+ 'text_scale': 0.8 * box_overlay_ratio,
108
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
109
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
110
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
111
+ }
112
+ xyxy, logits, phrases = self.predict_yolo(yolo_model=model, image=image_source, box_threshold=box_threshold, iou_threshold=0.1)
113
+ xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
114
+ image_source = np.asarray(image_source)
115
+ phrases = [str(i) for i in range(len(phrases))]
116
+
117
+ # merge text and YOLO boxes
118
+ ocr_bbox_elem = None
119
+ if ocr_bbox:
120
+ ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
121
+ ocr_bbox=ocr_bbox.tolist()
122
+ ocr_bbox_elem = [{'type': 'text', 'bbox': box, 'interactivity': False, 'content': txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if self.int_box_area(box, w, h) > 0]
123
+ xyxy_elem = [{'type': 'icon', 'bbox': box, 'interactivity': True, 'content': None, 'source': 'box_yolo_content_yolo'} for box in xyxy.tolist() if self.int_box_area(box, w, h) > 0]
124
+ filtered_boxes = self.remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
125
+
126
+ # sort the filtered_boxes so that ocr first, yolo last
127
+ filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
128
+ # get the index of the first yolo box
129
+ starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
130
+ filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
131
+
132
+ # get local semantics using florence on yolo boxes
133
+ # TODO: fix that should not be using starting_idx ... instead use a where content is none ... O(n) rather than using sort
134
+ parsed_content_icon = self.get_parsed_content_icon(filtered_boxes, starting_idx, image_source, prompt=prompt, batch_size=batch_size)
135
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
136
+ icon_start = len(ocr_text)
137
+ parsed_content_icon_ls = []
138
+ # fill the filtered_boxes_elem None content with parsed_content_icon in order
139
+ for i, box in enumerate(filtered_boxes_elem):
140
+ if box['content'] is None:
141
+ box['content'] = parsed_content_icon.pop(0)
142
+ for i, txt in enumerate(parsed_content_icon):
143
+ parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
144
+ parsed_content_merged = ocr_text + parsed_content_icon_ls
145
+
146
+ # draw boxes
147
+ filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
148
+ phrases = [i for i in range(len(filtered_boxes))]
149
+ annotated_frame, _ = self.annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
150
+
151
+ pil_img = Image.fromarray(annotated_frame)
152
+ buffered = io.BytesIO()
153
+ pil_img.save(buffered, format="PNG")
154
+ encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
155
+ assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
156
+
157
+ return encoded_image, filtered_boxes_elem
158
+
159
+ @staticmethod
160
+ def predict_yolo(yolo_model, image, box_threshold, iou_threshold):
161
+ """ Use huggingface model to replace the original model
162
+ """
163
+ result = yolo_model.predict(
164
+ source=image,
165
+ conf=box_threshold,
166
+ iou=iou_threshold,
167
+ verbose=False
168
+ )
169
+ boxes = result[0].boxes.xyxy
170
+ conf = result[0].boxes.conf
171
+ phrases = [str(i) for i in range(len(boxes))]
172
+
173
+ return boxes, conf, phrases
174
+
175
+ @staticmethod
176
+ def remove_overlap_new(boxes, iou_threshold, ocr_bbox):
177
+ '''
178
+ boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
179
+ ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
180
+
181
+ '''
182
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
183
+
184
+ def box_area(box):
185
+ return (box[2] - box[0]) * (box[3] - box[1])
186
+
187
+ def intersection_area(box1, box2):
188
+ x1 = max(box1[0], box2[0])
189
+ y1 = max(box1[1], box2[1])
190
+ x2 = min(box1[2], box2[2])
191
+ y2 = min(box1[3], box2[3])
192
+ return max(0, x2 - x1) * max(0, y2 - y1)
193
+
194
+ def IoU(box1, box2):
195
+ intersection = intersection_area(box1, box2)
196
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
197
+ if box_area(box1) > 0 and box_area(box2) > 0:
198
+ ratio1 = intersection / box_area(box1)
199
+ ratio2 = intersection / box_area(box2)
200
+ else:
201
+ ratio1, ratio2 = 0, 0
202
+ return max(intersection / union, ratio1, ratio2)
203
+
204
+ def is_inside(box1, box2):
205
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
206
+ intersection = intersection_area(box1, box2)
207
+ ratio1 = intersection / box_area(box1)
208
+ return ratio1 > 0.80
209
+
210
+ filtered_boxes = []
211
+ if ocr_bbox:
212
+ filtered_boxes.extend(ocr_bbox)
213
+ for i, box1_elem in enumerate(boxes):
214
+ box1 = box1_elem['bbox']
215
+ is_valid_box = True
216
+ for j, box2_elem in enumerate(boxes):
217
+ # keep the smaller box
218
+ box2 = box2_elem['bbox']
219
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
220
+ is_valid_box = False
221
+ break
222
+ if is_valid_box:
223
+ if ocr_bbox:
224
+ box_added = False
225
+ ocr_labels = ''
226
+ for box3_elem in ocr_bbox:
227
+ if not box_added:
228
+ box3 = box3_elem['bbox']
229
+ if is_inside(box3, box1): # ocr inside icon
230
+ # delete the box3_elem from ocr_bbox
231
+ try:
232
+ # gather all ocr labels
233
+ ocr_labels += box3_elem['content'] + ' '
234
+ filtered_boxes.remove(box3_elem)
235
+ except:
236
+ continue
237
+ elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
238
+ box_added = True
239
+ break
240
+ else:
241
+ continue
242
+ if not box_added:
243
+ if ocr_labels:
244
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
245
+ else:
246
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
247
+ else:
248
+ filtered_boxes.append(box1)
249
+ return filtered_boxes
250
+
251
+ @torch.inference_mode()
252
+ def get_parsed_content_icon(self, filtered_boxes, starting_idx, image_source, prompt=None, batch_size=128):
253
+ # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
254
+ non_ocr_boxes = filtered_boxes[starting_idx:]
255
+ cropped_pil_image = []
256
+ for i, coord in enumerate(non_ocr_boxes):
257
+ try:
258
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
259
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
260
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
261
+ cropped_image = cv2.resize(cropped_image, (64, 64))
262
+ cropped_pil_image.append(self.to_pil(cropped_image))
263
+ except:
264
+ continue
265
+
266
+ caption_model, caption_model_processor = self.caption_model, self.caption_model_processor
267
+ if not prompt:
268
+ prompt = "<CAPTION>"
269
+
270
+ generated_texts = []
271
+ device = caption_model.device
272
+ for i in range(0, len(cropped_pil_image), batch_size):
273
+ batch = cropped_pil_image[i:i+batch_size]
274
+ if caption_model.device.type == 'cuda':
275
+ inputs = caption_model_processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
276
+ else:
277
+ inputs = caption_model_processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
278
+ generated_ids = caption_model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False, early_stopping=False)
279
+ generated_text = caption_model_processor.batch_decode(generated_ids, skip_special_tokens=True)
280
+ generated_text = [gen.strip() for gen in generated_text]
281
+ generated_texts.extend(generated_text)
282
+
283
+ return generated_texts
284
+
285
+ @staticmethod
286
+ def annotate(image_source, boxes, logits, phrases, text_scale,
287
+ text_padding=5, text_thickness=2, thickness=3):
288
+ """
289
+ This function annotates an image with bounding boxes and labels.
290
+
291
+ Parameters:
292
+ image_source (np.ndarray): The source image to be annotated.
293
+ boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
294
+ logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
295
+ phrases (List[str]): A list of labels for each bounding box.
296
+ text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
297
+
298
+ Returns:
299
+ np.ndarray: The annotated image.
300
+ """
301
+ h, w, _ = image_source.shape
302
+ boxes = boxes * torch.Tensor([w, h, w, h])
303
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
304
+ xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
305
+ detections = sv.Detections(xyxy=xyxy)
306
+
307
+ labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
308
+
309
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
310
+ annotated_frame = image_source.copy()
311
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
312
+
313
+ label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
314
+ return annotated_frame, label_coordinates
315
+
316
+ @staticmethod
317
+ def int_box_area(box, w, h):
318
+ x1, y1, x2, y2 = box
319
+ int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
320
+ area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
321
+ return area
322
+
323
+ @staticmethod
324
+ def get_xyxy(input):
325
+ x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
326
+ x, y, xp, yp = int(x), int(y), int(xp), int(yp)
327
+ return x, y, xp, yp
328
+
329
+
330
+ class OmniparserForCausalLM(OmniparserModel):
331
+ """OmniParser model with a causal language modeling interface."""
332
+ _keys_to_ignore_on_load_missing = [r"caption_model\..*"]
333
+
334
+ def __init__(self, config):
335
+ super().__init__(config)
336
+
337
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
338
+ """Required by HF generate() - passes image_source from kwargs to forward."""
339
+ return {"image_source": kwargs.get("image_source", None)}
340
+
341
+ def generate(self, image_source=None, **kwargs):
342
+ """Process image and return formatted text output of detected elements."""
343
+ outputs = super().forward(image_source=image_source)
344
+ res = json.dumps(outputs.parsed_content_list)
345
+ return res
346
+
347
+ # Register the model with Auto classes
348
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
349
+ AutoConfig.register("omniparserv2", OmniparserConfig)
350
+ AutoModel.register(OmniparserConfig, OmniparserModel)
351
+ AutoModelForCausalLM.register(OmniparserConfig, OmniparserForCausalLM)