# app.py – Gradio 6+ (CPU‑only) – safe for limited sandbox resources import base64 import gc import logging import threading import time import urllib.parse from io import BytesIO import gradio as gr import requests import torch from PIL import Image from transformers import ( AutoTokenizer, VisionEncoderDecoderModel, ViTImageProcessor, T5ForConditionalGeneration, T5Tokenizer, ) # ------------------------------------------------- # Runtime limits (sandbox‑friendly) # ------------------------------------------------- torch.set_num_threads(1) # one CPU thread torch.set_num_interop_threads(1) # one inter‑op thread torch.set_grad_enabled(False) # inference‑only logging.basicConfig(level=logging.INFO) device = torch.device("cpu") # ------------------------------------------------- # Model loading (fp16 only when a GPU is present) # ------------------------------------------------- IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning" TXT_MODEL = "t5-small" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Vision‑caption model processor = ViTImageProcessor.from_pretrained(IMG_MODEL) tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL) vision = ( VisionEncoderDecoderModel.from_pretrained(IMG_MODEL, torch_dtype=dtype) .to(device) .eval() ) # Text‑rewriter model rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL) rewriter = ( T5ForConditionalGeneration.from_pretrained(TXT_MODEL, torch_dtype=dtype) .to(device) .eval() ) # Release any temporary download buffers gc.collect() torch.cuda.empty_cache() # no‑op on CPU, kept for symmetry # ------------------------------------------------- # Helper utilities # ------------------------------------------------- def load_image(url: str): """Fetch an image from a URL or a data‑URL.""" try: url = (url or "").strip() if not url: return None, "No URL provided." # data‑URL (base64‑encoded image) if url.startswith("data:"): _, data = url.split(",", 1) img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB") return img, None # normal HTTP/HTTPS URL if not urllib.parse.urlsplit(url).scheme: return None, "Missing http/https scheme." resp = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"}) resp.raise_for_status() img = Image.open(BytesIO(resp.content)).convert("RGB") return img, None except Exception as exc: return None, f"Load error: {exc}" def generate_base(img: Image.Image, max_len=40, beams=2, sample=False): """Create a short caption with the vision model.""" inputs = processor(images=img, return_tensors="pt") pix = inputs.pixel_values.to(device) if sample: out = vision.generate( pix, max_length=max_len, do_sample=True, temperature=0.8, top_k=50, top_p=0.9, num_return_sequences=3, early_stopping=True, ) else: out = vision.generate( pix, max_length=max_len, num_beams=beams, num_return_sequences=min(3, beams), early_stopping=True, ) captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out] # pick the longest (usually the most complete) caption return max(captions, key=lambda s: len(s.split())) def expand_caption(base: str, prompt: str = None, max_len=160): """Rewrite/expand the base caption with the T5 model.""" instruction = ( f"Expand using: '{prompt}'. Caption: \"{base}\"" if prompt and prompt.strip() else f"Expand with rich visual detail. Caption: \"{base}\"" ) toks = rewriter_tok( instruction, return_tensors="pt", truncation=True, padding="max_length", max_length=256, ).to(device) out = rewriter.generate( **toks, max_length=max_len, num_beams=4, early_stopping=True, no_repeat_ngram_size=3, ) return rewriter_tok.decode(out[0], skip_special_tokens=True).strip() def async_expand(base, prompt, max_len, status_dict): """Background thread that runs the expansion and updates status.""" try: status_dict["text"] = "Expanding…" result = expand_caption(base, prompt, max_len) status_dict["final"] = result status_dict["text"] = "Done" except Exception as exc: status_dict["text"] = f"Error: {exc}" status_dict["final"] = base # ------------------------------------------------- # Gradio callbacks # ------------------------------------------------- def fast_describe(url, prompt, detail, beams, sample): """Quick path – returns image, short caption and a transient status.""" img, err = load_image(url) if err: return None, "", err detail_map = {"Low": 80, "Medium": 140, "High": 220} max_expand = detail_map.get(detail, 140) base = generate_base(img, beams=beams, sample=sample) # status is a mutable dict that the UI can read later status = {"text": "Queued…", "final": ""} threading.Thread( target=async_expand, args=(base, prompt, max_expand, status), daemon=True, ).start() # The UI will poll `status_out` to see the final text later return img, base, status["text"] def final_caption(url, prompt, detail, beams, sample): """Blocking path – returns the fully expanded caption.""" img, err = load_image(url) if err: return "", err detail_map = {"Low": 80, "Medium": 140, "High": 220} max_expand = detail_map.get(detail, 140) base = generate_base(img, beams=beams, sample=sample) try: final = expand_caption(base, prompt, max_expand) return final, "Done" except Exception as exc: return base, f"Expand error: {exc}" # ------------------------------------------------- # UI layout # ------------------------------------------------- css = "footer {display:none !important;}" with gr.Blocks(title="Image Describer (CPU‑only)", css=css) as demo: gr.Markdown("## Image Describer (CPU‑only)") with gr.Row(): # ---- Left column – inputs ---- with gr.Column(): url_in = gr.Textbox(label="Image URL / data‑URL") prompt_in = gr.Textbox(label="Optional prompt") detail_in = gr.Radio( ["Low", "Medium", "High"], value="Medium", label="Detail level" ) beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams") sample_in = gr.Checkbox( label="Enable sampling (more diverse)", value=False ) go_btn = gr.Button("Load & Describe (fast)") final_btn = gr.Button("Get final caption (detailed)") status_out = gr.Textbox(label="Status", interactive=False) # ---- Middle column – image preview ---- with gr.Column(): img_out = gr.Image(type="pil", label="Image") # ---- Right column – caption output ---- with gr.Column(): caption_out = gr.Textbox(label="Caption", lines=8) # Fast path: returns image + short caption immediately go_btn.click( fn=fast_describe, inputs=[url_in, prompt_in, detail_in, beams_in, sample_in], outputs=[img_out, caption_out, status_out], ) # Detailed path: blocks until the expanded caption is ready final_btn.click( fn=final_caption, inputs=[url_in, prompt_in, detail_in, beams_in, sample_in], outputs=[caption_out, status_out], ) if __name__ == "__main__": demo.queue() # enables request queuing (helps with sandbox limits) demo.launch(server_name="0.0.0.0", server_port=7860, share=False)