Spaces:
Runtime error
Runtime error
| # app.py – Gradio 6+ (CPU) | |
| import base64 | |
| import threading | |
| import time | |
| import urllib.parse | |
| from io import BytesIO | |
| import gradio as gr | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| AutoTokenizer, | |
| VisionEncoderDecoderModel, | |
| ViTImageProcessor, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer, | |
| ) | |
| device = torch.device("cpu") | |
| IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning" | |
| TXT_MODEL = "t5-small" | |
| processor = ViTImageProcessor.from_pretrained(IMG_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL) | |
| vision = VisionEncoderDecoderModel.from_pretrained(IMG_MODEL).to(device).eval() | |
| rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL) | |
| rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval() | |
| def load_image(url: str): | |
| """Return (PIL.Image, None) or (None, error). Handles http/https and data‑URL.""" | |
| try: | |
| url = (url or "").strip() | |
| if not url: | |
| return None, "No URL provided." | |
| if url.startswith("data:"): | |
| _, data = url.split(",", 1) | |
| img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB") | |
| return img, None | |
| if not urllib.parse.urlsplit(url).scheme: | |
| return None, "Missing http/https scheme." | |
| r = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"}) | |
| r.raise_for_status() | |
| return Image.open(BytesIO(r.content)).convert("RGB"), None | |
| except Exception as e: | |
| return None, f"Load error: {e}" | |
| def generate_base(img: Image.Image, max_len=40, beams=2, sample=False): | |
| inputs = processor(images=img, return_tensors="pt") | |
| pix = inputs.pixel_values.to(device) | |
| if sample: | |
| out = vision.generate( | |
| pix, | |
| max_length=max_len, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9, | |
| num_return_sequences=3, | |
| early_stopping=True, | |
| ) | |
| else: | |
| out = vision.generate( | |
| pix, | |
| max_length=max_len, | |
| num_beams=beams, | |
| num_return_sequences=min(3, beams), | |
| early_stopping=True, | |
| ) | |
| caps = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out] | |
| return max(caps, key=lambda s: len(s.split())) | |
| def expand_caption(base: str, prompt: str = None, max_len=160): | |
| if prompt and prompt.strip(): | |
| instr = f"Expand using: '{prompt}'. Caption: \"{base}\"" | |
| else: | |
| instr = f"Expand with rich visual detail. Caption: \"{base}\"" | |
| toks = rewriter_tok( | |
| instr, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=256, | |
| ).to(device) | |
| out = rewriter.generate( | |
| **toks, | |
| max_length=max_len, | |
| num_beams=4, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| ) | |
| return rewriter_tok.decode(out[0], skip_special_tokens=True).strip() | |
| def async_expand(base, prompt, max_len, status): | |
| try: | |
| status["text"] = "Expanding…" | |
| time.sleep(0.1) | |
| result = expand_caption(base, prompt, max_len) | |
| status["text"] = "Done" | |
| status["final"] = result | |
| except Exception as e: | |
| status["text"] = f"Error: {e}" | |
| status["final"] = base | |
| def fast_describe(url, prompt, detail, beams, sample): | |
| img, err = load_image(url) | |
| if err: | |
| return None, "", err | |
| detail_map = {"Low": 80, "Medium": 140, "High": 220} | |
| max_expand = detail_map.get(detail, 140) | |
| base = generate_base(img, beams=beams, sample=sample) | |
| status = {"text": "Queued…", "final": ""} | |
| threading.Thread(target=async_expand, args=(base, prompt, max_expand, status), daemon=True).start() | |
| return img, base, status["text"] | |
| def final_caption(url, prompt, detail, beams, sample): | |
| img, err = load_image(url) | |
| if err: | |
| return "", err | |
| detail_map = {"Low": 80, "Medium": 140, "High": 220} | |
| max_expand = detail_map.get(detail, 140) | |
| base = generate_base(img, beams=beams, sample=sample) | |
| try: | |
| final = expand_caption(base, prompt, max_expand) | |
| return final, "Done" | |
| except Exception as e: | |
| return base, f"Expand error: {e}" | |
| css = "footer {display:none !important;}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Image Describer") | |
| with gr.Row(): | |
| with gr.Column(): | |
| url_in = gr.Textbox(label="Image URL / data‑URL") | |
| prompt_in = gr.Textbox(label="Optional prompt") | |
| detail_in = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Detail level") | |
| beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams") | |
| sample_in = gr.Checkbox(label="Enable sampling (more diverse)", value=False) | |
| go_btn = gr.Button("Load & Describe (fast)") | |
| final_btn = gr.Button("Get final caption (detailed)") | |
| status_out = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(): | |
| img_out = gr.Image(type="pil", label="Image") | |
| with gr.Column(): | |
| caption_out = gr.Textbox(label="Caption", lines=8) | |
| go_btn.click( | |
| fn=fast_describe, | |
| inputs=[url_in, prompt_in, detail_in, beams_in, sample_in], | |
| outputs=[img_out, caption_out, status_out], | |
| ) | |
| final_btn.click( | |
| fn=final_caption, | |
| inputs=[url_in, prompt_in, detail_in, beams_in, sample_in], | |
| outputs=[caption_out, status_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, css=css, prevent_thread_lock=True) | |