Image-describer / app.py
Hug0endob's picture
Update app.py
7d16cdf verified
raw
history blame
5.6 kB
# app.py – Gradio 6+ (CPU)
import base64
import threading
import time
import urllib.parse
from io import BytesIO
import gradio as gr
import requests
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
VisionEncoderDecoderModel,
ViTImageProcessor,
T5ForConditionalGeneration,
T5Tokenizer,
)
device = torch.device("cpu")
IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
TXT_MODEL = "t5-small"
processor = ViTImageProcessor.from_pretrained(IMG_MODEL)
tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL)
vision = VisionEncoderDecoderModel.from_pretrained(IMG_MODEL).to(device).eval()
rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
def load_image(url: str):
"""Return (PIL.Image, None) or (None, error). Handles http/https and data‑URL."""
try:
url = (url or "").strip()
if not url:
return None, "No URL provided."
if url.startswith("data:"):
_, data = url.split(",", 1)
img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
return img, None
if not urllib.parse.urlsplit(url).scheme:
return None, "Missing http/https scheme."
r = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"})
r.raise_for_status()
return Image.open(BytesIO(r.content)).convert("RGB"), None
except Exception as e:
return None, f"Load error: {e}"
def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
inputs = processor(images=img, return_tensors="pt")
pix = inputs.pixel_values.to(device)
if sample:
out = vision.generate(
pix,
max_length=max_len,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.9,
num_return_sequences=3,
early_stopping=True,
)
else:
out = vision.generate(
pix,
max_length=max_len,
num_beams=beams,
num_return_sequences=min(3, beams),
early_stopping=True,
)
caps = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
return max(caps, key=lambda s: len(s.split()))
def expand_caption(base: str, prompt: str = None, max_len=160):
if prompt and prompt.strip():
instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
else:
instr = f"Expand with rich visual detail. Caption: \"{base}\""
toks = rewriter_tok(
instr,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=256,
).to(device)
out = rewriter.generate(
**toks,
max_length=max_len,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
)
return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
def async_expand(base, prompt, max_len, status):
try:
status["text"] = "Expanding…"
time.sleep(0.1)
result = expand_caption(base, prompt, max_len)
status["text"] = "Done"
status["final"] = result
except Exception as e:
status["text"] = f"Error: {e}"
status["final"] = base
def fast_describe(url, prompt, detail, beams, sample):
img, err = load_image(url)
if err:
return None, "", err
detail_map = {"Low": 80, "Medium": 140, "High": 220}
max_expand = detail_map.get(detail, 140)
base = generate_base(img, beams=beams, sample=sample)
status = {"text": "Queued…", "final": ""}
threading.Thread(target=async_expand, args=(base, prompt, max_expand, status), daemon=True).start()
return img, base, status["text"]
def final_caption(url, prompt, detail, beams, sample):
img, err = load_image(url)
if err:
return "", err
detail_map = {"Low": 80, "Medium": 140, "High": 220}
max_expand = detail_map.get(detail, 140)
base = generate_base(img, beams=beams, sample=sample)
try:
final = expand_caption(base, prompt, max_expand)
return final, "Done"
except Exception as e:
return base, f"Expand error: {e}"
css = "footer {display:none !important;}"
with gr.Blocks() as demo:
gr.Markdown("## Image Describer")
with gr.Row():
with gr.Column():
url_in = gr.Textbox(label="Image URL / data‑URL")
prompt_in = gr.Textbox(label="Optional prompt")
detail_in = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Detail level")
beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
sample_in = gr.Checkbox(label="Enable sampling (more diverse)", value=False)
go_btn = gr.Button("Load & Describe (fast)")
final_btn = gr.Button("Get final caption (detailed)")
status_out = gr.Textbox(label="Status", interactive=False)
with gr.Column():
img_out = gr.Image(type="pil", label="Image")
with gr.Column():
caption_out = gr.Textbox(label="Caption", lines=8)
go_btn.click(
fn=fast_describe,
inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
outputs=[img_out, caption_out, status_out],
)
final_btn.click(
fn=final_caption,
inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
outputs=[caption_out, status_out],
)
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, css=css, prevent_thread_lock=True)