Spaces:
Runtime error
Runtime error
File size: 7,949 Bytes
cd43691 63ffe59 cd43691 ccc20d8 63ffe59 5eab5a1 63ffe59 5eab5a1 63ffe59 4462cb3 63ffe59 4462cb3 5f87827 cd43691 ccc20d8 5f87827 5eab5a1 cd43691 516d7c2 9bccfcb cd43691 516d7c2 5f87827 cd43691 516d7c2 cd43691 5eab5a1 cd43691 63ffe59 cd43691 516d7c2 cd43691 5eab5a1 7d16cdf cd43691 36cfd26 516d7c2 36cfd26 cd43691 516d7c2 cd43691 5eab5a1 63ffe59 516d7c2 cd43691 4462cb3 516d7c2 cd43691 516d7c2 cd43691 63ffe59 516d7c2 cd43691 516d7c2 cd43691 516d7c2 cd43691 516d7c2 63ffe59 cd43691 9bccfcb cd43691 516d7c2 cd43691 516d7c2 63ffe59 cd43691 516d7c2 cd43691 516d7c2 9bccfcb 516d7c2 cd43691 9bccfcb 516d7c2 cd43691 516d7c2 cd43691 7d16cdf cd43691 9bccfcb 63ffe59 516d7c2 cd43691 516d7c2 5eab5a1 516d7c2 cd43691 9bccfcb 516d7c2 cd43691 516d7c2 9bccfcb 516d7c2 cd43691 5eab5a1 63ffe59 cd43691 516d7c2 cd43691 5eab5a1 cd43691 516d7c2 cd43691 049393b cd43691 516d7c2 cd43691 516d7c2 d914e3a cd43691 516d7c2 cd43691 516d7c2 cd43691 516d7c2 5eab5a1 cd43691 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 | # app.py – Gradio 6+ (CPU‑only) – safe for limited sandbox resources
import base64
import gc
import logging
import threading
import time
import urllib.parse
from io import BytesIO
import gradio as gr
import requests
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
VisionEncoderDecoderModel,
ViTImageProcessor,
T5ForConditionalGeneration,
T5Tokenizer,
)
# -------------------------------------------------
# Runtime limits (sandbox‑friendly)
# -------------------------------------------------
torch.set_num_threads(1) # one CPU thread
torch.set_num_interop_threads(1) # one inter‑op thread
torch.set_grad_enabled(False) # inference‑only
logging.basicConfig(level=logging.INFO)
device = torch.device("cpu")
# -------------------------------------------------
# Model loading (fp16 only when a GPU is present)
# -------------------------------------------------
IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
TXT_MODEL = "t5-small"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Vision‑caption model
processor = ViTImageProcessor.from_pretrained(IMG_MODEL)
tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL)
vision = (
VisionEncoderDecoderModel.from_pretrained(IMG_MODEL, torch_dtype=dtype)
.to(device)
.eval()
)
# Text‑rewriter model
rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
rewriter = (
T5ForConditionalGeneration.from_pretrained(TXT_MODEL, torch_dtype=dtype)
.to(device)
.eval()
)
# Release any temporary download buffers
gc.collect()
torch.cuda.empty_cache() # no‑op on CPU, kept for symmetry
# -------------------------------------------------
# Helper utilities
# -------------------------------------------------
def load_image(url: str):
"""Fetch an image from a URL or a data‑URL."""
try:
url = (url or "").strip()
if not url:
return None, "No URL provided."
# data‑URL (base64‑encoded image)
if url.startswith("data:"):
_, data = url.split(",", 1)
img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
return img, None
# normal HTTP/HTTPS URL
if not urllib.parse.urlsplit(url).scheme:
return None, "Missing http/https scheme."
resp = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"})
resp.raise_for_status()
img = Image.open(BytesIO(resp.content)).convert("RGB")
return img, None
except Exception as exc:
return None, f"Load error: {exc}"
def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
"""Create a short caption with the vision model."""
inputs = processor(images=img, return_tensors="pt")
pix = inputs.pixel_values.to(device)
if sample:
out = vision.generate(
pix,
max_length=max_len,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.9,
num_return_sequences=3,
early_stopping=True,
)
else:
out = vision.generate(
pix,
max_length=max_len,
num_beams=beams,
num_return_sequences=min(3, beams),
early_stopping=True,
)
captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
# pick the longest (usually the most complete) caption
return max(captions, key=lambda s: len(s.split()))
def expand_caption(base: str, prompt: str = None, max_len=160):
"""Rewrite/expand the base caption with the T5 model."""
instruction = (
f"Expand using: '{prompt}'. Caption: \"{base}\""
if prompt and prompt.strip()
else f"Expand with rich visual detail. Caption: \"{base}\""
)
toks = rewriter_tok(
instruction,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=256,
).to(device)
out = rewriter.generate(
**toks,
max_length=max_len,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
)
return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
def async_expand(base, prompt, max_len, status_dict):
"""Background thread that runs the expansion and updates status."""
try:
status_dict["text"] = "Expanding…"
result = expand_caption(base, prompt, max_len)
status_dict["final"] = result
status_dict["text"] = "Done"
except Exception as exc:
status_dict["text"] = f"Error: {exc}"
status_dict["final"] = base
# -------------------------------------------------
# Gradio callbacks
# -------------------------------------------------
def fast_describe(url, prompt, detail, beams, sample):
"""Quick path – returns image, short caption and a transient status."""
img, err = load_image(url)
if err:
return None, "", err
detail_map = {"Low": 80, "Medium": 140, "High": 220}
max_expand = detail_map.get(detail, 140)
base = generate_base(img, beams=beams, sample=sample)
# status is a mutable dict that the UI can read later
status = {"text": "Queued…", "final": ""}
threading.Thread(
target=async_expand,
args=(base, prompt, max_expand, status),
daemon=True,
).start()
# The UI will poll `status_out` to see the final text later
return img, base, status["text"]
def final_caption(url, prompt, detail, beams, sample):
"""Blocking path – returns the fully expanded caption."""
img, err = load_image(url)
if err:
return "", err
detail_map = {"Low": 80, "Medium": 140, "High": 220}
max_expand = detail_map.get(detail, 140)
base = generate_base(img, beams=beams, sample=sample)
try:
final = expand_caption(base, prompt, max_expand)
return final, "Done"
except Exception as exc:
return base, f"Expand error: {exc}"
# -------------------------------------------------
# UI layout
# -------------------------------------------------
css = "footer {display:none !important;}"
with gr.Blocks(title="Image Describer (CPU‑only)", css=css) as demo:
gr.Markdown("## Image Describer (CPU‑only)")
with gr.Row():
# ---- Left column – inputs ----
with gr.Column():
url_in = gr.Textbox(label="Image URL / data‑URL")
prompt_in = gr.Textbox(label="Optional prompt")
detail_in = gr.Radio(
["Low", "Medium", "High"], value="Medium", label="Detail level"
)
beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
sample_in = gr.Checkbox(
label="Enable sampling (more diverse)", value=False
)
go_btn = gr.Button("Load & Describe (fast)")
final_btn = gr.Button("Get final caption (detailed)")
status_out = gr.Textbox(label="Status", interactive=False)
# ---- Middle column – image preview ----
with gr.Column():
img_out = gr.Image(type="pil", label="Image")
# ---- Right column – caption output ----
with gr.Column():
caption_out = gr.Textbox(label="Caption", lines=8)
# Fast path: returns image + short caption immediately
go_btn.click(
fn=fast_describe,
inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
outputs=[img_out, caption_out, status_out],
)
# Detailed path: blocks until the expanded caption is ready
final_btn.click(
fn=final_caption,
inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
outputs=[caption_out, status_out],
)
if __name__ == "__main__":
demo.queue() # enables request queuing (helps with sandbox limits)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|