OmniParser-v2.0 / modeling_omniparser.py
ThomasDh-C's picture
Update modeling_omniparser.py
c4b9391 verified
from transformers import PreTrainedModel
from transformers import AutoProcessor, AutoModelForCausalLM
from .configuration_omniparser import OmniparserConfig
from ultralytics import YOLO
from torchvision.transforms import ToPILImage
from torchvision.ops import box_convert
import base64
import supervision as sv
import os
import torch
import numpy as np
from PIL import Image
import io
import cv2
import easyocr
from typing import List, Dict
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
import json
from supervision.detection.core import Detections
from supervision.draw.color import Color, ColorPalette
from typing import List, Optional, Union, Tuple
from huggingface_hub import snapshot_download
@dataclass
class OmniparserOutput(ModelOutput):
"""
Output type of [`OmniparserModel`].
Args:
annotated_image (`str`):
The image with bounding boxes and labels drawn on it, returned as base64 encoded string.
parsed_content_list (`List[Dict]`):
Lists of detected elements with their properties (type, bbox, interactivity, content, source).
"""
annotated_image: str = None
parsed_content_list: List[Dict] = None
class OmniparserModel(PreTrainedModel):
config_class = OmniparserConfig
_keys_to_ignore_on_load_missing = [r"caption_model\..*"]
def __init__(self, config):
super().__init__(config)
self._device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
icon_detect_path = snapshot_download(
repo_id="ThomasDh-C/OmniParser-v2.0",
allow_patterns="icon_detect/**",
)
self.icon_detect_path = os.path.join(icon_detect_path, 'icon_detect', 'model.pt')
icon_caption_path = snapshot_download(
repo_id="ThomasDh-C/OmniParser-v2.0",
allow_patterns="icon_caption/**",
)
caption_model_processor = self.get_caption_model_processor(model_name_or_path=os.path.join(icon_caption_path, 'icon_caption'))
self.caption_model, self.caption_model_processor = caption_model_processor['model'], caption_model_processor['processor']
self.reader = easyocr.Reader(['en'])
self.to_pil = ToPILImage()
self.som_model_inited = False
def forward(self, image_source):
# Lazy load YOLO model as autoload calls .eval which breaks YOLO unless init in forward
if not self.som_model_inited:
self.som_model = YOLO(self.icon_detect_path)
self.som_model_inited = True
if image_source is None:
raise ValueError("No image provided")
if isinstance(image_source, str):
image_source = Image.open(image_source)
image_source = image_source.convert("RGB")
# Scan for text
ocr_text, ocr_bbox = self.check_ocr_box(image_source, easyocr_args={'text_threshold': 0.8})
# Yolo + merge text and YOLO + florence
annotated_image, parsed_content_list = self.get_som_labeled_img(image_source,
self.som_model,
box_threshold=self.config.box_threshold,
ocr_bbox=ocr_bbox,
ocr_text=ocr_text)
return OmniparserOutput(
annotated_image=annotated_image,
parsed_content_list=parsed_content_list
)
def get_caption_model_processor(self, model_name_or_path):
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if self._device_type == 'cpu':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(self._device_type)
return {'model': model.to(self._device_type), 'processor': processor}
def check_ocr_box(self, image_source, easyocr_args):
image_np = np.array(image_source)
w, h = image_source.size
result = self.reader.readtext(image_np, **easyocr_args)
coord = [item[0] for item in result]
text = [item[1] for item in result]
bb = [self.get_xyxy(item) for item in coord]
return text, bb
def get_som_labeled_img(self, image_source,
model,
box_threshold,
ocr_bbox,
ocr_text,
iou_threshold=0.7,
prompt=None,
batch_size=128):
w, h = image_source.size
box_overlay_ratio = max(w, h) / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
xyxy, logits, phrases = self.predict_yolo(yolo_model=model, image=image_source, box_threshold=box_threshold, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))]
# merge text and YOLO boxes
ocr_bbox_elem = None
if ocr_bbox:
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
ocr_bbox=ocr_bbox.tolist()
ocr_bbox_elem = [{'type': 'text', 'bbox': box, 'interactivity': False, 'content': txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if self.int_box_area(box, w, h) > 0]
xyxy_elem = [{'type': 'icon', 'bbox': box, 'interactivity': True, 'content': None, 'source': 'box_yolo_content_yolo'} for box in xyxy.tolist() if self.int_box_area(box, w, h) > 0]
filtered_boxes = self.remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
# sort the filtered_boxes so that ocr first, yolo last
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
# get the index of the first yolo box
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
# get local semantics using florence on yolo boxes
# TODO: fix that should not be using starting_idx ... instead use a where content is none ... O(n) rather than using sort
parsed_content_icon = self.get_parsed_content_icon(filtered_boxes, starting_idx, image_source, prompt=prompt, batch_size=batch_size)
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
icon_start = len(ocr_text)
parsed_content_icon_ls = []
# fill the filtered_boxes_elem None content with parsed_content_icon in order
for i, box in enumerate(filtered_boxes_elem):
if box['content'] is None:
box['content'] = parsed_content_icon.pop(0)
for i, txt in enumerate(parsed_content_icon):
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
parsed_content_merged = ocr_text + parsed_content_icon_ls
# draw boxes
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
phrases = [i for i in range(len(filtered_boxes))]
annotated_frame, _ = self.annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
pil_img = Image.fromarray(annotated_frame)
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
return encoded_image, filtered_boxes_elem
@staticmethod
def predict_yolo(yolo_model, image, box_threshold, iou_threshold):
""" Use huggingface model to replace the original model
"""
result = yolo_model.predict(
source=image,
conf=box_threshold,
iou=iou_threshold,
verbose=False
)
boxes = result[0].boxes.xyxy
conf = result[0].boxes.conf
phrases = [str(i) for i in range(len(boxes))]
return boxes, conf, phrases
def box_area(self, box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(self, box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(self, box1, box2):
intersection = self.intersection_area(box1, box2)
union = self.box_area(box1) + self.box_area(box2) - intersection + 1e-6
if self.box_area(box1) > 0 and self.box_area(box2) > 0:
ratio1 = intersection / self.box_area(box1)
ratio2 = intersection / self.box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def remove_overlap_new(self, boxes, iou_threshold, ocr_bbox):
'''
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
'''
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = self.intersection_area(box1, box2)
ratio1 = intersection / self.box_area(box1)
return ratio1 > 0.80
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
for i, box1_elem in enumerate(boxes):
box1 = box1_elem['bbox']
is_valid_box = True
for j, box2_elem in enumerate(boxes):
# keep the smaller box
box2 = box2_elem['bbox']
if i != j and self.IoU(box1, box2) > iou_threshold and self.box_area(box1) > self.box_area(box2):
is_valid_box = False
break
if is_valid_box:
if ocr_bbox:
box_added = False
ocr_labels = ''
for box3_elem in ocr_bbox:
if not box_added:
box3 = box3_elem['bbox']
if is_inside(box3, box1): # ocr inside icon
# delete the box3_elem from ocr_bbox
try:
# gather all ocr labels
ocr_labels += box3_elem['content'] + ' '
filtered_boxes.remove(box3_elem)
except:
continue
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
box_added = True
break
else:
continue
if not box_added:
if ocr_labels:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
else:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
else:
filtered_boxes.append(box1)
return filtered_boxes
@torch.inference_mode()
def get_parsed_content_icon(self, filtered_boxes, starting_idx, image_source, prompt=None, batch_size=128):
# Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
non_ocr_boxes = filtered_boxes[starting_idx:]
cropped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
try:
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
cropped_image = cv2.resize(cropped_image, (64, 64))
cropped_pil_image.append(self.to_pil(cropped_image))
except:
continue
caption_model, caption_model_processor = self.caption_model, self.caption_model_processor
if not prompt:
prompt = "<CAPTION>"
generated_texts = []
device = caption_model.device
for i in range(0, len(cropped_pil_image), batch_size):
batch = cropped_pil_image[i:i+batch_size]
if caption_model.device.type == 'cuda':
inputs = caption_model_processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = caption_model_processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
generated_ids = caption_model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False, early_stopping=False)
generated_text = caption_model_processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = [gen.strip() for gen in generated_text]
generated_texts.extend(generated_text)
return generated_texts
class BoxAnnotator:
"""
A class for drawing bounding boxes on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to draw the bounding box,
can be a single color or a color palette
thickness (int): The thickness of the bounding box lines, default is 2
text_color (Color): The color of the text on the bounding box, default is white
text_scale (float): The scale of the text on the bounding box, default is 0.5
text_thickness (int): The thickness of the text on the bounding box,
default is 1
text_padding (int): The padding around the text on the bounding box,
default is 5
"""
def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
text_color: Color = Color.BLACK,
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
text_thickness: int = 2, #1, # 2 for demo
text_padding: int = 10,
avoid_overlap: bool = True,
get_is_overlap_fn=None,
get_optimal_label_pos_fn=None,
):
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding
self.avoid_overlap: bool = avoid_overlap
self.get_is_overlap_fn = get_is_overlap_fn
self.get_optimal_label_pos_fn = get_optimal_label_pos_fn
def annotate(
self,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
image_size: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
Args:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the
bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels
corresponding to each detection. If `labels` are not provided,
corresponding `class_id` will be used as label.
skip_label (bool): Is set to `True`, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
Example:
```python
import supervision as sv
classes = ['person', ...]
image = ...
detections = sv.Detections(...)
box_annotator = sv.BoxAnnotator()
labels = [
f"{classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detections
]
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
```
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue
text = (
f"{class_id}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
if not self.avoid_overlap:
text_x = x1 + self.text_padding
text_y = y1 - self.text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height
text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1
else:
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = self.get_optimal_label_pos_fn(
self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size
)
cv2.rectangle(
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
# import pdb; pdb.set_trace()
box_color = color.as_rgb()
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
text_color = (0,0,0) if luminance > 160 else (255,255,255)
cv2.putText(
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
# color=self.text_color.as_rgb(),
color=text_color,
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene
def get_is_overlap(self, detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
is_overlap = False
for i in range(len(detections)):
detection = detections.xyxy[i].astype(int)
if self.IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
is_overlap = True
break
# check if the text is out of the image
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
is_overlap = True
return is_overlap
def get_optimal_label_pos(self, text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
""" check overlap of text and background detection box, and get_optimal_label_pos,
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
Threshold: default to 0.3
"""
# if pos == 'top left':
text_x = x1 + text_padding
text_y = y1 - text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x1 + 2 * text_padding + text_width
text_background_y2 = y1
is_overlap = self.get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer left':
text_x = x1 - text_padding - text_width
text_y = y1 + text_padding + text_height
text_background_x1 = x1 - 2 * text_padding - text_width
text_background_y1 = y1
text_background_x2 = x1
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = self.get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer right':
text_x = x2 + text_padding
text_y = y1 + text_padding + text_height
text_background_x1 = x2
text_background_y1 = y1
text_background_x2 = x2 + 2 * text_padding + text_width
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = self.get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'top right':
text_x = x2 - text_padding - text_width
text_y = y1 - text_padding
text_background_x1 = x2 - 2 * text_padding - text_width
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x2
text_background_y2 = y1
is_overlap = self.get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
def annotate(self, image_source, boxes, logits, phrases, text_scale,
text_padding=5, text_thickness=2, thickness=3):
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
box_annotator = self.BoxAnnotator(
text_scale=text_scale,
text_padding=text_padding,
text_thickness=text_thickness,
thickness=thickness,
get_is_overlap_fn=self.get_is_overlap,
get_optimal_label_pos_fn=self.get_optimal_label_pos
) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
annotated_frame = image_source.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
return annotated_frame, label_coordinates
@staticmethod
def int_box_area(box, w, h):
x1, y1, x2, y2 = box
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
@staticmethod
def get_xyxy(input):
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
return x, y, xp, yp
class OmniparserForCausalLM(OmniparserModel):
"""OmniParser model with a causal language modeling interface."""
_keys_to_ignore_on_load_missing = [r"caption_model\..*"]
def __init__(self, config):
super().__init__(config)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""Required by HF generate() - passes image_source from kwargs to forward."""
return {"image_source": kwargs.get("image_source", None)}
def generate(self, image_source=None, **kwargs):
"""Process image and return formatted text output of detected elements."""
outputs = super().forward(image_source=image_source)
res = json.dumps(outputs.parsed_content_list)
return res
# Register the model with Auto classes
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("omniparserv2", OmniparserConfig)
AutoModel.register(OmniparserConfig, OmniparserModel)
AutoModelForCausalLM.register(OmniparserConfig, OmniparserForCausalLM)