Spaces:
Running
Running
File size: 4,430 Bytes
641b32e 49f8ccd 641b32e 210def2 641b32e 210def2 afd6ed3 49f8ccd 641b32e 210def2 641b32e 210def2 641b32e 210def2 641b32e da2a069 210def2 641b32e 49f8ccd 641b32e 49f8ccd 641b32e 49f8ccd 641b32e 210def2 641b32e 210def2 641b32e 210def2 49f8ccd 210def2 641b32e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | from __future__ import annotations
import os
import re
from io import BytesIO
from typing import Any
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
MODEL_REVISION = os.getenv("MODEL_REVISION")
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64"))
MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>")
TASK_TOKEN_PATTERN = re.compile(r"^<[^>\s]+>")
_model = None
_processor = None
_device = torch.device("cpu")
_dtype = torch.float32
def _prepare_image(image_bytes: bytes) -> Image.Image:
image = Image.open(BytesIO(image_bytes)).convert("RGB")
width, height = image.size
if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE:
return image
if width >= height:
# Landscape: cap width, preserve aspect ratio.
ratio = MAX_IMAGE_SIDE / width
else:
# Portrait: cap height, preserve aspect ratio.
ratio = MAX_IMAGE_SIDE / height
new_w = max(1, int(width * ratio))
new_h = max(1, int(height * ratio))
# Align dimensions to improve tensor-core friendly shapes.
if RESIZE_MULTIPLE > 1:
new_w = max(RESIZE_MULTIPLE, (new_w // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
new_h = max(RESIZE_MULTIPLE, (new_h // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
new_size = (new_w, new_h)
return image.resize(new_size, Image.Resampling.LANCZOS)
def load_model() -> tuple[Any, Any]:
global _model, _processor
if _model is None or _processor is None:
pretrained_kwargs: dict[str, Any] = {"trust_remote_code": True}
if MODEL_REVISION:
pretrained_kwargs["revision"] = MODEL_REVISION
_processor = AutoProcessor.from_pretrained(MODEL_ID, **pretrained_kwargs)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=_dtype,
attn_implementation="eager",
**pretrained_kwargs,
).to(_device)
_model.eval()
return _model, _processor
def _build_prompt(text_input: str | None) -> str:
if text_input is None:
return DEFAULT_PROMPT
prompt = text_input.strip()
if not prompt:
return DEFAULT_PROMPT
if not prompt.startswith("<"):
raise ValueError(
"Invalid prompt in `text`: expected a Florence-2 task token like "
"'<CAPTION>' or '<CAPTION_TO_PHRASE_GROUNDING>phrase'."
)
return prompt
def _task_token_from_prompt(prompt: str) -> str:
match = TASK_TOKEN_PATTERN.match(prompt)
return match.group(0) if match else DEFAULT_PROMPT
def generate_caption(
image_bytes: bytes,
text_input: str | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, Any]:
model, processor = load_model()
prompt = _build_prompt(text_input)
safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
image = _prepare_image(image_bytes)
try:
inputs = processor(text=prompt, images=image, return_tensors="pt")
except AssertionError as exc:
raise ValueError(
"Invalid Florence-2 task format in `text`. For plain captioning, use only "
"'<CAPTION>' with no extra words."
) from exc
input_ids = inputs["input_ids"].to(_device)
pixel_values = inputs["pixel_values"].to(_device, _dtype)
with torch.inference_mode():
generated_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
do_sample=False,
max_new_tokens=safe_max_tokens,
num_beams=max(1, NUM_BEAMS),
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip()
parsed = None
post_process = getattr(processor, "post_process_generation", None)
if callable(post_process):
try:
parsed = post_process(
generated_text,
task=_task_token_from_prompt(prompt),
image_size=(image.width, image.height),
)
except Exception:
parsed = None
return {"text": generated_text, "parsed": parsed} if parsed else {"text": generated_text}
|