|
|
|
|
|
|
|
|
import os |
|
|
from functools import lru_cache |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
try: |
|
|
import pillow_avif |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
from pillow_heif import register_heif_opener |
|
|
register_heif_opener() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
from unittest.mock import patch |
|
|
from transformers.dynamic_module_utils import get_imports as _orig_get_imports |
|
|
|
|
|
CAPTION_MODEL_ID = "microsoft/Florence-2-base" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
DEVICE = "cpu" |
|
|
DTYPE = torch.float32 |
|
|
MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "1024")) |
|
|
|
|
|
def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image: |
|
|
w, h = img.size |
|
|
if max(w, h) <= max_side: |
|
|
return img |
|
|
r = max_side / max(w, h) |
|
|
return img.resize((int(w * r), int(h * r)), Image.LANCZOS) |
|
|
|
|
|
def _ensure_rgb(img) -> Image.Image: |
|
|
if not isinstance(img, Image.Image): |
|
|
raise gr.Error("Upload a valid image.") |
|
|
return img.convert("RGB") |
|
|
|
|
|
def _no_flash_attn_get_imports(filename): |
|
|
imps = _orig_get_imports(filename) |
|
|
try: |
|
|
name = str(filename).lower() |
|
|
if "florence2" in name or "modeling_florence2.py" in name: |
|
|
return [x for x in imps if x != "flash_attn"] |
|
|
except Exception: |
|
|
pass |
|
|
return imps |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def _load_florence(): |
|
|
proc = AutoProcessor.from_pretrained(CAPTION_MODEL_ID, trust_remote_code=True, token=HF_TOKEN) |
|
|
with patch("transformers.dynamic_module_utils.get_imports", _no_flash_attn_get_imports): |
|
|
mdl = AutoModelForCausalLM.from_pretrained( |
|
|
CAPTION_MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
token=HF_TOKEN, |
|
|
attn_implementation="sdpa", |
|
|
torch_dtype=DTYPE, |
|
|
device_map="cpu", |
|
|
).eval() |
|
|
return proc, mdl |
|
|
|
|
|
@torch.inference_mode() |
|
|
def caption(image: Image.Image, max_new_tokens: int = 128, num_beams: int = 3) -> str: |
|
|
image = _ensure_rgb(_resize_max(image)) |
|
|
processor, model = _load_florence() |
|
|
batch = processor( |
|
|
text="<MORE_DETAILED_CAPTION>", |
|
|
images=[image], |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
for k, v in list(batch.items()): |
|
|
if torch.is_tensor(v): |
|
|
batch[k] = v.to(DEVICE) |
|
|
|
|
|
out_ids = model.generate( |
|
|
**batch, |
|
|
max_new_tokens=max_new_tokens, |
|
|
num_beams=num_beams, |
|
|
do_sample=False, |
|
|
early_stopping=False, |
|
|
) |
|
|
gen = processor.batch_decode(out_ids, skip_special_tokens=False)[0] |
|
|
parsed = processor.post_process_generation( |
|
|
gen, task="<MORE_DETAILED_CAPTION>", image_size=[(image.width, image.height)] |
|
|
) |
|
|
data = parsed[0] if isinstance(parsed, list) else parsed |
|
|
return (data.get("<MORE_DETAILED_CAPTION>", "") or "Unable to generate a caption.").strip() |
|
|
|
|
|
def run(image: Image.Image): |
|
|
txt = caption(image) |
|
|
return txt, "Model: Florence-2-base (CPU)" |
|
|
|
|
|
with gr.Blocks(css="footer{visibility:hidden}") as demo: |
|
|
gr.Markdown("# Image β Caption (CPU) β Florence-2-base") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
img = gr.Image(type="pil", label="Image", sources=["upload", "clipboard", "webcam"]) |
|
|
btn = gr.Button("Caption", variant="primary") |
|
|
with gr.Column(): |
|
|
out = gr.Textbox(label="Caption", lines=10) |
|
|
status = gr.Markdown() |
|
|
btn.click(run, [img], [out, status], scroll_to_output=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=8).launch() |
|
|
|