Spaces:
Runtime error
Runtime error
| # app.py – Gradio 6+ (CPU‑only) – safe for limited sandbox resources | |
| import base64 | |
| import gc | |
| import logging | |
| import threading | |
| import time | |
| import urllib.parse | |
| from io import BytesIO | |
| import gradio as gr | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| AutoTokenizer, | |
| VisionEncoderDecoderModel, | |
| ViTImageProcessor, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer, | |
| ) | |
| # ------------------------------------------------- | |
| # Runtime limits (sandbox‑friendly) | |
| # ------------------------------------------------- | |
| torch.set_num_threads(1) # one CPU thread | |
| torch.set_num_interop_threads(1) # one inter‑op thread | |
| torch.set_grad_enabled(False) # inference‑only | |
| logging.basicConfig(level=logging.INFO) | |
| device = torch.device("cpu") | |
| # ------------------------------------------------- | |
| # Model loading (fp16 only when a GPU is present) | |
| # ------------------------------------------------- | |
| IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning" | |
| TXT_MODEL = "t5-small" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Vision‑caption model | |
| processor = ViTImageProcessor.from_pretrained(IMG_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL) | |
| vision = ( | |
| VisionEncoderDecoderModel.from_pretrained(IMG_MODEL, torch_dtype=dtype) | |
| .to(device) | |
| .eval() | |
| ) | |
| # Text‑rewriter model | |
| rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL) | |
| rewriter = ( | |
| T5ForConditionalGeneration.from_pretrained(TXT_MODEL, torch_dtype=dtype) | |
| .to(device) | |
| .eval() | |
| ) | |
| # Release any temporary download buffers | |
| gc.collect() | |
| torch.cuda.empty_cache() # no‑op on CPU, kept for symmetry | |
| # ------------------------------------------------- | |
| # Helper utilities | |
| # ------------------------------------------------- | |
| def load_image(url: str): | |
| """Fetch an image from a URL or a data‑URL.""" | |
| try: | |
| url = (url or "").strip() | |
| if not url: | |
| return None, "No URL provided." | |
| # data‑URL (base64‑encoded image) | |
| if url.startswith("data:"): | |
| _, data = url.split(",", 1) | |
| img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB") | |
| return img, None | |
| # normal HTTP/HTTPS URL | |
| if not urllib.parse.urlsplit(url).scheme: | |
| return None, "Missing http/https scheme." | |
| resp = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"}) | |
| resp.raise_for_status() | |
| img = Image.open(BytesIO(resp.content)).convert("RGB") | |
| return img, None | |
| except Exception as exc: | |
| return None, f"Load error: {exc}" | |
| def generate_base(img: Image.Image, max_len=40, beams=2, sample=False): | |
| """Create a short caption with the vision model.""" | |
| inputs = processor(images=img, return_tensors="pt") | |
| pix = inputs.pixel_values.to(device) | |
| if sample: | |
| out = vision.generate( | |
| pix, | |
| max_length=max_len, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9, | |
| num_return_sequences=3, | |
| early_stopping=True, | |
| ) | |
| else: | |
| out = vision.generate( | |
| pix, | |
| max_length=max_len, | |
| num_beams=beams, | |
| num_return_sequences=min(3, beams), | |
| early_stopping=True, | |
| ) | |
| captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out] | |
| # pick the longest (usually the most complete) caption | |
| return max(captions, key=lambda s: len(s.split())) | |
| def expand_caption(base: str, prompt: str = None, max_len=160): | |
| """Rewrite/expand the base caption with the T5 model.""" | |
| instruction = ( | |
| f"Expand using: '{prompt}'. Caption: \"{base}\"" | |
| if prompt and prompt.strip() | |
| else f"Expand with rich visual detail. Caption: \"{base}\"" | |
| ) | |
| toks = rewriter_tok( | |
| instruction, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=256, | |
| ).to(device) | |
| out = rewriter.generate( | |
| **toks, | |
| max_length=max_len, | |
| num_beams=4, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| ) | |
| return rewriter_tok.decode(out[0], skip_special_tokens=True).strip() | |
| def async_expand(base, prompt, max_len, status_dict): | |
| """Background thread that runs the expansion and updates status.""" | |
| try: | |
| status_dict["text"] = "Expanding…" | |
| result = expand_caption(base, prompt, max_len) | |
| status_dict["final"] = result | |
| status_dict["text"] = "Done" | |
| except Exception as exc: | |
| status_dict["text"] = f"Error: {exc}" | |
| status_dict["final"] = base | |
| # ------------------------------------------------- | |
| # Gradio callbacks | |
| # ------------------------------------------------- | |
| def fast_describe(url, prompt, detail, beams, sample): | |
| """Quick path – returns image, short caption and a transient status.""" | |
| img, err = load_image(url) | |
| if err: | |
| return None, "", err | |
| detail_map = {"Low": 80, "Medium": 140, "High": 220} | |
| max_expand = detail_map.get(detail, 140) | |
| base = generate_base(img, beams=beams, sample=sample) | |
| # status is a mutable dict that the UI can read later | |
| status = {"text": "Queued…", "final": ""} | |
| threading.Thread( | |
| target=async_expand, | |
| args=(base, prompt, max_expand, status), | |
| daemon=True, | |
| ).start() | |
| # The UI will poll `status_out` to see the final text later | |
| return img, base, status["text"] | |
| def final_caption(url, prompt, detail, beams, sample): | |
| """Blocking path – returns the fully expanded caption.""" | |
| img, err = load_image(url) | |
| if err: | |
| return "", err | |
| detail_map = {"Low": 80, "Medium": 140, "High": 220} | |
| max_expand = detail_map.get(detail, 140) | |
| base = generate_base(img, beams=beams, sample=sample) | |
| try: | |
| final = expand_caption(base, prompt, max_expand) | |
| return final, "Done" | |
| except Exception as exc: | |
| return base, f"Expand error: {exc}" | |
| # ------------------------------------------------- | |
| # UI layout | |
| # ------------------------------------------------- | |
| css = "footer {display:none !important;}" | |
| with gr.Blocks(title="Image Describer (CPU‑only)", css=css) as demo: | |
| gr.Markdown("## Image Describer (CPU‑only)") | |
| with gr.Row(): | |
| # ---- Left column – inputs ---- | |
| with gr.Column(): | |
| url_in = gr.Textbox(label="Image URL / data‑URL") | |
| prompt_in = gr.Textbox(label="Optional prompt") | |
| detail_in = gr.Radio( | |
| ["Low", "Medium", "High"], value="Medium", label="Detail level" | |
| ) | |
| beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams") | |
| sample_in = gr.Checkbox( | |
| label="Enable sampling (more diverse)", value=False | |
| ) | |
| go_btn = gr.Button("Load & Describe (fast)") | |
| final_btn = gr.Button("Get final caption (detailed)") | |
| status_out = gr.Textbox(label="Status", interactive=False) | |
| # ---- Middle column – image preview ---- | |
| with gr.Column(): | |
| img_out = gr.Image(type="pil", label="Image") | |
| # ---- Right column – caption output ---- | |
| with gr.Column(): | |
| caption_out = gr.Textbox(label="Caption", lines=8) | |
| # Fast path: returns image + short caption immediately | |
| go_btn.click( | |
| fn=fast_describe, | |
| inputs=[url_in, prompt_in, detail_in, beams_in, sample_in], | |
| outputs=[img_out, caption_out, status_out], | |
| ) | |
| # Detailed path: blocks until the expanded caption is ready | |
| final_btn.click( | |
| fn=final_caption, | |
| inputs=[url_in, prompt_in, detail_in, beams_in, sample_in], | |
| outputs=[caption_out, status_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() # enables request queuing (helps with sandbox limits) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |