File size: 3,787 Bytes
5251d80 e6ef52a 787d948 5251d80 787d948 5251d80 eb1bdd5 5251d80 eb1bdd5 5251d80 787d948 25a17bd 5251d80 fc12805 5251d80 787d948 5251d80 787d948 5251d80 ba8d791 5251d80 ba8d791 a6262d6 5251d80 3864546 5251d80 3864546 5251d80 3864546 5251d80 3864546 5251d80 3864546 5251d80 25a17bd 787d948 5251d80 25a17bd 5251d80 0663dd7 5251d80 25a17bd 5251d80 787d948 5251d80 ba8d791 787d948 5251d80 25a17bd 5251d80 25a17bd 5251d80 787d948 ba8d791 5251d80 df4efd1 5251d80 df4efd1 5251d80 df4efd1 5251d80 df4efd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# CPU-only: Image → Caption (Florence-2-base), concise build
import os
from functools import lru_cache
import torch
import gradio as gr
from PIL import Image
# AVIF/HEIF support (optional, safe to ignore if unavailable)
try:
import pillow_avif # noqa: F401
except Exception:
pass
try:
from pillow_heif import register_heif_opener
register_heif_opener()
except Exception:
pass
from transformers import AutoProcessor, AutoModelForCausalLM
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports as _orig_get_imports
CAPTION_MODEL_ID = "microsoft/Florence-2-base"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
DEVICE = "cpu"
DTYPE = torch.float32
MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "1024"))
def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
w, h = img.size
if max(w, h) <= max_side:
return img
r = max_side / max(w, h)
return img.resize((int(w * r), int(h * r)), Image.LANCZOS)
def _ensure_rgb(img) -> Image.Image:
if not isinstance(img, Image.Image):
raise gr.Error("Upload a valid image.")
return img.convert("RGB")
def _no_flash_attn_get_imports(filename):
imps = _orig_get_imports(filename)
try:
name = str(filename).lower()
if "florence2" in name or "modeling_florence2.py" in name:
return [x for x in imps if x != "flash_attn"]
except Exception:
pass
return imps
@lru_cache(maxsize=1)
def _load_florence():
proc = AutoProcessor.from_pretrained(CAPTION_MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
with patch("transformers.dynamic_module_utils.get_imports", _no_flash_attn_get_imports):
mdl = AutoModelForCausalLM.from_pretrained(
CAPTION_MODEL_ID,
trust_remote_code=True,
token=HF_TOKEN,
attn_implementation="sdpa", # CPU-safe
torch_dtype=DTYPE,
device_map="cpu",
).eval()
return proc, mdl
@torch.inference_mode()
def caption(image: Image.Image, max_new_tokens: int = 128, num_beams: int = 3) -> str:
image = _ensure_rgb(_resize_max(image))
processor, model = _load_florence()
batch = processor(
text="<MORE_DETAILED_CAPTION>",
images=[image], # batch even for single
padding=True,
return_tensors="pt",
)
# move tensors to CPU device (BatchFeature may contain non-tensors)
for k, v in list(batch.items()):
if torch.is_tensor(v):
batch[k] = v.to(DEVICE)
out_ids = model.generate(
**batch,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False,
early_stopping=False,
)
gen = processor.batch_decode(out_ids, skip_special_tokens=False)[0]
parsed = processor.post_process_generation(
gen, task="<MORE_DETAILED_CAPTION>", image_size=[(image.width, image.height)]
)
data = parsed[0] if isinstance(parsed, list) else parsed
return (data.get("<MORE_DETAILED_CAPTION>", "") or "Unable to generate a caption.").strip()
def run(image: Image.Image):
txt = caption(image)
return txt, "Model: Florence-2-base (CPU)"
with gr.Blocks(css="footer{visibility:hidden}") as demo:
gr.Markdown("# Image → Caption (CPU) — Florence-2-base")
with gr.Row():
with gr.Column():
img = gr.Image(type="pil", label="Image", sources=["upload", "clipboard", "webcam"])
btn = gr.Button("Caption", variant="primary")
with gr.Column():
out = gr.Textbox(label="Caption", lines=10)
status = gr.Markdown()
btn.click(run, [img], [out, status], scroll_to_output=True)
if __name__ == "__main__":
demo.queue(max_size=8).launch()
|