| import torch
|
| import torchvision.transforms.functional as F
|
| import io
|
| import os
|
| from typing import List
|
| import matplotlib
|
| matplotlib.use('Agg')
|
| import matplotlib.pyplot as plt
|
| import matplotlib.patches as patches
|
| from PIL import Image, ImageDraw, ImageColor, ImageFont
|
| import random
|
| import numpy as np
|
| import re
|
|
|
|
|
| from unittest.mock import patch
|
| from transformers.dynamic_module_utils import get_imports
|
|
|
| def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
| if not str(filename).endswith("modeling_florence2.py"):
|
| return get_imports(filename)
|
| imports = get_imports(filename)
|
| imports.remove("flash_attn")
|
| return imports
|
|
|
|
|
| import comfy.model_management as mm
|
| from comfy.utils import ProgressBar
|
| import folder_paths
|
|
|
| script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
| from transformers import AutoModelForCausalLM, AutoProcessor
|
|
|
| class DownloadAndLoadFlorence2Model:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {
|
| "model": (
|
| [
|
| 'microsoft/Florence-2-base',
|
| 'microsoft/Florence-2-base-ft',
|
| 'microsoft/Florence-2-large',
|
| 'microsoft/Florence-2-large-ft',
|
| 'HuggingFaceM4/Florence-2-DocVQA'
|
| ],
|
| {
|
| "default": 'microsoft/Florence-2-base'
|
| }),
|
| "precision": ([ 'fp16','bf16','fp32'],
|
| {
|
| "default": 'fp16'
|
| }),
|
| "attention": (
|
| [ 'flash_attention_2', 'sdpa', 'eager'],
|
| {
|
| "default": 'sdpa'
|
| }),
|
|
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("FL2MODEL",)
|
| RETURN_NAMES = ("florence2_model",)
|
| FUNCTION = "loadmodel"
|
| CATEGORY = "Florence2"
|
|
|
| def loadmodel(self, model, precision, attention):
|
| device = mm.get_torch_device()
|
| offload_device = mm.unet_offload_device()
|
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
|
|
| model_name = model.rsplit('/', 1)[-1]
|
| model_path = os.path.join(folder_paths.models_dir, "LLM", model_name)
|
|
|
| if not os.path.exists(model_path):
|
| print(f"Downloading Lumina model to: {model_path}")
|
| from huggingface_hub import snapshot_download
|
| snapshot_download(repo_id=model,
|
| local_dir=model_path,
|
| local_dir_use_symlinks=False)
|
|
|
| print(f"using {attention} for attention")
|
| with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
| model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, device_map=device, torch_dtype=dtype,trust_remote_code=True)
|
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
|
|
| florence2_model = {
|
| 'model': model,
|
| 'processor': processor,
|
| 'dtype': dtype
|
| }
|
|
|
| return (florence2_model,)
|
|
|
| def calculate_bounding_box(width, height, flat_points) -> List[float]:
|
| """
|
| Calculate the bounding box for a polygon.
|
|
|
| Args:
|
| flat_points (list of int): Flat list of x, y coordinates defining the polygon points.
|
|
|
| Returns:
|
| tuple: (min_x, min_y, max_x, max_y) defining the bounding box.
|
| """
|
| if not flat_points or len(flat_points) % 2 != 0:
|
| raise ValueError("The list of points must be non-empty and have an even number of elements")
|
|
|
| x_coords = flat_points[0::2]
|
| y_coords = flat_points[1::2]
|
|
|
| min_x = min(x_coords)
|
| max_x = max(x_coords)
|
| min_y = min(y_coords)
|
| max_y = max(y_coords)
|
|
|
| return [min_x / width, min_y / height, max_x / width, max_y / height]
|
|
|
| class Florence2Run:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "image": ("IMAGE", ),
|
| "florence2_model": ("FL2MODEL", ),
|
| "text_input": ("STRING", {"default": "", "multiline": True}),
|
| "task": (
|
| [
|
| 'region_caption',
|
| 'dense_region_caption',
|
| 'region_proposal',
|
| 'caption',
|
| 'detailed_caption',
|
| 'more_detailed_caption',
|
| 'caption_to_phrase_grounding',
|
| 'referring_expression_segmentation',
|
| 'ocr',
|
| 'ocr_with_region',
|
| 'docvqa'
|
| ],
|
| ),
|
| "fill_mask": ("BOOLEAN", {"default": True}),
|
| },
|
| "optional": {
|
| "keep_model_loaded": ("BOOLEAN", {"default": False}),
|
| "max_new_tokens": ("INT", {"default": 1024, "min": 1, "max": 4096}),
|
| "num_beams": ("INT", {"default": 3, "min": 1, "max": 64}),
|
| "do_sample": ("BOOLEAN", {"default": True}),
|
| "output_mask_select": ("STRING", {"default": ""}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE", "MASK", "STRING", "JSON")
|
| RETURN_NAMES =("image", "mask", "caption", "data")
|
| FUNCTION = "encode"
|
| CATEGORY = "Florence2"
|
|
|
| def encode(self, image, text_input, florence2_model, task, fill_mask, keep_model_loaded=False,
|
| num_beams=3, max_new_tokens=1024, do_sample=True, output_mask_select=""):
|
| device = mm.get_torch_device()
|
| _, height, width, _ = image.shape
|
| offload_device = mm.unet_offload_device()
|
| annotated_image_tensor = None
|
| mask_tensor = None
|
| processor = florence2_model['processor']
|
| model = florence2_model['model']
|
| dtype = florence2_model['dtype']
|
| model.to(device)
|
|
|
| colormap = ['blue','orange','green','purple','brown','pink','olive','cyan','red',
|
| 'lime','indigo','violet','aqua','magenta','gold','tan','skyblue']
|
|
|
| prompts = {
|
| 'region_caption': '<OD>',
|
| 'dense_region_caption': '<DENSE_REGION_CAPTION>',
|
| 'region_proposal': '<REGION_PROPOSAL>',
|
| 'caption': '<CAPTION>',
|
| 'detailed_caption': '<DETAILED_CAPTION>',
|
| 'more_detailed_caption': '<MORE_DETAILED_CAPTION>',
|
| 'caption_to_phrase_grounding': '<CAPTION_TO_PHRASE_GROUNDING>',
|
| 'referring_expression_segmentation': '<REFERRING_EXPRESSION_SEGMENTATION>',
|
| 'ocr': '<OCR>',
|
| 'ocr_with_region': '<OCR_WITH_REGION>',
|
| 'docvqa': '<DocVQA>'
|
| }
|
| task_prompt = prompts.get(task, '<OD>')
|
|
|
| if (task not in ['referring_expression_segmentation', 'caption_to_phrase_grounding', 'docvqa']) and text_input:
|
| raise ValueError("Text input (prompt) is only supported for 'referring_expression_segmentation', 'caption_to_phrase_grounding', and 'docvqa'")
|
|
|
| if text_input != "":
|
| prompt = task_prompt + " " + text_input
|
| else:
|
| prompt = task_prompt
|
|
|
| image = image.permute(0, 3, 1, 2)
|
|
|
| out = []
|
| out_masks = []
|
| out_results = []
|
| out_data = []
|
| pbar = ProgressBar(len(image))
|
| for img in image:
|
| image_pil = F.to_pil_image(img)
|
| inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device)
|
|
|
| generated_ids = model.generate(
|
| input_ids=inputs["input_ids"],
|
| pixel_values=inputs["pixel_values"],
|
| max_new_tokens=max_new_tokens,
|
| do_sample=do_sample,
|
| num_beams=num_beams,
|
| )
|
|
|
| results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| print(results)
|
|
|
| if task == 'ocr_with_region':
|
| clean_results = str(results)
|
| cleaned_string = re.sub(r'</?s>|<[^>]*>', '\n', clean_results)
|
| clean_results = re.sub(r'\n+', '\n', cleaned_string)
|
| else:
|
| clean_results = str(results)
|
| clean_results = clean_results.replace('</s>', '')
|
| clean_results = clean_results.replace('<s>', '')
|
|
|
|
|
| if len(image) == 1:
|
| out_results = clean_results
|
| else:
|
| out_results.append(clean_results)
|
|
|
| W, H = image_pil.size
|
|
|
| parsed_answer = processor.post_process_generation(results, task=task_prompt, image_size=(W, H))
|
|
|
| if task == 'region_caption' or task == 'dense_region_caption' or task == 'caption_to_phrase_grounding' or task == 'region_proposal':
|
| fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100)
|
| fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
| ax.imshow(image_pil)
|
| bboxes = parsed_answer[task_prompt]['bboxes']
|
| labels = parsed_answer[task_prompt]['labels']
|
|
|
| mask_indexes = []
|
|
|
| if output_mask_select != "":
|
| mask_indexes = [n for n in output_mask_select.split(",")]
|
| print(mask_indexes)
|
| else:
|
| mask_indexes = [str(i) for i in range(len(bboxes))]
|
|
|
|
|
| if fill_mask:
|
| mask_layer = Image.new('RGB', image_pil.size, (0, 0, 0))
|
| mask_draw = ImageDraw.Draw(mask_layer)
|
|
|
| for index, (bbox, label) in enumerate(zip(bboxes, labels)):
|
|
|
| indexed_label = f"{index}.{label}"
|
|
|
| if fill_mask:
|
| if str(index) in mask_indexes:
|
| print("match index:", str(index), "in mask_indexes:", mask_indexes)
|
| mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255))
|
| if label in mask_indexes:
|
| print("match label")
|
| mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255))
|
|
|
|
|
| rect = patches.Rectangle(
|
| (bbox[0], bbox[1]),
|
| bbox[2] - bbox[0],
|
| bbox[3] - bbox[1],
|
| linewidth=1,
|
| edgecolor='r',
|
| facecolor='none',
|
| label=indexed_label
|
| )
|
|
|
| text_width = len(label) * 6
|
| text_height = 12
|
|
|
|
|
| text_x = bbox[0]
|
| text_y = bbox[1] - text_height
|
|
|
|
|
| if text_x < 0:
|
| text_x = 0
|
| elif text_x + text_width > W:
|
| text_x = W - text_width
|
|
|
|
|
| if text_y < 0:
|
| text_y = bbox[3]
|
|
|
|
|
| ax.add_patch(rect)
|
| facecolor = random.choice(colormap) if len(image) == 1 else 'red'
|
|
|
| plt.text(
|
| text_x,
|
| text_y,
|
| indexed_label,
|
| color='white',
|
| fontsize=12,
|
| bbox=dict(facecolor=facecolor, alpha=0.5)
|
| )
|
| if fill_mask:
|
| mask_tensor = F.to_tensor(mask_layer)
|
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float()
|
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True)
|
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3)
|
| mask_tensor = mask_tensor[:, :, :, 0]
|
| out_masks.append(mask_tensor)
|
|
|
|
|
| ax.axis('off')
|
| ax.margins(0,0)
|
| ax.get_xaxis().set_major_locator(plt.NullLocator())
|
| ax.get_yaxis().set_major_locator(plt.NullLocator())
|
| fig.canvas.draw()
|
| buf = io.BytesIO()
|
| plt.savefig(buf, format='png', pad_inches=0)
|
| buf.seek(0)
|
| annotated_image_pil = Image.open(buf)
|
|
|
| annotated_image_tensor = F.to_tensor(annotated_image_pil)
|
| out_tensor = annotated_image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float()
|
| out.append(out_tensor)
|
|
|
|
|
| pbar.update(1)
|
|
|
| plt.close(fig)
|
|
|
| elif task == 'referring_expression_segmentation':
|
|
|
| mask_image = Image.new('RGB', (W, H), 'black')
|
| mask_draw = ImageDraw.Draw(mask_image)
|
|
|
| predictions = parsed_answer[task_prompt]
|
|
|
|
|
| for polygons, label in zip(predictions['polygons'], predictions['labels']):
|
| color = random.choice(colormap)
|
| for _polygon in polygons:
|
| _polygon = np.array(_polygon).reshape(-1, 2)
|
|
|
| _polygon = np.clip(_polygon, [0, 0], [W - 1, H - 1])
|
| if len(_polygon) < 3:
|
| print('Invalid polygon:', _polygon)
|
| continue
|
|
|
| _polygon = _polygon.reshape(-1).tolist()
|
|
|
|
|
| if fill_mask:
|
| overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0))
|
| image_pil = image_pil.convert('RGBA')
|
| draw = ImageDraw.Draw(overlay)
|
| color_with_opacity = ImageColor.getrgb(color) + (180,)
|
| draw.polygon(_polygon, outline=color, fill=color_with_opacity, width=3)
|
| image_pil = Image.alpha_composite(image_pil, overlay)
|
| else:
|
| draw = ImageDraw.Draw(image_pil)
|
| draw.polygon(_polygon, outline=color, width=3)
|
|
|
|
|
| mask_draw.polygon(_polygon, outline="white", fill="white")
|
|
|
| image_tensor = F.to_tensor(image_pil)
|
| image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float()
|
| out.append(image_tensor)
|
|
|
| mask_tensor = F.to_tensor(mask_image)
|
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float()
|
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True)
|
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3)
|
| mask_tensor = mask_tensor[:, :, :, 0]
|
| out_masks.append(mask_tensor)
|
| pbar.update(1)
|
|
|
| elif task == 'ocr_with_region':
|
| try:
|
| font = ImageFont.load_default().font_variant(size=24)
|
| except:
|
| font = ImageFont.load_default()
|
| predictions = parsed_answer[task_prompt]
|
| scale = 1
|
| draw = ImageDraw.Draw(image_pil)
|
| bboxes, labels = predictions['quad_boxes'], predictions['labels']
|
|
|
| for box, label in zip(bboxes, labels):
|
| bbox = calculate_bounding_box(width, height, box)
|
| out_data.append({"label": label, "polygon": box, "box": bbox})
|
| color = random.choice(colormap)
|
| new_box = (np.array(box) * scale).tolist()
|
| draw.polygon(new_box, width=3, outline=color)
|
| draw.text((new_box[0]+8, new_box[1]+2),
|
| "{}".format(label),
|
| align="right",
|
| font=font,
|
| fill=color)
|
|
|
| image_tensor = F.to_tensor(image_pil)
|
| image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float()
|
| out.append(image_tensor)
|
|
|
| elif task == 'docvqa':
|
| if text_input == "":
|
| raise ValueError("Text input (prompt) is required for 'docvqa'")
|
| prompt = "<DocVQA> " + text_input
|
|
|
| inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device)
|
| generated_ids = model.generate(
|
| input_ids=inputs["input_ids"],
|
| pixel_values=inputs["pixel_values"],
|
| max_new_tokens=max_new_tokens,
|
| do_sample=do_sample,
|
| num_beams=num_beams,
|
| )
|
|
|
| results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| clean_results = results.replace('</s>', '').replace('<s>', '')
|
|
|
| if len(image) == 1:
|
| out_results = clean_results
|
| else:
|
| out_results.append(clean_results)
|
|
|
| out.append(F.to_tensor(image_pil).unsqueeze(0).permute(0, 2, 3, 1).cpu().float())
|
|
|
| pbar.update(1)
|
|
|
| if len(out) > 0:
|
| out_tensor = torch.cat(out, dim=0)
|
| else:
|
| out_tensor = torch.zeros((1, 64,64, 3), dtype=torch.float32, device="cpu")
|
| if len(out_masks) > 0:
|
| out_mask_tensor = torch.cat(out_masks, dim=0)
|
| else:
|
| out_mask_tensor = torch.zeros((1,64,64), dtype=torch.float32, device="cpu")
|
|
|
| if not keep_model_loaded:
|
| print("Offloading model...")
|
| model.to(offload_device)
|
| mm.soft_empty_cache()
|
|
|
| return (out_tensor, out_mask_tensor, out_results, out_data)
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "DownloadAndLoadFlorence2Model": DownloadAndLoadFlorence2Model,
|
| "Florence2Run": Florence2Run,
|
| }
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "DownloadAndLoadFlorence2Model": "DownloadAndLoadFlorence2Model",
|
| "Florence2Run": "Florence2Run",
|
| } |