Image-describer / app.py
Hug0endob's picture
Update app.py
cd43691 verified
# app.py – Gradio 6+ (CPU‑only) – safe for limited sandbox resources
import base64
import gc
import logging
import threading
import time
import urllib.parse
from io import BytesIO
import gradio as gr
import requests
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
VisionEncoderDecoderModel,
ViTImageProcessor,
T5ForConditionalGeneration,
T5Tokenizer,
)
# -------------------------------------------------
# Runtime limits (sandbox‑friendly)
# -------------------------------------------------
torch.set_num_threads(1) # one CPU thread
torch.set_num_interop_threads(1) # one inter‑op thread
torch.set_grad_enabled(False) # inference‑only
logging.basicConfig(level=logging.INFO)
device = torch.device("cpu")
# -------------------------------------------------
# Model loading (fp16 only when a GPU is present)
# -------------------------------------------------
IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
TXT_MODEL = "t5-small"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Vision‑caption model
processor = ViTImageProcessor.from_pretrained(IMG_MODEL)
tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL)
vision = (
VisionEncoderDecoderModel.from_pretrained(IMG_MODEL, torch_dtype=dtype)
.to(device)
.eval()
)
# Text‑rewriter model
rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
rewriter = (
T5ForConditionalGeneration.from_pretrained(TXT_MODEL, torch_dtype=dtype)
.to(device)
.eval()
)
# Release any temporary download buffers
gc.collect()
torch.cuda.empty_cache() # no‑op on CPU, kept for symmetry
# -------------------------------------------------
# Helper utilities
# -------------------------------------------------
def load_image(url: str):
"""Fetch an image from a URL or a data‑URL."""
try:
url = (url or "").strip()
if not url:
return None, "No URL provided."
# data‑URL (base64‑encoded image)
if url.startswith("data:"):
_, data = url.split(",", 1)
img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
return img, None
# normal HTTP/HTTPS URL
if not urllib.parse.urlsplit(url).scheme:
return None, "Missing http/https scheme."
resp = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"})
resp.raise_for_status()
img = Image.open(BytesIO(resp.content)).convert("RGB")
return img, None
except Exception as exc:
return None, f"Load error: {exc}"
def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
"""Create a short caption with the vision model."""
inputs = processor(images=img, return_tensors="pt")
pix = inputs.pixel_values.to(device)
if sample:
out = vision.generate(
pix,
max_length=max_len,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.9,
num_return_sequences=3,
early_stopping=True,
)
else:
out = vision.generate(
pix,
max_length=max_len,
num_beams=beams,
num_return_sequences=min(3, beams),
early_stopping=True,
)
captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
# pick the longest (usually the most complete) caption
return max(captions, key=lambda s: len(s.split()))
def expand_caption(base: str, prompt: str = None, max_len=160):
"""Rewrite/expand the base caption with the T5 model."""
instruction = (
f"Expand using: '{prompt}'. Caption: \"{base}\""
if prompt and prompt.strip()
else f"Expand with rich visual detail. Caption: \"{base}\""
)
toks = rewriter_tok(
instruction,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=256,
).to(device)
out = rewriter.generate(
**toks,
max_length=max_len,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
)
return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
def async_expand(base, prompt, max_len, status_dict):
"""Background thread that runs the expansion and updates status."""
try:
status_dict["text"] = "Expanding…"
result = expand_caption(base, prompt, max_len)
status_dict["final"] = result
status_dict["text"] = "Done"
except Exception as exc:
status_dict["text"] = f"Error: {exc}"
status_dict["final"] = base
# -------------------------------------------------
# Gradio callbacks
# -------------------------------------------------
def fast_describe(url, prompt, detail, beams, sample):
"""Quick path – returns image, short caption and a transient status."""
img, err = load_image(url)
if err:
return None, "", err
detail_map = {"Low": 80, "Medium": 140, "High": 220}
max_expand = detail_map.get(detail, 140)
base = generate_base(img, beams=beams, sample=sample)
# status is a mutable dict that the UI can read later
status = {"text": "Queued…", "final": ""}
threading.Thread(
target=async_expand,
args=(base, prompt, max_expand, status),
daemon=True,
).start()
# The UI will poll `status_out` to see the final text later
return img, base, status["text"]
def final_caption(url, prompt, detail, beams, sample):
"""Blocking path – returns the fully expanded caption."""
img, err = load_image(url)
if err:
return "", err
detail_map = {"Low": 80, "Medium": 140, "High": 220}
max_expand = detail_map.get(detail, 140)
base = generate_base(img, beams=beams, sample=sample)
try:
final = expand_caption(base, prompt, max_expand)
return final, "Done"
except Exception as exc:
return base, f"Expand error: {exc}"
# -------------------------------------------------
# UI layout
# -------------------------------------------------
css = "footer {display:none !important;}"
with gr.Blocks(title="Image Describer (CPU‑only)", css=css) as demo:
gr.Markdown("## Image Describer (CPU‑only)")
with gr.Row():
# ---- Left column – inputs ----
with gr.Column():
url_in = gr.Textbox(label="Image URL / data‑URL")
prompt_in = gr.Textbox(label="Optional prompt")
detail_in = gr.Radio(
["Low", "Medium", "High"], value="Medium", label="Detail level"
)
beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
sample_in = gr.Checkbox(
label="Enable sampling (more diverse)", value=False
)
go_btn = gr.Button("Load & Describe (fast)")
final_btn = gr.Button("Get final caption (detailed)")
status_out = gr.Textbox(label="Status", interactive=False)
# ---- Middle column – image preview ----
with gr.Column():
img_out = gr.Image(type="pil", label="Image")
# ---- Right column – caption output ----
with gr.Column():
caption_out = gr.Textbox(label="Caption", lines=8)
# Fast path: returns image + short caption immediately
go_btn.click(
fn=fast_describe,
inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
outputs=[img_out, caption_out, status_out],
)
# Detailed path: blocks until the expanded caption is ready
final_btn.click(
fn=final_caption,
inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
outputs=[caption_out, status_out],
)
if __name__ == "__main__":
demo.queue() # enables request queuing (helps with sandbox limits)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)