|
|
from typing import List, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers.modular_pipelines import ( |
|
|
ComponentSpec, |
|
|
InputParam, |
|
|
ModularPipelineBlocks, |
|
|
OutputParam, |
|
|
PipelineState, |
|
|
) |
|
|
from PIL import Image, ImageDraw |
|
|
from transformers import AutoProcessor, Florence2ForConditionalGeneration |
|
|
|
|
|
|
|
|
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
|
|
@property |
|
|
def expected_components(self): |
|
|
return [ |
|
|
ComponentSpec( |
|
|
name="image_annotator", |
|
|
type_hint=Florence2ForConditionalGeneration, |
|
|
repo="florence-community/Florence-2-base-ft", |
|
|
), |
|
|
ComponentSpec( |
|
|
name="image_annotator_processor", |
|
|
type_hint=AutoProcessor, |
|
|
repo="florence-community/Florence-2-base-ft", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"image", |
|
|
type_hint=Union[Image.Image, List[Image.Image]], |
|
|
required=True, |
|
|
description="Image(s) to annotate", |
|
|
metadata={"mellon":"image"}, |
|
|
), |
|
|
InputParam( |
|
|
"annotation_task", |
|
|
type_hint=Union[str, List[str]], |
|
|
default="<REFERRING_EXPRESSION_SEGMENTATION>", |
|
|
metadata={"mellon":"dropdown"}, |
|
|
description="""Annotation Task to perform on the image. |
|
|
Supported Tasks: |
|
|
|
|
|
<OD> |
|
|
<REFERRING_EXPRESSION_SEGMENTATION> |
|
|
<CAPTION> |
|
|
<DETAILED_CAPTION> |
|
|
<MORE_DETAILED_CAPTION> |
|
|
<DENSE_REGION_CAPTION> |
|
|
<REGION_PROPOSAL> |
|
|
<CAPTION_TO_PHRASE_GROUNDING> |
|
|
<OPEN_VOCABULARY_DETECTION> |
|
|
<OCR> |
|
|
<OCR_WITH_REGION> |
|
|
|
|
|
""", |
|
|
), |
|
|
InputParam( |
|
|
"annotation_prompt", |
|
|
type_hint=Union[str, List[str]], |
|
|
required=True, |
|
|
metadata={"mellon":"textbox"}, |
|
|
description="""Annotation Prompt to provide more context to the task. |
|
|
Can be used to detect or segment out specific elements in the image |
|
|
""", |
|
|
), |
|
|
InputParam( |
|
|
"annotation_output_type", |
|
|
type_hint=str, |
|
|
default="mask_image", |
|
|
metadata={"mellon":"dropbox"}, |
|
|
description="""Output type from annotation predictions. Availabe options are |
|
|
annotation: |
|
|
- raw annotation predictions from the model based on task type. |
|
|
mask_image: |
|
|
-black and white mask image for the given image based on the task type |
|
|
mask_overlay: |
|
|
- white mask overlayed on the original image |
|
|
bounding_box: |
|
|
- bounding boxes drawn on the original image |
|
|
""", |
|
|
), |
|
|
InputParam( |
|
|
"annotation_overlay", |
|
|
type_hint=bool, |
|
|
required=True, |
|
|
default=False, |
|
|
description="", |
|
|
metadata={"mellon":"checkbox"}, |
|
|
), |
|
|
InputParam( |
|
|
"fill", |
|
|
type_hint=str, |
|
|
default="white", |
|
|
description="", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"annotations", |
|
|
type_hint=dict, |
|
|
description="Annotations Predictions for input Image(s)", |
|
|
), |
|
|
OutputParam( |
|
|
"images", |
|
|
type_hint=Image, |
|
|
description="Annotated input Image(s)", |
|
|
metadata={"mellon":"image"}, |
|
|
), |
|
|
] |
|
|
|
|
|
def get_annotations(self, components, images, prompts, task): |
|
|
task_prompts = [task + prompt for prompt in prompts] |
|
|
|
|
|
inputs = components.image_annotator_processor( |
|
|
text=task_prompts, images=images, return_tensors="pt" |
|
|
).to(components.image_annotator.device, components.image_annotator.dtype) |
|
|
|
|
|
generated_ids = components.image_annotator.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
pixel_values=inputs["pixel_values"], |
|
|
max_new_tokens=1024, |
|
|
early_stopping=False, |
|
|
do_sample=False, |
|
|
num_beams=3, |
|
|
) |
|
|
annotations = components.image_annotator_processor.batch_decode( |
|
|
generated_ids, skip_special_tokens=False |
|
|
) |
|
|
|
|
|
outputs = [] |
|
|
for image, annotation in zip(images, annotations): |
|
|
outputs.append( |
|
|
components.image_annotator_processor.post_process_generation( |
|
|
annotation, task=task, image_size=(image.width, image.height) |
|
|
) |
|
|
) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def _iter_polygon_point_sets(self, poly): |
|
|
""" |
|
|
Yields lists of (x, y) points for all simple polygons found in `poly`. |
|
|
Supports formats: |
|
|
- [x1, y1, x2, y2, ...] |
|
|
- [[x, y], [x, y], ...] |
|
|
- [xs, ys] |
|
|
- dict {'x': xs, 'y': ys} |
|
|
- nested lists containing any of the above |
|
|
""" |
|
|
if poly is None: |
|
|
return |
|
|
|
|
|
def is_num(v): |
|
|
return isinstance(v, (int, float, np.number)) |
|
|
|
|
|
|
|
|
if isinstance(poly, dict) and "x" in poly and "y" in poly: |
|
|
xs, ys = poly["x"], poly["y"] |
|
|
if ( |
|
|
isinstance(xs, (list, tuple)) |
|
|
and isinstance(ys, (list, tuple)) |
|
|
and len(xs) == len(ys) |
|
|
): |
|
|
pts = list(zip(xs, ys)) |
|
|
if len(pts) >= 3: |
|
|
yield pts |
|
|
return |
|
|
|
|
|
if isinstance(poly, (list, tuple)): |
|
|
|
|
|
if all(is_num(v) for v in poly): |
|
|
coords = list(poly) |
|
|
if len(coords) >= 6 and len(coords) % 2 == 0: |
|
|
yield list(zip(coords[0::2], coords[1::2])) |
|
|
return |
|
|
|
|
|
|
|
|
if all( |
|
|
isinstance(v, (list, tuple)) |
|
|
and len(v) == 2 |
|
|
and all(is_num(n) for n in v) |
|
|
for v in poly |
|
|
): |
|
|
if len(poly) >= 3: |
|
|
yield [tuple(v) for v in poly] |
|
|
return |
|
|
|
|
|
|
|
|
if len(poly) == 2 and all(isinstance(v, (list, tuple)) for v in poly): |
|
|
xs, ys = poly |
|
|
try: |
|
|
if len(xs) == len(ys) and len(xs) >= 3: |
|
|
yield list(zip(xs, ys)) |
|
|
return |
|
|
except TypeError: |
|
|
pass |
|
|
|
|
|
|
|
|
for part in poly: |
|
|
yield from self._iter_polygon_point_sets(part) |
|
|
|
|
|
|
|
|
def prepare_mask(self, images, annotations, overlay=False, fill="white"): |
|
|
masks = [] |
|
|
for image, annotation in zip(images, annotations): |
|
|
mask_image = image.copy() if overlay else Image.new("L", image.size, 0) |
|
|
draw = ImageDraw.Draw(mask_image) |
|
|
|
|
|
|
|
|
mask_fill = fill |
|
|
if not overlay and isinstance(fill, str): |
|
|
|
|
|
mask_fill = 255 |
|
|
|
|
|
for _, _annotation in annotation.items(): |
|
|
if "polygons" in _annotation: |
|
|
for poly in _annotation["polygons"]: |
|
|
for pts in self._iter_polygon_point_sets(poly): |
|
|
if len(pts) < 3: |
|
|
continue |
|
|
|
|
|
flat = [] |
|
|
for x, y in pts: |
|
|
xi = int(round(max(0, min(image.width - 1, x)))) |
|
|
yi = int(round(max(0, min(image.height - 1, y)))) |
|
|
flat.extend([xi, yi]) |
|
|
draw.polygon(flat, fill=mask_fill) |
|
|
|
|
|
elif "bboxes" in _annotation: |
|
|
for bbox in _annotation["bboxes"]: |
|
|
flat = np.array(bbox).flatten().tolist() |
|
|
if len(flat) == 4: |
|
|
x0, y0, x1, y1 = flat |
|
|
draw.rectangle( |
|
|
( |
|
|
int(round(x0)), |
|
|
int(round(y0)), |
|
|
int(round(x1)), |
|
|
int(round(y1)), |
|
|
), |
|
|
fill=mask_fill, |
|
|
) |
|
|
|
|
|
elif "quad_boxes" in _annotation: |
|
|
for quad in _annotation["quad_boxes"]: |
|
|
for pts in self._iter_polygon_point_sets(quad): |
|
|
if len(pts) < 3: |
|
|
continue |
|
|
flat = [] |
|
|
for x, y in pts: |
|
|
xi = int(round(max(0, min(image.width - 1, x)))) |
|
|
yi = int(round(max(0, min(image.height - 1, y)))) |
|
|
flat.extend([xi, yi]) |
|
|
draw.polygon(flat, fill=mask_fill) |
|
|
|
|
|
masks.append(mask_image) |
|
|
|
|
|
return masks |
|
|
|
|
|
def prepare_bounding_boxes(self, images, annotations): |
|
|
outputs = [] |
|
|
for image, annotation in zip(images, annotations): |
|
|
image_copy = image.copy() |
|
|
draw = ImageDraw.Draw(image_copy) |
|
|
for _, _annotation in annotation.items(): |
|
|
|
|
|
bboxes = _annotation.get("bboxes", []) |
|
|
labels = _annotation.get("labels", []) |
|
|
|
|
|
if len(labels) == 0: |
|
|
labels = _annotation.get("bboxes_labels", []) |
|
|
|
|
|
for i, bbox in enumerate(bboxes): |
|
|
flat = np.array(bbox).flatten().tolist() |
|
|
|
|
|
if len(flat) != 4: |
|
|
continue |
|
|
|
|
|
x0, y0, x1, y1 = flat |
|
|
draw.rectangle( |
|
|
( |
|
|
int(round(x0)), |
|
|
int(round(y0)), |
|
|
int(round(x1)), |
|
|
int(round(y1)), |
|
|
), |
|
|
outline="red", |
|
|
width=3, |
|
|
) |
|
|
label = labels[i] if i < len(labels) else "" |
|
|
if label: |
|
|
text_y = max(0, int(y0) - 20) |
|
|
draw.text((int(x0), text_y), label, fill="red") |
|
|
|
|
|
|
|
|
quad_boxes = _annotation.get("quad_boxes", []) |
|
|
qlabels = _annotation.get("labels", []) |
|
|
for i, quad in enumerate(quad_boxes): |
|
|
for pts in self._iter_polygon_point_sets(quad): |
|
|
if len(pts) < 3: |
|
|
continue |
|
|
flat = [] |
|
|
xs, ys = [], [] |
|
|
for x, y in pts: |
|
|
xi = int(round(max(0, min(image.width - 1, x)))) |
|
|
yi = int(round(max(0, min(image.height - 1, y)))) |
|
|
flat.extend([xi, yi]) |
|
|
xs.append(xi) |
|
|
ys.append(yi) |
|
|
|
|
|
|
|
|
try: |
|
|
draw.polygon(flat, outline="red", width=3) |
|
|
except TypeError: |
|
|
|
|
|
draw.polygon(flat, outline="red") |
|
|
|
|
|
|
|
|
label = qlabels[i] if i < len(qlabels) else "" |
|
|
if label: |
|
|
cx = int(round(sum(xs) / len(xs))) |
|
|
cy = int(round(sum(ys) / len(ys))) |
|
|
cx = max(0, min(image.width - 1, cx)) |
|
|
cy = max(0, min(image.height - 1, cy)) |
|
|
draw.text((cx, cy), label, fill="red") |
|
|
|
|
|
outputs.append(image_copy) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def prepare_inputs(self, images, prompts): |
|
|
prompts = prompts or "" |
|
|
|
|
|
if isinstance(images, Image.Image): |
|
|
images = [images] |
|
|
if isinstance(prompts, str): |
|
|
prompts = [prompts] |
|
|
|
|
|
if len(images) != len(prompts): |
|
|
raise ValueError("Number of images and annotation prompts must match.") |
|
|
|
|
|
return images, prompts |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
skip_image = False |
|
|
|
|
|
|
|
|
if ( |
|
|
block_state.annotation_task == "<OD>" |
|
|
or block_state.annotation_task == "<DENSE_REGION_CAPTION>" |
|
|
or block_state.annotation_task == "<REGION_PROPOSAL>" |
|
|
or block_state.annotation_task == "<OCR_WITH_REGION>" |
|
|
): |
|
|
block_state.annotation_prompt = "" |
|
|
block_state.annotation_output_type = "bounding_box" |
|
|
|
|
|
elif ( |
|
|
block_state.annotation_task == "<CAPTION>" |
|
|
or block_state.annotation_task == "<DETAILED_CAPTION>" |
|
|
or block_state.annotation_task == "<MORE_DETAILED_CAPTION>" |
|
|
or block_state.annotation_task == "<OCR>" |
|
|
): |
|
|
block_state.annotation_prompt = "" |
|
|
skip_image = True |
|
|
|
|
|
images, annotation_task_prompt = self.prepare_inputs( |
|
|
block_state.image, block_state.annotation_prompt |
|
|
) |
|
|
task = block_state.annotation_task |
|
|
fill = block_state.fill |
|
|
|
|
|
annotations = self.get_annotations( |
|
|
components, images, annotation_task_prompt, task |
|
|
) |
|
|
|
|
|
block_state.annotations = annotations |
|
|
block_state.images = None |
|
|
|
|
|
if not skip_image: |
|
|
if block_state.annotation_output_type == "mask_image": |
|
|
block_state.images = self.prepare_mask(images, annotations) |
|
|
|
|
|
if block_state.annotation_output_type == "mask_overlay": |
|
|
block_state.images = self.prepare_mask( |
|
|
images, annotations, overlay=True, fill=fill |
|
|
) |
|
|
elif block_state.annotation_output_type == "bounding_box": |
|
|
block_state.images = self.prepare_bounding_boxes(images, annotations) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|