YiYiXu HF Staff commited on
Commit
efe6ec3
·
verified ·
1 Parent(s): c103541

Upload ModularPipeline

Browse files
Files changed (4) hide show
  1. README.md +3 -0
  2. block.py +398 -0
  3. modular_config.json +2 -2
  4. modular_model_index.json +33 -0
README.md CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ ---
block.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.modular_pipelines import (
6
+ ComponentSpec,
7
+ InputParam,
8
+ ModularPipelineBlocks,
9
+ OutputParam,
10
+ PipelineState,
11
+ )
12
+ from PIL import Image, ImageDraw
13
+ from transformers import AutoProcessor, Florence2ForConditionalGeneration
14
+
15
+
16
+ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
17
+ @property
18
+ def expected_components(self):
19
+ return [
20
+ ComponentSpec(
21
+ name="image_annotator",
22
+ type_hint=Florence2ForConditionalGeneration,
23
+ repo="florence-community/Florence-2-base-ft",
24
+ ),
25
+ ComponentSpec(
26
+ name="image_annotator_processor",
27
+ type_hint=AutoProcessor,
28
+ repo="florence-community/Florence-2-base-ft",
29
+ ),
30
+ ]
31
+
32
+ @property
33
+ def inputs(self) -> List[InputParam]:
34
+ return [
35
+ InputParam(
36
+ "image",
37
+ type_hint=Union[Image.Image, List[Image.Image]],
38
+ required=True,
39
+ description="Image(s) to annotate",
40
+ ),
41
+ InputParam(
42
+ "annotation_task",
43
+ type_hint=Union[str, List[str]],
44
+ default="<REFERRING_EXPRESSION_SEGMENTATION>",
45
+ description="""Annotation Task to perform on the image.
46
+ Supported Tasks:
47
+
48
+ <OD>
49
+ <REFERRING_EXPRESSION_SEGMENTATION>
50
+ <CAPTION>
51
+ <DETAILED_CAPTION>
52
+ <MORE_DETAILED_CAPTION>
53
+ <DENSE_REGION_CAPTION>
54
+ <REGION_PROPOSAL>
55
+ <CAPTION_TO_PHRASE_GROUNDING>
56
+ <OPEN_VOCABULARY_DETECTION>
57
+ <OCR>
58
+ <OCR_WITH_REGION>
59
+
60
+ """,
61
+ ),
62
+ InputParam(
63
+ "annotation_prompt",
64
+ type_hint=Union[str, List[str]],
65
+ required=True,
66
+ description="""Annotation Prompt to provide more context to the task.
67
+ Can be used to detect or segment out specific elements in the image
68
+ """,
69
+ ),
70
+ InputParam(
71
+ "annotation_output_type",
72
+ type_hint=str,
73
+ default="mask_image",
74
+ description="""Output type from annotation predictions. Availabe options are
75
+ annotation:
76
+ - raw annotation predictions from the model based on task type.
77
+ mask_image:
78
+ -black and white mask image for the given image based on the task type
79
+ mask_overlay:
80
+ - white mask overlayed on the original image
81
+ bounding_box:
82
+ - bounding boxes drawn on the original image
83
+ """,
84
+ ),
85
+ InputParam(
86
+ "annotation_overlay",
87
+ type_hint=bool,
88
+ required=True,
89
+ default=False,
90
+ description="",
91
+ ),
92
+ InputParam(
93
+ "fill",
94
+ type_hint=str,
95
+ default="white",
96
+ description="",
97
+ ),
98
+ ]
99
+
100
+ @property
101
+ def intermediate_outputs(self) -> List[OutputParam]:
102
+ return [
103
+ OutputParam(
104
+ "annotations",
105
+ type_hint=dict,
106
+ description="Annotations Predictions for input Image(s)",
107
+ ),
108
+ OutputParam(
109
+ "images",
110
+ type_hint=Image,
111
+ description="Annotated input Image(s)",
112
+ ),
113
+ ]
114
+
115
+ def get_annotations(self, components, images, prompts, task):
116
+ task_prompts = [task + prompt for prompt in prompts]
117
+
118
+ inputs = components.image_annotator_processor(
119
+ text=task_prompts, images=images, return_tensors="pt"
120
+ ).to(components.image_annotator.device, components.image_annotator.dtype)
121
+
122
+ generated_ids = components.image_annotator.generate(
123
+ input_ids=inputs["input_ids"],
124
+ pixel_values=inputs["pixel_values"],
125
+ max_new_tokens=1024,
126
+ early_stopping=False,
127
+ do_sample=False,
128
+ num_beams=3,
129
+ )
130
+ annotations = components.image_annotator_processor.batch_decode(
131
+ generated_ids, skip_special_tokens=False
132
+ )
133
+
134
+ outputs = []
135
+ for image, annotation in zip(images, annotations):
136
+ outputs.append(
137
+ components.image_annotator_processor.post_process_generation(
138
+ annotation, task=task, image_size=(image.width, image.height)
139
+ )
140
+ )
141
+
142
+ return outputs
143
+
144
+ def _iter_polygon_point_sets(self, poly):
145
+ """
146
+ Yields lists of (x, y) points for all simple polygons found in `poly`.
147
+ Supports formats:
148
+ - [x1, y1, x2, y2, ...]
149
+ - [[x, y], [x, y], ...]
150
+ - [xs, ys]
151
+ - dict {'x': xs, 'y': ys}
152
+ - nested lists containing any of the above
153
+ """
154
+ if poly is None:
155
+ return
156
+
157
+ def is_num(v):
158
+ return isinstance(v, (int, float, np.number))
159
+
160
+ # dict {'x': [...], 'y': [...]}
161
+ if isinstance(poly, dict) and "x" in poly and "y" in poly:
162
+ xs, ys = poly["x"], poly["y"]
163
+ if (
164
+ isinstance(xs, (list, tuple))
165
+ and isinstance(ys, (list, tuple))
166
+ and len(xs) == len(ys)
167
+ ):
168
+ pts = list(zip(xs, ys))
169
+ if len(pts) >= 3:
170
+ yield pts
171
+ return
172
+
173
+ if isinstance(poly, (list, tuple)):
174
+ # flat numeric [x1, y1, ...]
175
+ if all(is_num(v) for v in poly):
176
+ coords = list(poly)
177
+ if len(coords) >= 6 and len(coords) % 2 == 0:
178
+ yield list(zip(coords[0::2], coords[1::2]))
179
+ return
180
+
181
+ # list of pairs [[x, y], ...]
182
+ if all(
183
+ isinstance(v, (list, tuple))
184
+ and len(v) == 2
185
+ and all(is_num(n) for n in v)
186
+ for v in poly
187
+ ):
188
+ if len(poly) >= 3:
189
+ yield [tuple(v) for v in poly]
190
+ return
191
+
192
+ # [xs, ys]
193
+ if len(poly) == 2 and all(isinstance(v, (list, tuple)) for v in poly):
194
+ xs, ys = poly
195
+ try:
196
+ if len(xs) == len(ys) and len(xs) >= 3:
197
+ yield list(zip(xs, ys))
198
+ return
199
+ except TypeError:
200
+ pass
201
+
202
+ # nested: recurse into parts
203
+ for part in poly:
204
+ yield from self._iter_polygon_point_sets(part)
205
+ # other types are ignored
206
+
207
+ def prepare_mask(self, images, annotations, overlay=False, fill="white"):
208
+ masks = []
209
+ for image, annotation in zip(images, annotations):
210
+ mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
211
+ draw = ImageDraw.Draw(mask_image)
212
+
213
+ # use a safe fill for grayscale masks
214
+ mask_fill = fill
215
+ if not overlay and isinstance(fill, str):
216
+ # for "L" mode, white -> 255
217
+ mask_fill = 255
218
+
219
+ for _, _annotation in annotation.items():
220
+ if "polygons" in _annotation:
221
+ for poly in _annotation["polygons"]:
222
+ for pts in self._iter_polygon_point_sets(poly):
223
+ if len(pts) < 3:
224
+ continue
225
+ # clip to image bounds and flatten
226
+ flat = []
227
+ for x, y in pts:
228
+ xi = int(round(max(0, min(image.width - 1, x))))
229
+ yi = int(round(max(0, min(image.height - 1, y))))
230
+ flat.extend([xi, yi])
231
+ draw.polygon(flat, fill=mask_fill)
232
+
233
+ elif "bboxes" in _annotation:
234
+ for bbox in _annotation["bboxes"]:
235
+ flat = np.array(bbox).flatten().tolist()
236
+ if len(flat) == 4:
237
+ x0, y0, x1, y1 = flat
238
+ draw.rectangle(
239
+ (
240
+ int(round(x0)),
241
+ int(round(y0)),
242
+ int(round(x1)),
243
+ int(round(y1)),
244
+ ),
245
+ fill=mask_fill,
246
+ )
247
+
248
+ elif "quad_boxes" in _annotation:
249
+ for quad in _annotation["quad_boxes"]:
250
+ for pts in self._iter_polygon_point_sets(quad):
251
+ if len(pts) < 3:
252
+ continue
253
+ flat = []
254
+ for x, y in pts:
255
+ xi = int(round(max(0, min(image.width - 1, x))))
256
+ yi = int(round(max(0, min(image.height - 1, y))))
257
+ flat.extend([xi, yi])
258
+ draw.polygon(flat, fill=mask_fill)
259
+
260
+ masks.append(mask_image)
261
+
262
+ return masks
263
+
264
+ def prepare_bounding_boxes(self, images, annotations):
265
+ outputs = []
266
+ for image, annotation in zip(images, annotations):
267
+ image_copy = image.copy()
268
+ draw = ImageDraw.Draw(image_copy)
269
+ for _, _annotation in annotation.items():
270
+ # Standard axis-aligned boxes
271
+ bboxes = _annotation.get("bboxes", [])
272
+ labels = _annotation.get("labels", [])
273
+
274
+ if len(labels) == 0:
275
+ labels = _annotation.get("bboxes_labels", [])
276
+
277
+ for i, bbox in enumerate(bboxes):
278
+ flat = np.array(bbox).flatten().tolist()
279
+
280
+ if len(flat) != 4:
281
+ continue
282
+
283
+ x0, y0, x1, y1 = flat
284
+ draw.rectangle(
285
+ (
286
+ int(round(x0)),
287
+ int(round(y0)),
288
+ int(round(x1)),
289
+ int(round(y1)),
290
+ ),
291
+ outline="red",
292
+ width=3,
293
+ )
294
+ label = labels[i] if i < len(labels) else ""
295
+ if label:
296
+ text_y = max(0, int(y0) - 20)
297
+ draw.text((int(x0), text_y), label, fill="red")
298
+
299
+ # Quadrilateral boxes (draw as polygons)
300
+ quad_boxes = _annotation.get("quad_boxes", [])
301
+ qlabels = _annotation.get("labels", [])
302
+ for i, quad in enumerate(quad_boxes):
303
+ for pts in self._iter_polygon_point_sets(quad):
304
+ if len(pts) < 3:
305
+ continue
306
+ flat = []
307
+ xs, ys = [], []
308
+ for x, y in pts:
309
+ xi = int(round(max(0, min(image.width - 1, x))))
310
+ yi = int(round(max(0, min(image.height - 1, y))))
311
+ flat.extend([xi, yi])
312
+ xs.append(xi)
313
+ ys.append(yi)
314
+
315
+ # Outline polygon
316
+ try:
317
+ draw.polygon(flat, outline="red", width=3)
318
+ except TypeError:
319
+ # Pillow without width for polygon
320
+ draw.polygon(flat, outline="red")
321
+
322
+ # Optional label at centroid (inside the quad)
323
+ label = qlabels[i] if i < len(qlabels) else ""
324
+ if label:
325
+ cx = int(round(sum(xs) / len(xs)))
326
+ cy = int(round(sum(ys) / len(ys)))
327
+ cx = max(0, min(image.width - 1, cx))
328
+ cy = max(0, min(image.height - 1, cy))
329
+ draw.text((cx, cy), label, fill="red")
330
+
331
+ outputs.append(image_copy)
332
+
333
+ return outputs
334
+
335
+ def prepare_inputs(self, images, prompts):
336
+ prompts = prompts or ""
337
+
338
+ if isinstance(images, Image.Image):
339
+ images = [images]
340
+ if isinstance(prompts, str):
341
+ prompts = [prompts]
342
+
343
+ if len(images) != len(prompts):
344
+ raise ValueError("Number of images and annotation prompts must match.")
345
+
346
+ return images, prompts
347
+
348
+ @torch.no_grad()
349
+ def __call__(self, components, state: PipelineState) -> PipelineState:
350
+ block_state = self.get_block_state(state)
351
+ skip_image = False
352
+
353
+ # these don't require a prompt and fail if one is given
354
+ if (
355
+ block_state.annotation_task == "<OD>"
356
+ or block_state.annotation_task == "<DENSE_REGION_CAPTION>"
357
+ or block_state.annotation_task == "<REGION_PROPOSAL>"
358
+ or block_state.annotation_task == "<OCR_WITH_REGION>"
359
+ ):
360
+ block_state.annotation_prompt = ""
361
+ block_state.annotation_output_type = "bounding_box"
362
+ # these don't require a prompt and doesn't ouput an image
363
+ elif (
364
+ block_state.annotation_task == "<CAPTION>"
365
+ or block_state.annotation_task == "<DETAILED_CAPTION>"
366
+ or block_state.annotation_task == "<MORE_DETAILED_CAPTION>"
367
+ or block_state.annotation_task == "<OCR>"
368
+ ):
369
+ block_state.annotation_prompt = ""
370
+ skip_image = True
371
+
372
+ images, annotation_task_prompt = self.prepare_inputs(
373
+ block_state.image, block_state.annotation_prompt
374
+ )
375
+ task = block_state.annotation_task
376
+ fill = block_state.fill
377
+
378
+ annotations = self.get_annotations(
379
+ components, images, annotation_task_prompt, task
380
+ )
381
+
382
+ block_state.annotations = annotations
383
+ block_state.images = None
384
+
385
+ if not skip_image:
386
+ if block_state.annotation_output_type == "mask_image":
387
+ block_state.images = self.prepare_mask(images, annotations)
388
+
389
+ if block_state.annotation_output_type == "mask_overlay":
390
+ block_state.images = self.prepare_mask(
391
+ images, annotations, overlay=True, fill=fill
392
+ )
393
+ elif block_state.annotation_output_type == "bounding_box":
394
+ block_state.images = self.prepare_bounding_boxes(images, annotations)
395
+
396
+ self.set_block_state(state, block_state)
397
+
398
+ return components, state
modular_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "Florence2ImageAnnotatorBlock",
3
- "_diffusers_version": "0.37.0.dev0",
4
  "auto_map": {
5
  "ModularPipelineBlocks": "block.Florence2ImageAnnotatorBlock"
6
  }
7
- }
 
1
  {
2
  "_class_name": "Florence2ImageAnnotatorBlock",
3
+ "_diffusers_version": "0.35.1",
4
  "auto_map": {
5
  "ModularPipelineBlocks": "block.Florence2ImageAnnotatorBlock"
6
  }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "Florence2ImageAnnotatorBlock",
3
+ "_class_name": "ModularPipeline",
4
+ "_diffusers_version": "0.37.0.dev0",
5
+ "image_annotator": [
6
+ null,
7
+ null,
8
+ {
9
+ "pretrained_model_name_or_path": "florence-community/Florence-2-base-ft",
10
+ "revision": null,
11
+ "subfolder": "",
12
+ "type_hint": [
13
+ "transformers",
14
+ "Florence2ForConditionalGeneration"
15
+ ],
16
+ "variant": null
17
+ }
18
+ ],
19
+ "image_annotator_processor": [
20
+ null,
21
+ null,
22
+ {
23
+ "pretrained_model_name_or_path": "florence-community/Florence-2-base-ft",
24
+ "revision": null,
25
+ "subfolder": "",
26
+ "type_hint": [
27
+ "transformers",
28
+ "AutoProcessor"
29
+ ],
30
+ "variant": null
31
+ }
32
+ ]
33
+ }