from typing import List, Union import numpy as np import torch from diffusers.modular_pipelines import ( ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState, ) from PIL import Image, ImageDraw from transformers import AutoProcessor, Florence2ForConditionalGeneration class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): @property def expected_components(self): return [ ComponentSpec( name="image_annotator", type_hint=Florence2ForConditionalGeneration, repo="florence-community/Florence-2-base-ft", ), ComponentSpec( name="image_annotator_processor", type_hint=AutoProcessor, repo="florence-community/Florence-2-base-ft", ), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "image", type_hint=Union[Image.Image, List[Image.Image]], required=True, description="Image(s) to annotate", metadata={"mellon":"image"}, ), InputParam( "annotation_task", type_hint=Union[str, List[str]], default="", metadata={"mellon":"dropdown"}, description="""Annotation Task to perform on the image. Supported Tasks: """, ), InputParam( "annotation_prompt", type_hint=Union[str, List[str]], required=True, metadata={"mellon":"textbox"}, description="""Annotation Prompt to provide more context to the task. Can be used to detect or segment out specific elements in the image """, ), InputParam( "annotation_output_type", type_hint=str, default="mask_image", metadata={"mellon":"dropdown"}, description="""Output type from annotation predictions. Availabe options are annotation: - raw annotation predictions from the model based on task type. mask_image: -black and white mask image for the given image based on the task type mask_overlay: - white mask overlayed on the original image bounding_box: - bounding boxes drawn on the original image """, ), InputParam( "annotation_overlay", type_hint=bool, required=True, default=False, description="", metadata={"mellon":"checkbox"}, ), InputParam( "fill", type_hint=str, default="white", description="", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "annotations", type_hint=dict, description="Annotations Predictions for input Image(s)", ), OutputParam( "images", type_hint=Image, description="Annotated input Image(s)", metadata={"mellon":"image"}, ), ] def get_annotations(self, components, images, prompts, task): task_prompts = [task + prompt for prompt in prompts] inputs = components.image_annotator_processor( text=task_prompts, images=images, return_tensors="pt" ).to(components.image_annotator.device, components.image_annotator.dtype) generated_ids = components.image_annotator.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) annotations = components.image_annotator_processor.batch_decode( generated_ids, skip_special_tokens=False ) outputs = [] for image, annotation in zip(images, annotations): outputs.append( components.image_annotator_processor.post_process_generation( annotation, task=task, image_size=(image.width, image.height) ) ) return outputs def _iter_polygon_point_sets(self, poly): """ Yields lists of (x, y) points for all simple polygons found in `poly`. Supports formats: - [x1, y1, x2, y2, ...] - [[x, y], [x, y], ...] - [xs, ys] - dict {'x': xs, 'y': ys} - nested lists containing any of the above """ if poly is None: return def is_num(v): return isinstance(v, (int, float, np.number)) # dict {'x': [...], 'y': [...]} if isinstance(poly, dict) and "x" in poly and "y" in poly: xs, ys = poly["x"], poly["y"] if ( isinstance(xs, (list, tuple)) and isinstance(ys, (list, tuple)) and len(xs) == len(ys) ): pts = list(zip(xs, ys)) if len(pts) >= 3: yield pts return if isinstance(poly, (list, tuple)): # flat numeric [x1, y1, ...] if all(is_num(v) for v in poly): coords = list(poly) if len(coords) >= 6 and len(coords) % 2 == 0: yield list(zip(coords[0::2], coords[1::2])) return # list of pairs [[x, y], ...] if all( isinstance(v, (list, tuple)) and len(v) == 2 and all(is_num(n) for n in v) for v in poly ): if len(poly) >= 3: yield [tuple(v) for v in poly] return # [xs, ys] if len(poly) == 2 and all(isinstance(v, (list, tuple)) for v in poly): xs, ys = poly try: if len(xs) == len(ys) and len(xs) >= 3: yield list(zip(xs, ys)) return except TypeError: pass # nested: recurse into parts for part in poly: yield from self._iter_polygon_point_sets(part) # other types are ignored def prepare_mask(self, images, annotations, overlay=False, fill="white"): masks = [] for image, annotation in zip(images, annotations): mask_image = image.copy() if overlay else Image.new("L", image.size, 0) draw = ImageDraw.Draw(mask_image) # use a safe fill for grayscale masks mask_fill = fill if not overlay and isinstance(fill, str): # for "L" mode, white -> 255 mask_fill = 255 for _, _annotation in annotation.items(): if "polygons" in _annotation: for poly in _annotation["polygons"]: for pts in self._iter_polygon_point_sets(poly): if len(pts) < 3: continue # clip to image bounds and flatten flat = [] for x, y in pts: xi = int(round(max(0, min(image.width - 1, x)))) yi = int(round(max(0, min(image.height - 1, y)))) flat.extend([xi, yi]) draw.polygon(flat, fill=mask_fill) elif "bboxes" in _annotation: for bbox in _annotation["bboxes"]: flat = np.array(bbox).flatten().tolist() if len(flat) == 4: x0, y0, x1, y1 = flat draw.rectangle( ( int(round(x0)), int(round(y0)), int(round(x1)), int(round(y1)), ), fill=mask_fill, ) elif "quad_boxes" in _annotation: for quad in _annotation["quad_boxes"]: for pts in self._iter_polygon_point_sets(quad): if len(pts) < 3: continue flat = [] for x, y in pts: xi = int(round(max(0, min(image.width - 1, x)))) yi = int(round(max(0, min(image.height - 1, y)))) flat.extend([xi, yi]) draw.polygon(flat, fill=mask_fill) masks.append(mask_image) return masks def prepare_bounding_boxes(self, images, annotations): outputs = [] for image, annotation in zip(images, annotations): image_copy = image.copy() draw = ImageDraw.Draw(image_copy) for _, _annotation in annotation.items(): # Standard axis-aligned boxes bboxes = _annotation.get("bboxes", []) labels = _annotation.get("labels", []) if len(labels) == 0: labels = _annotation.get("bboxes_labels", []) for i, bbox in enumerate(bboxes): flat = np.array(bbox).flatten().tolist() if len(flat) != 4: continue x0, y0, x1, y1 = flat draw.rectangle( ( int(round(x0)), int(round(y0)), int(round(x1)), int(round(y1)), ), outline="red", width=3, ) label = labels[i] if i < len(labels) else "" if label: text_y = max(0, int(y0) - 20) draw.text((int(x0), text_y), label, fill="red") # Quadrilateral boxes (draw as polygons) quad_boxes = _annotation.get("quad_boxes", []) qlabels = _annotation.get("labels", []) for i, quad in enumerate(quad_boxes): for pts in self._iter_polygon_point_sets(quad): if len(pts) < 3: continue flat = [] xs, ys = [], [] for x, y in pts: xi = int(round(max(0, min(image.width - 1, x)))) yi = int(round(max(0, min(image.height - 1, y)))) flat.extend([xi, yi]) xs.append(xi) ys.append(yi) # Outline polygon try: draw.polygon(flat, outline="red", width=3) except TypeError: # Pillow without width for polygon draw.polygon(flat, outline="red") # Optional label at centroid (inside the quad) label = qlabels[i] if i < len(qlabels) else "" if label: cx = int(round(sum(xs) / len(xs))) cy = int(round(sum(ys) / len(ys))) cx = max(0, min(image.width - 1, cx)) cy = max(0, min(image.height - 1, cy)) draw.text((cx, cy), label, fill="red") outputs.append(image_copy) return outputs def prepare_inputs(self, images, prompts): prompts = prompts or "" if isinstance(images, Image.Image): images = [images] if isinstance(prompts, str): prompts = [prompts] if len(images) != len(prompts): raise ValueError("Number of images and annotation prompts must match.") return images, prompts @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) skip_image = False # these don't require a prompt and fail if one is given if ( block_state.annotation_task == "" or block_state.annotation_task == "" or block_state.annotation_task == "" or block_state.annotation_task == "" ): block_state.annotation_prompt = "" block_state.annotation_output_type = "bounding_box" # these don't require a prompt and doesn't ouput an image elif ( block_state.annotation_task == "" or block_state.annotation_task == "" or block_state.annotation_task == "" or block_state.annotation_task == "" ): block_state.annotation_prompt = "" skip_image = True images, annotation_task_prompt = self.prepare_inputs( block_state.image, block_state.annotation_prompt ) task = block_state.annotation_task fill = block_state.fill annotations = self.get_annotations( components, images, annotation_task_prompt, task ) block_state.annotations = annotations block_state.images = None if not skip_image: if block_state.annotation_output_type == "mask_image": block_state.images = self.prepare_mask(images, annotations) if block_state.annotation_output_type == "mask_overlay": block_state.images = self.prepare_mask( images, annotations, overlay=True, fill=fill ) elif block_state.annotation_output_type == "bounding_box": block_state.images = self.prepare_bounding_boxes(images, annotations) self.set_block_state(state, block_state) return components, state