AiSolMM / app.py
achase25's picture
Update app.py
5251d80 verified
# CPU-only: Image β†’ Caption (Florence-2-base), concise build
import os
from functools import lru_cache
import torch
import gradio as gr
from PIL import Image
# AVIF/HEIF support (optional, safe to ignore if unavailable)
try:
import pillow_avif # noqa: F401
except Exception:
pass
try:
from pillow_heif import register_heif_opener
register_heif_opener()
except Exception:
pass
from transformers import AutoProcessor, AutoModelForCausalLM
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports as _orig_get_imports
CAPTION_MODEL_ID = "microsoft/Florence-2-base"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
DEVICE = "cpu"
DTYPE = torch.float32
MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "1024"))
def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
w, h = img.size
if max(w, h) <= max_side:
return img
r = max_side / max(w, h)
return img.resize((int(w * r), int(h * r)), Image.LANCZOS)
def _ensure_rgb(img) -> Image.Image:
if not isinstance(img, Image.Image):
raise gr.Error("Upload a valid image.")
return img.convert("RGB")
def _no_flash_attn_get_imports(filename):
imps = _orig_get_imports(filename)
try:
name = str(filename).lower()
if "florence2" in name or "modeling_florence2.py" in name:
return [x for x in imps if x != "flash_attn"]
except Exception:
pass
return imps
@lru_cache(maxsize=1)
def _load_florence():
proc = AutoProcessor.from_pretrained(CAPTION_MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
with patch("transformers.dynamic_module_utils.get_imports", _no_flash_attn_get_imports):
mdl = AutoModelForCausalLM.from_pretrained(
CAPTION_MODEL_ID,
trust_remote_code=True,
token=HF_TOKEN,
attn_implementation="sdpa", # CPU-safe
torch_dtype=DTYPE,
device_map="cpu",
).eval()
return proc, mdl
@torch.inference_mode()
def caption(image: Image.Image, max_new_tokens: int = 128, num_beams: int = 3) -> str:
image = _ensure_rgb(_resize_max(image))
processor, model = _load_florence()
batch = processor(
text="<MORE_DETAILED_CAPTION>",
images=[image], # batch even for single
padding=True,
return_tensors="pt",
)
# move tensors to CPU device (BatchFeature may contain non-tensors)
for k, v in list(batch.items()):
if torch.is_tensor(v):
batch[k] = v.to(DEVICE)
out_ids = model.generate(
**batch,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False,
early_stopping=False,
)
gen = processor.batch_decode(out_ids, skip_special_tokens=False)[0]
parsed = processor.post_process_generation(
gen, task="<MORE_DETAILED_CAPTION>", image_size=[(image.width, image.height)]
)
data = parsed[0] if isinstance(parsed, list) else parsed
return (data.get("<MORE_DETAILED_CAPTION>", "") or "Unable to generate a caption.").strip()
def run(image: Image.Image):
txt = caption(image)
return txt, "Model: Florence-2-base (CPU)"
with gr.Blocks(css="footer{visibility:hidden}") as demo:
gr.Markdown("# Image β†’ Caption (CPU) β€” Florence-2-base")
with gr.Row():
with gr.Column():
img = gr.Image(type="pil", label="Image", sources=["upload", "clipboard", "webcam"])
btn = gr.Button("Caption", variant="primary")
with gr.Column():
out = gr.Textbox(label="Caption", lines=10)
status = gr.Markdown()
btn.click(run, [img], [out, status], scroll_to_output=True)
if __name__ == "__main__":
demo.queue(max_size=8).launch()