File size: 4,430 Bytes
641b32e
 
 
49f8ccd
641b32e
 
 
 
 
 
 
 
210def2
641b32e
 
 
210def2
 
afd6ed3
49f8ccd
641b32e
 
 
210def2
 
641b32e
 
 
 
 
 
 
 
210def2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641b32e
 
 
 
 
 
210def2
 
 
 
 
641b32e
 
 
da2a069
210def2
641b32e
 
 
 
 
49f8ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641b32e
 
 
 
 
 
49f8ccd
641b32e
 
 
 
49f8ccd
 
 
 
 
 
 
641b32e
 
 
 
 
 
 
 
 
210def2
641b32e
 
210def2
641b32e
 
 
 
 
210def2
 
49f8ccd
210def2
 
641b32e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import os
import re
from io import BytesIO
from typing import Any

import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
MODEL_REVISION = os.getenv("MODEL_REVISION")
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64"))
MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>")
TASK_TOKEN_PATTERN = re.compile(r"^<[^>\s]+>")

_model = None
_processor = None
_device = torch.device("cpu")
_dtype = torch.float32


def _prepare_image(image_bytes: bytes) -> Image.Image:
    image = Image.open(BytesIO(image_bytes)).convert("RGB")
    width, height = image.size
    if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE:
        return image

    if width >= height:
        # Landscape: cap width, preserve aspect ratio.
        ratio = MAX_IMAGE_SIDE / width
    else:
        # Portrait: cap height, preserve aspect ratio.
        ratio = MAX_IMAGE_SIDE / height

    new_w = max(1, int(width * ratio))
    new_h = max(1, int(height * ratio))

    # Align dimensions to improve tensor-core friendly shapes.
    if RESIZE_MULTIPLE > 1:
        new_w = max(RESIZE_MULTIPLE, (new_w // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
        new_h = max(RESIZE_MULTIPLE, (new_h // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)

    new_size = (new_w, new_h)
    return image.resize(new_size, Image.Resampling.LANCZOS)


def load_model() -> tuple[Any, Any]:
    global _model, _processor
    if _model is None or _processor is None:
        pretrained_kwargs: dict[str, Any] = {"trust_remote_code": True}
        if MODEL_REVISION:
            pretrained_kwargs["revision"] = MODEL_REVISION

        _processor = AutoProcessor.from_pretrained(MODEL_ID, **pretrained_kwargs)
        _model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=_dtype,
            attn_implementation="eager",
            **pretrained_kwargs,
        ).to(_device)
        _model.eval()
    return _model, _processor


def _build_prompt(text_input: str | None) -> str:
    if text_input is None:
        return DEFAULT_PROMPT

    prompt = text_input.strip()
    if not prompt:
        return DEFAULT_PROMPT
    if not prompt.startswith("<"):
        raise ValueError(
            "Invalid prompt in `text`: expected a Florence-2 task token like "
            "'<CAPTION>' or '<CAPTION_TO_PHRASE_GROUNDING>phrase'."
        )
    return prompt


def _task_token_from_prompt(prompt: str) -> str:
    match = TASK_TOKEN_PATTERN.match(prompt)
    return match.group(0) if match else DEFAULT_PROMPT


def generate_caption(
    image_bytes: bytes,
    text_input: str | None = None,
    max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, Any]:
    model, processor = load_model()
    prompt = _build_prompt(text_input)

    safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
    image = _prepare_image(image_bytes)

    try:
        inputs = processor(text=prompt, images=image, return_tensors="pt")
    except AssertionError as exc:
        raise ValueError(
            "Invalid Florence-2 task format in `text`. For plain captioning, use only "
            "'<CAPTION>' with no extra words."
        ) from exc
    input_ids = inputs["input_ids"].to(_device)
    pixel_values = inputs["pixel_values"].to(_device, _dtype)

    with torch.inference_mode():
        generated_ids = model.generate(
            input_ids=input_ids,
            pixel_values=pixel_values,
            do_sample=False,
            max_new_tokens=safe_max_tokens,
            num_beams=max(1, NUM_BEAMS),
        )

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip()

    parsed = None
    post_process = getattr(processor, "post_process_generation", None)
    if callable(post_process):
        try:
            parsed = post_process(
                generated_text,
                task=_task_token_from_prompt(prompt),
                image_size=(image.width, image.height),
            )
        except Exception:
            parsed = None

    return {"text": generated_text, "parsed": parsed} if parsed else {"text": generated_text}