File size: 3,787 Bytes
5251d80
e6ef52a
787d948
5251d80
787d948
 
 
 
 
5251d80
eb1bdd5
5251d80
eb1bdd5
 
 
 
 
 
 
 
5251d80
 
 
787d948
25a17bd
5251d80
fc12805
5251d80
 
787d948
 
 
 
 
5251d80
 
787d948
5251d80
ba8d791
5251d80
ba8d791
a6262d6
5251d80
 
3864546
 
 
5251d80
3864546
 
5251d80
3864546
5251d80
 
 
 
 
3864546
 
 
5251d80
3864546
 
 
5251d80
25a17bd
787d948
5251d80
 
 
 
25a17bd
5251d80
 
 
0663dd7
5251d80
 
 
 
25a17bd
5251d80
 
787d948
 
5251d80
ba8d791
787d948
5251d80
25a17bd
5251d80
25a17bd
5251d80
 
787d948
ba8d791
5251d80
 
df4efd1
5251d80
 
df4efd1
 
5251d80
 
df4efd1
5251d80
 
 
df4efd1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# CPU-only: Image → Caption (Florence-2-base), concise build

import os
from functools import lru_cache

import torch
import gradio as gr
from PIL import Image

# AVIF/HEIF support (optional, safe to ignore if unavailable)
try:
    import pillow_avif  # noqa: F401
except Exception:
    pass
try:
    from pillow_heif import register_heif_opener
    register_heif_opener()
except Exception:
    pass

from transformers import AutoProcessor, AutoModelForCausalLM
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports as _orig_get_imports

CAPTION_MODEL_ID = "microsoft/Florence-2-base"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
DEVICE = "cpu"
DTYPE = torch.float32
MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "1024"))

def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
    w, h = img.size
    if max(w, h) <= max_side:
        return img
    r = max_side / max(w, h)
    return img.resize((int(w * r), int(h * r)), Image.LANCZOS)

def _ensure_rgb(img) -> Image.Image:
    if not isinstance(img, Image.Image):
        raise gr.Error("Upload a valid image.")
    return img.convert("RGB")

def _no_flash_attn_get_imports(filename):
    imps = _orig_get_imports(filename)
    try:
        name = str(filename).lower()
        if "florence2" in name or "modeling_florence2.py" in name:
            return [x for x in imps if x != "flash_attn"]
    except Exception:
        pass
    return imps

@lru_cache(maxsize=1)
def _load_florence():
    proc = AutoProcessor.from_pretrained(CAPTION_MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
    with patch("transformers.dynamic_module_utils.get_imports", _no_flash_attn_get_imports):
        mdl = AutoModelForCausalLM.from_pretrained(
            CAPTION_MODEL_ID,
            trust_remote_code=True,
            token=HF_TOKEN,
            attn_implementation="sdpa",  # CPU-safe
            torch_dtype=DTYPE,
            device_map="cpu",
        ).eval()
    return proc, mdl

@torch.inference_mode()
def caption(image: Image.Image, max_new_tokens: int = 128, num_beams: int = 3) -> str:
    image = _ensure_rgb(_resize_max(image))
    processor, model = _load_florence()
    batch = processor(
        text="<MORE_DETAILED_CAPTION>",
        images=[image],          # batch even for single
        padding=True,
        return_tensors="pt",
    )
    # move tensors to CPU device (BatchFeature may contain non-tensors)
    for k, v in list(batch.items()):
        if torch.is_tensor(v):
            batch[k] = v.to(DEVICE)

    out_ids = model.generate(
        **batch,
        max_new_tokens=max_new_tokens,
        num_beams=num_beams,
        do_sample=False,
        early_stopping=False,
    )
    gen = processor.batch_decode(out_ids, skip_special_tokens=False)[0]
    parsed = processor.post_process_generation(
        gen, task="<MORE_DETAILED_CAPTION>", image_size=[(image.width, image.height)]
    )
    data = parsed[0] if isinstance(parsed, list) else parsed
    return (data.get("<MORE_DETAILED_CAPTION>", "") or "Unable to generate a caption.").strip()

def run(image: Image.Image):
    txt = caption(image)
    return txt, "Model: Florence-2-base (CPU)"

with gr.Blocks(css="footer{visibility:hidden}") as demo:
    gr.Markdown("# Image → Caption (CPU) — Florence-2-base")
    with gr.Row():
        with gr.Column():
            img = gr.Image(type="pil", label="Image", sources=["upload", "clipboard", "webcam"])
            btn = gr.Button("Caption", variant="primary")
        with gr.Column():
            out = gr.Textbox(label="Caption", lines=10)
            status = gr.Markdown()
    btn.click(run, [img], [out, status], scroll_to_output=True)

if __name__ == "__main__":
    demo.queue(max_size=8).launch()