Transformers
Browse files- box_annotator.py +262 -0
- config.json +5 -0
- configuration_omniparser.py +13 -0
- model.safetensors +3 -0
- modelling_omniparser.py +351 -0
box_annotator.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Tuple
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from supervision.detection.core import Detections
|
| 7 |
+
from supervision.draw.color import Color, ColorPalette
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BoxAnnotator:
|
| 11 |
+
"""
|
| 12 |
+
A class for drawing bounding boxes on an image using detections provided.
|
| 13 |
+
|
| 14 |
+
Attributes:
|
| 15 |
+
color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
| 16 |
+
can be a single color or a color palette
|
| 17 |
+
thickness (int): The thickness of the bounding box lines, default is 2
|
| 18 |
+
text_color (Color): The color of the text on the bounding box, default is white
|
| 19 |
+
text_scale (float): The scale of the text on the bounding box, default is 0.5
|
| 20 |
+
text_thickness (int): The thickness of the text on the bounding box,
|
| 21 |
+
default is 1
|
| 22 |
+
text_padding (int): The padding around the text on the bounding box,
|
| 23 |
+
default is 5
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
| 30 |
+
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
| 31 |
+
text_color: Color = Color.BLACK,
|
| 32 |
+
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
| 33 |
+
text_thickness: int = 2, #1, # 2 for demo
|
| 34 |
+
text_padding: int = 10,
|
| 35 |
+
avoid_overlap: bool = True,
|
| 36 |
+
):
|
| 37 |
+
self.color: Union[Color, ColorPalette] = color
|
| 38 |
+
self.thickness: int = thickness
|
| 39 |
+
self.text_color: Color = text_color
|
| 40 |
+
self.text_scale: float = text_scale
|
| 41 |
+
self.text_thickness: int = text_thickness
|
| 42 |
+
self.text_padding: int = text_padding
|
| 43 |
+
self.avoid_overlap: bool = avoid_overlap
|
| 44 |
+
|
| 45 |
+
def annotate(
|
| 46 |
+
self,
|
| 47 |
+
scene: np.ndarray,
|
| 48 |
+
detections: Detections,
|
| 49 |
+
labels: Optional[List[str]] = None,
|
| 50 |
+
skip_label: bool = False,
|
| 51 |
+
image_size: Optional[Tuple[int, int]] = None,
|
| 52 |
+
) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Draws bounding boxes on the frame using the detections provided.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
scene (np.ndarray): The image on which the bounding boxes will be drawn
|
| 58 |
+
detections (Detections): The detections for which the
|
| 59 |
+
bounding boxes will be drawn
|
| 60 |
+
labels (Optional[List[str]]): An optional list of labels
|
| 61 |
+
corresponding to each detection. If `labels` are not provided,
|
| 62 |
+
corresponding `class_id` will be used as label.
|
| 63 |
+
skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
| 64 |
+
Returns:
|
| 65 |
+
np.ndarray: The image with the bounding boxes drawn on it
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
```python
|
| 69 |
+
import supervision as sv
|
| 70 |
+
|
| 71 |
+
classes = ['person', ...]
|
| 72 |
+
image = ...
|
| 73 |
+
detections = sv.Detections(...)
|
| 74 |
+
|
| 75 |
+
box_annotator = sv.BoxAnnotator()
|
| 76 |
+
labels = [
|
| 77 |
+
f"{classes[class_id]} {confidence:0.2f}"
|
| 78 |
+
for _, _, confidence, class_id, _ in detections
|
| 79 |
+
]
|
| 80 |
+
annotated_frame = box_annotator.annotate(
|
| 81 |
+
scene=image.copy(),
|
| 82 |
+
detections=detections,
|
| 83 |
+
labels=labels
|
| 84 |
+
)
|
| 85 |
+
```
|
| 86 |
+
"""
|
| 87 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 88 |
+
for i in range(len(detections)):
|
| 89 |
+
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
| 90 |
+
class_id = (
|
| 91 |
+
detections.class_id[i] if detections.class_id is not None else None
|
| 92 |
+
)
|
| 93 |
+
idx = class_id if class_id is not None else i
|
| 94 |
+
color = (
|
| 95 |
+
self.color.by_idx(idx)
|
| 96 |
+
if isinstance(self.color, ColorPalette)
|
| 97 |
+
else self.color
|
| 98 |
+
)
|
| 99 |
+
cv2.rectangle(
|
| 100 |
+
img=scene,
|
| 101 |
+
pt1=(x1, y1),
|
| 102 |
+
pt2=(x2, y2),
|
| 103 |
+
color=color.as_bgr(),
|
| 104 |
+
thickness=self.thickness,
|
| 105 |
+
)
|
| 106 |
+
if skip_label:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
text = (
|
| 110 |
+
f"{class_id}"
|
| 111 |
+
if (labels is None or len(detections) != len(labels))
|
| 112 |
+
else labels[i]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
text_width, text_height = cv2.getTextSize(
|
| 116 |
+
text=text,
|
| 117 |
+
fontFace=font,
|
| 118 |
+
fontScale=self.text_scale,
|
| 119 |
+
thickness=self.text_thickness,
|
| 120 |
+
)[0]
|
| 121 |
+
|
| 122 |
+
if not self.avoid_overlap:
|
| 123 |
+
text_x = x1 + self.text_padding
|
| 124 |
+
text_y = y1 - self.text_padding
|
| 125 |
+
|
| 126 |
+
text_background_x1 = x1
|
| 127 |
+
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
| 128 |
+
|
| 129 |
+
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
| 130 |
+
text_background_y2 = y1
|
| 131 |
+
# text_x = x1 - self.text_padding - text_width
|
| 132 |
+
# text_y = y1 + self.text_padding + text_height
|
| 133 |
+
# text_background_x1 = x1 - 2 * self.text_padding - text_width
|
| 134 |
+
# text_background_y1 = y1
|
| 135 |
+
# text_background_x2 = x1
|
| 136 |
+
# text_background_y2 = y1 + 2 * self.text_padding + text_height
|
| 137 |
+
else:
|
| 138 |
+
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
| 139 |
+
|
| 140 |
+
cv2.rectangle(
|
| 141 |
+
img=scene,
|
| 142 |
+
pt1=(text_background_x1, text_background_y1),
|
| 143 |
+
pt2=(text_background_x2, text_background_y2),
|
| 144 |
+
color=color.as_bgr(),
|
| 145 |
+
thickness=cv2.FILLED,
|
| 146 |
+
)
|
| 147 |
+
# import pdb; pdb.set_trace()
|
| 148 |
+
box_color = color.as_rgb()
|
| 149 |
+
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
| 150 |
+
text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
| 151 |
+
cv2.putText(
|
| 152 |
+
img=scene,
|
| 153 |
+
text=text,
|
| 154 |
+
org=(text_x, text_y),
|
| 155 |
+
fontFace=font,
|
| 156 |
+
fontScale=self.text_scale,
|
| 157 |
+
# color=self.text_color.as_rgb(),
|
| 158 |
+
color=text_color,
|
| 159 |
+
thickness=self.text_thickness,
|
| 160 |
+
lineType=cv2.LINE_AA,
|
| 161 |
+
)
|
| 162 |
+
return scene
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def box_area(box):
|
| 166 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 167 |
+
|
| 168 |
+
def intersection_area(box1, box2):
|
| 169 |
+
x1 = max(box1[0], box2[0])
|
| 170 |
+
y1 = max(box1[1], box2[1])
|
| 171 |
+
x2 = min(box1[2], box2[2])
|
| 172 |
+
y2 = min(box1[3], box2[3])
|
| 173 |
+
return max(0, x2 - x1) * max(0, y2 - y1)
|
| 174 |
+
|
| 175 |
+
def IoU(box1, box2, return_max=True):
|
| 176 |
+
intersection = intersection_area(box1, box2)
|
| 177 |
+
union = box_area(box1) + box_area(box2) - intersection
|
| 178 |
+
if box_area(box1) > 0 and box_area(box2) > 0:
|
| 179 |
+
ratio1 = intersection / box_area(box1)
|
| 180 |
+
ratio2 = intersection / box_area(box2)
|
| 181 |
+
else:
|
| 182 |
+
ratio1, ratio2 = 0, 0
|
| 183 |
+
if return_max:
|
| 184 |
+
return max(intersection / union, ratio1, ratio2)
|
| 185 |
+
else:
|
| 186 |
+
return intersection / union
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
| 190 |
+
""" check overlap of text and background detection box, and get_optimal_label_pos,
|
| 191 |
+
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
| 192 |
+
Threshold: default to 0.3
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
| 196 |
+
is_overlap = False
|
| 197 |
+
for i in range(len(detections)):
|
| 198 |
+
detection = detections.xyxy[i].astype(int)
|
| 199 |
+
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
| 200 |
+
is_overlap = True
|
| 201 |
+
break
|
| 202 |
+
# check if the text is out of the image
|
| 203 |
+
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
| 204 |
+
is_overlap = True
|
| 205 |
+
return is_overlap
|
| 206 |
+
|
| 207 |
+
# if pos == 'top left':
|
| 208 |
+
text_x = x1 + text_padding
|
| 209 |
+
text_y = y1 - text_padding
|
| 210 |
+
|
| 211 |
+
text_background_x1 = x1
|
| 212 |
+
text_background_y1 = y1 - 2 * text_padding - text_height
|
| 213 |
+
|
| 214 |
+
text_background_x2 = x1 + 2 * text_padding + text_width
|
| 215 |
+
text_background_y2 = y1
|
| 216 |
+
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
| 217 |
+
if not is_overlap:
|
| 218 |
+
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
| 219 |
+
|
| 220 |
+
# elif pos == 'outer left':
|
| 221 |
+
text_x = x1 - text_padding - text_width
|
| 222 |
+
text_y = y1 + text_padding + text_height
|
| 223 |
+
|
| 224 |
+
text_background_x1 = x1 - 2 * text_padding - text_width
|
| 225 |
+
text_background_y1 = y1
|
| 226 |
+
|
| 227 |
+
text_background_x2 = x1
|
| 228 |
+
text_background_y2 = y1 + 2 * text_padding + text_height
|
| 229 |
+
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
| 230 |
+
if not is_overlap:
|
| 231 |
+
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# elif pos == 'outer right':
|
| 235 |
+
text_x = x2 + text_padding
|
| 236 |
+
text_y = y1 + text_padding + text_height
|
| 237 |
+
|
| 238 |
+
text_background_x1 = x2
|
| 239 |
+
text_background_y1 = y1
|
| 240 |
+
|
| 241 |
+
text_background_x2 = x2 + 2 * text_padding + text_width
|
| 242 |
+
text_background_y2 = y1 + 2 * text_padding + text_height
|
| 243 |
+
|
| 244 |
+
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
| 245 |
+
if not is_overlap:
|
| 246 |
+
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
| 247 |
+
|
| 248 |
+
# elif pos == 'top right':
|
| 249 |
+
text_x = x2 - text_padding - text_width
|
| 250 |
+
text_y = y1 - text_padding
|
| 251 |
+
|
| 252 |
+
text_background_x1 = x2 - 2 * text_padding - text_width
|
| 253 |
+
text_background_y1 = y1 - 2 * text_padding - text_height
|
| 254 |
+
|
| 255 |
+
text_background_x2 = x2
|
| 256 |
+
text_background_y2 = y1
|
| 257 |
+
|
| 258 |
+
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
| 259 |
+
if not is_overlap:
|
| 260 |
+
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
| 261 |
+
|
| 262 |
+
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"box_threshold": 0.05,
|
| 3 |
+
"model_type": "omniparserv2",
|
| 4 |
+
"transformers_version": "4.48.1"
|
| 5 |
+
}
|
configuration_omniparser.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
class OmniparserConfig(PretrainedConfig):
|
| 5 |
+
model_type = "omniparserv2"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
box_threshold: float = 0.05,
|
| 10 |
+
**kwargs,
|
| 11 |
+
):
|
| 12 |
+
self.box_threshold = box_threshold
|
| 13 |
+
super().__init__(**kwargs)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
|
| 3 |
+
size 16
|
modelling_omniparser.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 3 |
+
from .configuration_omniparser import OmniparserConfig
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
from torchvision.transforms import ToPILImage
|
| 6 |
+
from torchvision.ops import box_convert
|
| 7 |
+
import base64
|
| 8 |
+
import supervision as sv
|
| 9 |
+
from .box_annotator import BoxAnnotator
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import io
|
| 15 |
+
import cv2
|
| 16 |
+
import easyocr
|
| 17 |
+
from typing import List, Dict
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from transformers.modeling_outputs import ModelOutput
|
| 20 |
+
import json
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class OmniparserOutput(ModelOutput):
|
| 24 |
+
"""
|
| 25 |
+
Output type of [`OmniparserModel`].
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
annotated_image (`str`):
|
| 29 |
+
The image with bounding boxes and labels drawn on it, returned as base64 encoded string.
|
| 30 |
+
parsed_content_list (`List[Dict]`):
|
| 31 |
+
Lists of detected elements with their properties (type, bbox, interactivity, content, source).
|
| 32 |
+
"""
|
| 33 |
+
annotated_image: str = None
|
| 34 |
+
parsed_content_list: List[Dict] = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class OmniparserModel(PreTrainedModel):
|
| 38 |
+
config_class = OmniparserConfig
|
| 39 |
+
_keys_to_ignore_on_load_missing = [r"caption_model\..*"]
|
| 40 |
+
|
| 41 |
+
def __init__(self, config):
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
self._device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
current_file_directory = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
+
self.icon_detect_path = os.path.join(current_file_directory, 'icon_detect/model.pt')
|
| 46 |
+
icon_caption_path = os.path.join(current_file_directory, 'icon_caption')
|
| 47 |
+
caption_model_processor = self.get_caption_model_processor(model_name_or_path=icon_caption_path)
|
| 48 |
+
self.caption_model, self.caption_model_processor = caption_model_processor['model'], caption_model_processor['processor']
|
| 49 |
+
self.reader = easyocr.Reader(['en'])
|
| 50 |
+
self.to_pil = ToPILImage()
|
| 51 |
+
self.som_model_inited = False
|
| 52 |
+
|
| 53 |
+
def forward(self, image_source):
|
| 54 |
+
# Lazy load YOLO model as autoload calls .eval which breaks YOLO unless init in forward
|
| 55 |
+
if not self.som_model_inited:
|
| 56 |
+
self.som_model = YOLO(self.icon_detect_path)
|
| 57 |
+
self.som_model_inited = True
|
| 58 |
+
|
| 59 |
+
if image_source is None:
|
| 60 |
+
raise ValueError("No image provided")
|
| 61 |
+
if isinstance(image_source, str):
|
| 62 |
+
image_source = Image.open(image_source)
|
| 63 |
+
image_source = image_source.convert("RGB")
|
| 64 |
+
|
| 65 |
+
# Scan for text
|
| 66 |
+
ocr_text, ocr_bbox = self.check_ocr_box(image_source, easyocr_args={'text_threshold': 0.8})
|
| 67 |
+
# Yolo + merge text and YOLO + florence
|
| 68 |
+
annotated_image, parsed_content_list = self.get_som_labeled_img(image_source,
|
| 69 |
+
self.som_model,
|
| 70 |
+
box_threshold=self.config.box_threshold,
|
| 71 |
+
ocr_bbox=ocr_bbox,
|
| 72 |
+
ocr_text=ocr_text)
|
| 73 |
+
|
| 74 |
+
return OmniparserOutput(
|
| 75 |
+
annotated_image=annotated_image,
|
| 76 |
+
parsed_content_list=parsed_content_list
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def get_caption_model_processor(self, model_name_or_path):
|
| 80 |
+
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
| 81 |
+
if self._device_type == 'cpu':
|
| 82 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
|
| 83 |
+
else:
|
| 84 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(self._device_type)
|
| 85 |
+
return {'model': model.to(self._device_type), 'processor': processor}
|
| 86 |
+
|
| 87 |
+
def check_ocr_box(self, image_source, easyocr_args):
|
| 88 |
+
image_np = np.array(image_source)
|
| 89 |
+
w, h = image_source.size
|
| 90 |
+
result = self.reader.readtext(image_np, **easyocr_args)
|
| 91 |
+
coord = [item[0] for item in result]
|
| 92 |
+
text = [item[1] for item in result]
|
| 93 |
+
bb = [self.get_xyxy(item) for item in coord]
|
| 94 |
+
return text, bb
|
| 95 |
+
|
| 96 |
+
def get_som_labeled_img(self, image_source,
|
| 97 |
+
model,
|
| 98 |
+
box_threshold,
|
| 99 |
+
ocr_bbox,
|
| 100 |
+
ocr_text,
|
| 101 |
+
iou_threshold=0.7,
|
| 102 |
+
prompt=None,
|
| 103 |
+
batch_size=128):
|
| 104 |
+
w, h = image_source.size
|
| 105 |
+
box_overlay_ratio = max(w, h) / 3200
|
| 106 |
+
draw_bbox_config = {
|
| 107 |
+
'text_scale': 0.8 * box_overlay_ratio,
|
| 108 |
+
'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
| 109 |
+
'text_padding': max(int(3 * box_overlay_ratio), 1),
|
| 110 |
+
'thickness': max(int(3 * box_overlay_ratio), 1),
|
| 111 |
+
}
|
| 112 |
+
xyxy, logits, phrases = self.predict_yolo(yolo_model=model, image=image_source, box_threshold=box_threshold, iou_threshold=0.1)
|
| 113 |
+
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
| 114 |
+
image_source = np.asarray(image_source)
|
| 115 |
+
phrases = [str(i) for i in range(len(phrases))]
|
| 116 |
+
|
| 117 |
+
# merge text and YOLO boxes
|
| 118 |
+
ocr_bbox_elem = None
|
| 119 |
+
if ocr_bbox:
|
| 120 |
+
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
|
| 121 |
+
ocr_bbox=ocr_bbox.tolist()
|
| 122 |
+
ocr_bbox_elem = [{'type': 'text', 'bbox': box, 'interactivity': False, 'content': txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if self.int_box_area(box, w, h) > 0]
|
| 123 |
+
xyxy_elem = [{'type': 'icon', 'bbox': box, 'interactivity': True, 'content': None, 'source': 'box_yolo_content_yolo'} for box in xyxy.tolist() if self.int_box_area(box, w, h) > 0]
|
| 124 |
+
filtered_boxes = self.remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
|
| 125 |
+
|
| 126 |
+
# sort the filtered_boxes so that ocr first, yolo last
|
| 127 |
+
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
|
| 128 |
+
# get the index of the first yolo box
|
| 129 |
+
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
|
| 130 |
+
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
|
| 131 |
+
|
| 132 |
+
# get local semantics using florence on yolo boxes
|
| 133 |
+
# TODO: fix that should not be using starting_idx ... instead use a where content is none ... O(n) rather than using sort
|
| 134 |
+
parsed_content_icon = self.get_parsed_content_icon(filtered_boxes, starting_idx, image_source, prompt=prompt, batch_size=batch_size)
|
| 135 |
+
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
| 136 |
+
icon_start = len(ocr_text)
|
| 137 |
+
parsed_content_icon_ls = []
|
| 138 |
+
# fill the filtered_boxes_elem None content with parsed_content_icon in order
|
| 139 |
+
for i, box in enumerate(filtered_boxes_elem):
|
| 140 |
+
if box['content'] is None:
|
| 141 |
+
box['content'] = parsed_content_icon.pop(0)
|
| 142 |
+
for i, txt in enumerate(parsed_content_icon):
|
| 143 |
+
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
|
| 144 |
+
parsed_content_merged = ocr_text + parsed_content_icon_ls
|
| 145 |
+
|
| 146 |
+
# draw boxes
|
| 147 |
+
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
| 148 |
+
phrases = [i for i in range(len(filtered_boxes))]
|
| 149 |
+
annotated_frame, _ = self.annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
|
| 150 |
+
|
| 151 |
+
pil_img = Image.fromarray(annotated_frame)
|
| 152 |
+
buffered = io.BytesIO()
|
| 153 |
+
pil_img.save(buffered, format="PNG")
|
| 154 |
+
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
| 155 |
+
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
|
| 156 |
+
|
| 157 |
+
return encoded_image, filtered_boxes_elem
|
| 158 |
+
|
| 159 |
+
@staticmethod
|
| 160 |
+
def predict_yolo(yolo_model, image, box_threshold, iou_threshold):
|
| 161 |
+
""" Use huggingface model to replace the original model
|
| 162 |
+
"""
|
| 163 |
+
result = yolo_model.predict(
|
| 164 |
+
source=image,
|
| 165 |
+
conf=box_threshold,
|
| 166 |
+
iou=iou_threshold,
|
| 167 |
+
verbose=False
|
| 168 |
+
)
|
| 169 |
+
boxes = result[0].boxes.xyxy
|
| 170 |
+
conf = result[0].boxes.conf
|
| 171 |
+
phrases = [str(i) for i in range(len(boxes))]
|
| 172 |
+
|
| 173 |
+
return boxes, conf, phrases
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def remove_overlap_new(boxes, iou_threshold, ocr_bbox):
|
| 177 |
+
'''
|
| 178 |
+
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
|
| 179 |
+
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
|
| 180 |
+
|
| 181 |
+
'''
|
| 182 |
+
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
| 183 |
+
|
| 184 |
+
def box_area(box):
|
| 185 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 186 |
+
|
| 187 |
+
def intersection_area(box1, box2):
|
| 188 |
+
x1 = max(box1[0], box2[0])
|
| 189 |
+
y1 = max(box1[1], box2[1])
|
| 190 |
+
x2 = min(box1[2], box2[2])
|
| 191 |
+
y2 = min(box1[3], box2[3])
|
| 192 |
+
return max(0, x2 - x1) * max(0, y2 - y1)
|
| 193 |
+
|
| 194 |
+
def IoU(box1, box2):
|
| 195 |
+
intersection = intersection_area(box1, box2)
|
| 196 |
+
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
| 197 |
+
if box_area(box1) > 0 and box_area(box2) > 0:
|
| 198 |
+
ratio1 = intersection / box_area(box1)
|
| 199 |
+
ratio2 = intersection / box_area(box2)
|
| 200 |
+
else:
|
| 201 |
+
ratio1, ratio2 = 0, 0
|
| 202 |
+
return max(intersection / union, ratio1, ratio2)
|
| 203 |
+
|
| 204 |
+
def is_inside(box1, box2):
|
| 205 |
+
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
| 206 |
+
intersection = intersection_area(box1, box2)
|
| 207 |
+
ratio1 = intersection / box_area(box1)
|
| 208 |
+
return ratio1 > 0.80
|
| 209 |
+
|
| 210 |
+
filtered_boxes = []
|
| 211 |
+
if ocr_bbox:
|
| 212 |
+
filtered_boxes.extend(ocr_bbox)
|
| 213 |
+
for i, box1_elem in enumerate(boxes):
|
| 214 |
+
box1 = box1_elem['bbox']
|
| 215 |
+
is_valid_box = True
|
| 216 |
+
for j, box2_elem in enumerate(boxes):
|
| 217 |
+
# keep the smaller box
|
| 218 |
+
box2 = box2_elem['bbox']
|
| 219 |
+
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
| 220 |
+
is_valid_box = False
|
| 221 |
+
break
|
| 222 |
+
if is_valid_box:
|
| 223 |
+
if ocr_bbox:
|
| 224 |
+
box_added = False
|
| 225 |
+
ocr_labels = ''
|
| 226 |
+
for box3_elem in ocr_bbox:
|
| 227 |
+
if not box_added:
|
| 228 |
+
box3 = box3_elem['bbox']
|
| 229 |
+
if is_inside(box3, box1): # ocr inside icon
|
| 230 |
+
# delete the box3_elem from ocr_bbox
|
| 231 |
+
try:
|
| 232 |
+
# gather all ocr labels
|
| 233 |
+
ocr_labels += box3_elem['content'] + ' '
|
| 234 |
+
filtered_boxes.remove(box3_elem)
|
| 235 |
+
except:
|
| 236 |
+
continue
|
| 237 |
+
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
|
| 238 |
+
box_added = True
|
| 239 |
+
break
|
| 240 |
+
else:
|
| 241 |
+
continue
|
| 242 |
+
if not box_added:
|
| 243 |
+
if ocr_labels:
|
| 244 |
+
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
|
| 245 |
+
else:
|
| 246 |
+
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
|
| 247 |
+
else:
|
| 248 |
+
filtered_boxes.append(box1)
|
| 249 |
+
return filtered_boxes
|
| 250 |
+
|
| 251 |
+
@torch.inference_mode()
|
| 252 |
+
def get_parsed_content_icon(self, filtered_boxes, starting_idx, image_source, prompt=None, batch_size=128):
|
| 253 |
+
# Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
|
| 254 |
+
non_ocr_boxes = filtered_boxes[starting_idx:]
|
| 255 |
+
cropped_pil_image = []
|
| 256 |
+
for i, coord in enumerate(non_ocr_boxes):
|
| 257 |
+
try:
|
| 258 |
+
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
| 259 |
+
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
| 260 |
+
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
| 261 |
+
cropped_image = cv2.resize(cropped_image, (64, 64))
|
| 262 |
+
cropped_pil_image.append(self.to_pil(cropped_image))
|
| 263 |
+
except:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
caption_model, caption_model_processor = self.caption_model, self.caption_model_processor
|
| 267 |
+
if not prompt:
|
| 268 |
+
prompt = "<CAPTION>"
|
| 269 |
+
|
| 270 |
+
generated_texts = []
|
| 271 |
+
device = caption_model.device
|
| 272 |
+
for i in range(0, len(cropped_pil_image), batch_size):
|
| 273 |
+
batch = cropped_pil_image[i:i+batch_size]
|
| 274 |
+
if caption_model.device.type == 'cuda':
|
| 275 |
+
inputs = caption_model_processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
|
| 276 |
+
else:
|
| 277 |
+
inputs = caption_model_processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
|
| 278 |
+
generated_ids = caption_model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False, early_stopping=False)
|
| 279 |
+
generated_text = caption_model_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 280 |
+
generated_text = [gen.strip() for gen in generated_text]
|
| 281 |
+
generated_texts.extend(generated_text)
|
| 282 |
+
|
| 283 |
+
return generated_texts
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def annotate(image_source, boxes, logits, phrases, text_scale,
|
| 287 |
+
text_padding=5, text_thickness=2, thickness=3):
|
| 288 |
+
"""
|
| 289 |
+
This function annotates an image with bounding boxes and labels.
|
| 290 |
+
|
| 291 |
+
Parameters:
|
| 292 |
+
image_source (np.ndarray): The source image to be annotated.
|
| 293 |
+
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
|
| 294 |
+
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
| 295 |
+
phrases (List[str]): A list of labels for each bounding box.
|
| 296 |
+
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
np.ndarray: The annotated image.
|
| 300 |
+
"""
|
| 301 |
+
h, w, _ = image_source.shape
|
| 302 |
+
boxes = boxes * torch.Tensor([w, h, w, h])
|
| 303 |
+
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
| 304 |
+
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
|
| 305 |
+
detections = sv.Detections(xyxy=xyxy)
|
| 306 |
+
|
| 307 |
+
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
|
| 308 |
+
|
| 309 |
+
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
| 310 |
+
annotated_frame = image_source.copy()
|
| 311 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
|
| 312 |
+
|
| 313 |
+
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
|
| 314 |
+
return annotated_frame, label_coordinates
|
| 315 |
+
|
| 316 |
+
@staticmethod
|
| 317 |
+
def int_box_area(box, w, h):
|
| 318 |
+
x1, y1, x2, y2 = box
|
| 319 |
+
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
|
| 320 |
+
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
| 321 |
+
return area
|
| 322 |
+
|
| 323 |
+
@staticmethod
|
| 324 |
+
def get_xyxy(input):
|
| 325 |
+
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
|
| 326 |
+
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
|
| 327 |
+
return x, y, xp, yp
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class OmniparserForCausalLM(OmniparserModel):
|
| 331 |
+
"""OmniParser model with a causal language modeling interface."""
|
| 332 |
+
_keys_to_ignore_on_load_missing = [r"caption_model\..*"]
|
| 333 |
+
|
| 334 |
+
def __init__(self, config):
|
| 335 |
+
super().__init__(config)
|
| 336 |
+
|
| 337 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 338 |
+
"""Required by HF generate() - passes image_source from kwargs to forward."""
|
| 339 |
+
return {"image_source": kwargs.get("image_source", None)}
|
| 340 |
+
|
| 341 |
+
def generate(self, image_source=None, **kwargs):
|
| 342 |
+
"""Process image and return formatted text output of detected elements."""
|
| 343 |
+
outputs = super().forward(image_source=image_source)
|
| 344 |
+
res = json.dumps(outputs.parsed_content_list)
|
| 345 |
+
return res
|
| 346 |
+
|
| 347 |
+
# Register the model with Auto classes
|
| 348 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 349 |
+
AutoConfig.register("omniparserv2", OmniparserConfig)
|
| 350 |
+
AutoModel.register(OmniparserConfig, OmniparserModel)
|
| 351 |
+
AutoModelForCausalLM.register(OmniparserConfig, OmniparserForCausalLM)
|