from __future__ import annotations import os import re from io import BytesIO from typing import Any import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base") MODEL_REVISION = os.getenv("MODEL_REVISION") DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64")) MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256")) MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896")) RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32")) NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3")) DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "") TASK_TOKEN_PATTERN = re.compile(r"^<[^>\s]+>") _model = None _processor = None _device = torch.device("cpu") _dtype = torch.float32 def _prepare_image(image_bytes: bytes) -> Image.Image: image = Image.open(BytesIO(image_bytes)).convert("RGB") width, height = image.size if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE: return image if width >= height: # Landscape: cap width, preserve aspect ratio. ratio = MAX_IMAGE_SIDE / width else: # Portrait: cap height, preserve aspect ratio. ratio = MAX_IMAGE_SIDE / height new_w = max(1, int(width * ratio)) new_h = max(1, int(height * ratio)) # Align dimensions to improve tensor-core friendly shapes. if RESIZE_MULTIPLE > 1: new_w = max(RESIZE_MULTIPLE, (new_w // RESIZE_MULTIPLE) * RESIZE_MULTIPLE) new_h = max(RESIZE_MULTIPLE, (new_h // RESIZE_MULTIPLE) * RESIZE_MULTIPLE) new_size = (new_w, new_h) return image.resize(new_size, Image.Resampling.LANCZOS) def load_model() -> tuple[Any, Any]: global _model, _processor if _model is None or _processor is None: pretrained_kwargs: dict[str, Any] = {"trust_remote_code": True} if MODEL_REVISION: pretrained_kwargs["revision"] = MODEL_REVISION _processor = AutoProcessor.from_pretrained(MODEL_ID, **pretrained_kwargs) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=_dtype, attn_implementation="eager", **pretrained_kwargs, ).to(_device) _model.eval() return _model, _processor def _build_prompt(text_input: str | None) -> str: if text_input is None: return DEFAULT_PROMPT prompt = text_input.strip() if not prompt: return DEFAULT_PROMPT if not prompt.startswith("<"): raise ValueError( "Invalid prompt in `text`: expected a Florence-2 task token like " "'' or 'phrase'." ) return prompt def _task_token_from_prompt(prompt: str) -> str: match = TASK_TOKEN_PATTERN.match(prompt) return match.group(0) if match else DEFAULT_PROMPT def generate_caption( image_bytes: bytes, text_input: str | None = None, max_tokens: int = DEFAULT_MAX_TOKENS, ) -> dict[str, Any]: model, processor = load_model() prompt = _build_prompt(text_input) safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS) image = _prepare_image(image_bytes) try: inputs = processor(text=prompt, images=image, return_tensors="pt") except AssertionError as exc: raise ValueError( "Invalid Florence-2 task format in `text`. For plain captioning, use only " "'' with no extra words." ) from exc input_ids = inputs["input_ids"].to(_device) pixel_values = inputs["pixel_values"].to(_device, _dtype) with torch.inference_mode(): generated_ids = model.generate( input_ids=input_ids, pixel_values=pixel_values, do_sample=False, max_new_tokens=safe_max_tokens, num_beams=max(1, NUM_BEAMS), ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip() parsed = None post_process = getattr(processor, "post_process_generation", None) if callable(post_process): try: parsed = post_process( generated_text, task=_task_token_from_prompt(prompt), image_size=(image.width, image.height), ) except Exception: parsed = None return {"text": generated_text, "parsed": parsed} if parsed else {"text": generated_text}