Hug0endob commited on
Commit
cd43691
·
verified ·
1 Parent(s): ccc20d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -45
app.py CHANGED
@@ -1,6 +1,7 @@
1
- # app.py – Gradio 6+ (CPU) SSE / SSR-safe
2
 
3
  import base64
 
4
  import logging
5
  import threading
6
  import time
@@ -19,42 +20,79 @@ from transformers import (
19
  T5Tokenizer,
20
  )
21
 
 
 
 
 
 
 
22
  logging.basicConfig(level=logging.INFO)
23
 
24
  device = torch.device("cpu")
25
 
 
 
 
26
  IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
27
  TXT_MODEL = "t5-small"
28
 
 
 
 
29
  processor = ViTImageProcessor.from_pretrained(IMG_MODEL)
30
  tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL)
31
- vision = VisionEncoderDecoderModel.from_pretrained(IMG_MODEL).to(device).eval()
32
 
 
 
 
 
 
 
 
33
  rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
34
- rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
 
 
 
 
35
 
 
 
 
36
 
 
 
 
37
  def load_image(url: str):
 
38
  try:
39
  url = (url or "").strip()
40
  if not url:
41
  return None, "No URL provided."
 
 
42
  if url.startswith("data:"):
43
  _, data = url.split(",", 1)
44
  img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
45
  return img, None
 
 
46
  if not urllib.parse.urlsplit(url).scheme:
47
  return None, "Missing http/https scheme."
48
- r = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"})
49
- r.raise_for_status()
50
- return Image.open(BytesIO(r.content)).convert("RGB"), None
51
- except Exception as e:
52
- return None, f"Load error: {e}"
 
 
53
 
54
 
55
  def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
 
56
  inputs = processor(images=img, return_tensors="pt")
57
  pix = inputs.pixel_values.to(device)
 
58
  if sample:
59
  out = vision.generate(
60
  pix,
@@ -74,22 +112,26 @@ def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
74
  num_return_sequences=min(3, beams),
75
  early_stopping=True,
76
  )
77
- caps = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
78
- return max(caps, key=lambda s: len(s.split()))
 
79
 
80
 
81
  def expand_caption(base: str, prompt: str = None, max_len=160):
82
- if prompt and prompt.strip():
83
- instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
84
- else:
85
- instr = f"Expand with rich visual detail. Caption: \"{base}\""
 
 
86
  toks = rewriter_tok(
87
- instr,
88
  return_tensors="pt",
89
  truncation=True,
90
  padding="max_length",
91
  max_length=256,
92
  ).to(device)
 
93
  out = rewriter.generate(
94
  **toks,
95
  max_length=max_len,
@@ -100,67 +142,101 @@ def expand_caption(base: str, prompt: str = None, max_len=160):
100
  return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
101
 
102
 
103
- def async_expand(base, prompt, max_len, status):
 
104
  try:
105
- status["text"] = "Expanding…"
106
- time.sleep(0.1)
107
  result = expand_caption(base, prompt, max_len)
108
- status["text"] = "Done"
109
- status["final"] = result
110
- except Exception as e:
111
- status["text"] = f"Error: {e}"
112
- status["final"] = base
113
 
114
 
 
 
 
115
  def fast_describe(url, prompt, detail, beams, sample):
 
116
  img, err = load_image(url)
117
  if err:
118
  return None, "", err
 
119
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
120
  max_expand = detail_map.get(detail, 140)
 
121
  base = generate_base(img, beams=beams, sample=sample)
 
 
122
  status = {"text": "Queued…", "final": ""}
123
- threading.Thread(target=async_expand, args=(base, prompt, max_expand, status), daemon=True).start()
 
 
 
 
 
 
 
124
  return img, base, status["text"]
125
 
126
 
127
  def final_caption(url, prompt, detail, beams, sample):
 
128
  img, err = load_image(url)
129
  if err:
130
  return "", err
 
131
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
132
  max_expand = detail_map.get(detail, 140)
 
133
  base = generate_base(img, beams=beams, sample=sample)
134
  try:
135
  final = expand_caption(base, prompt, max_expand)
136
  return final, "Done"
137
- except Exception as e:
138
- return base, f"Expand error: {e}"
139
 
140
 
 
 
 
141
  css = "footer {display:none !important;}"
142
- with gr.Blocks(title="Image Describer (CPU)") as demo:
143
- gr.Markdown("## Image Describer")
 
144
  with gr.Row():
 
145
  with gr.Column():
146
  url_in = gr.Textbox(label="Image URL / data‑URL")
147
  prompt_in = gr.Textbox(label="Optional prompt")
148
- detail_in = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Detail level")
 
 
149
  beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
150
- sample_in = gr.Checkbox(label="Enable sampling (more diverse)", value=False)
 
 
151
  go_btn = gr.Button("Load & Describe (fast)")
152
  final_btn = gr.Button("Get final caption (detailed)")
153
  status_out = gr.Textbox(label="Status", interactive=False)
 
 
154
  with gr.Column():
155
  img_out = gr.Image(type="pil", label="Image")
 
 
156
  with gr.Column():
157
  caption_out = gr.Textbox(label="Caption", lines=8)
158
 
 
159
  go_btn.click(
160
  fn=fast_describe,
161
  inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
162
  outputs=[img_out, caption_out, status_out],
163
  )
 
 
164
  final_btn.click(
165
  fn=final_caption,
166
  inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
@@ -168,18 +244,5 @@ with gr.Blocks(title="Image Describer (CPU)") as demo:
168
  )
169
 
170
  if __name__ == "__main__":
171
- demo.queue()
172
- try:
173
- demo.launch(
174
- server_name="0.0.0.0",
175
- server_port=7860,
176
- css=css,
177
- prevent_thread_lock=True,
178
- ssr_mode=False, # disable server-side rendering (avoids SSE/SSR SSE issues)
179
- share=False,
180
- )
181
- except Exception as e:
182
- logging.exception("Launch failed")
183
- with open("/tmp/gradio_launch_err.txt", "w") as fh:
184
- fh.write(str(e))
185
- raise
 
1
+ # app.py – Gradio6+ (CPU‑only) safe for limited sandbox resources
2
 
3
  import base64
4
+ import gc
5
  import logging
6
  import threading
7
  import time
 
20
  T5Tokenizer,
21
  )
22
 
23
+ # -------------------------------------------------
24
+ # Runtime limits (sandbox‑friendly)
25
+ # -------------------------------------------------
26
+ torch.set_num_threads(1) # one CPU thread
27
+ torch.set_num_interop_threads(1) # one inter‑op thread
28
+ torch.set_grad_enabled(False) # inference‑only
29
  logging.basicConfig(level=logging.INFO)
30
 
31
  device = torch.device("cpu")
32
 
33
+ # -------------------------------------------------
34
+ # Model loading (fp16 only when a GPU is present)
35
+ # -------------------------------------------------
36
  IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
37
  TXT_MODEL = "t5-small"
38
 
39
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
40
+
41
+ # Vision‑caption model
42
  processor = ViTImageProcessor.from_pretrained(IMG_MODEL)
43
  tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL)
 
44
 
45
+ vision = (
46
+ VisionEncoderDecoderModel.from_pretrained(IMG_MODEL, torch_dtype=dtype)
47
+ .to(device)
48
+ .eval()
49
+ )
50
+
51
+ # Text‑rewriter model
52
  rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
53
+ rewriter = (
54
+ T5ForConditionalGeneration.from_pretrained(TXT_MODEL, torch_dtype=dtype)
55
+ .to(device)
56
+ .eval()
57
+ )
58
 
59
+ # Release any temporary download buffers
60
+ gc.collect()
61
+ torch.cuda.empty_cache() # no‑op on CPU, kept for symmetry
62
 
63
+ # -------------------------------------------------
64
+ # Helper utilities
65
+ # -------------------------------------------------
66
  def load_image(url: str):
67
+ """Fetch an image from a URL or a data‑URL."""
68
  try:
69
  url = (url or "").strip()
70
  if not url:
71
  return None, "No URL provided."
72
+
73
+ # data‑URL (base64‑encoded image)
74
  if url.startswith("data:"):
75
  _, data = url.split(",", 1)
76
  img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
77
  return img, None
78
+
79
+ # normal HTTP/HTTPS URL
80
  if not urllib.parse.urlsplit(url).scheme:
81
  return None, "Missing http/https scheme."
82
+
83
+ resp = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"})
84
+ resp.raise_for_status()
85
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
86
+ return img, None
87
+ except Exception as exc:
88
+ return None, f"Load error: {exc}"
89
 
90
 
91
  def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
92
+ """Create a short caption with the vision model."""
93
  inputs = processor(images=img, return_tensors="pt")
94
  pix = inputs.pixel_values.to(device)
95
+
96
  if sample:
97
  out = vision.generate(
98
  pix,
 
112
  num_return_sequences=min(3, beams),
113
  early_stopping=True,
114
  )
115
+ captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
116
+ # pick the longest (usually the most complete) caption
117
+ return max(captions, key=lambda s: len(s.split()))
118
 
119
 
120
  def expand_caption(base: str, prompt: str = None, max_len=160):
121
+ """Rewrite/expand the base caption with the T5 model."""
122
+ instruction = (
123
+ f"Expand using: '{prompt}'. Caption: \"{base}\""
124
+ if prompt and prompt.strip()
125
+ else f"Expand with rich visual detail. Caption: \"{base}\""
126
+ )
127
  toks = rewriter_tok(
128
+ instruction,
129
  return_tensors="pt",
130
  truncation=True,
131
  padding="max_length",
132
  max_length=256,
133
  ).to(device)
134
+
135
  out = rewriter.generate(
136
  **toks,
137
  max_length=max_len,
 
142
  return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
143
 
144
 
145
+ def async_expand(base, prompt, max_len, status_dict):
146
+ """Background thread that runs the expansion and updates status."""
147
  try:
148
+ status_dict["text"] = "Expanding…"
 
149
  result = expand_caption(base, prompt, max_len)
150
+ status_dict["final"] = result
151
+ status_dict["text"] = "Done"
152
+ except Exception as exc:
153
+ status_dict["text"] = f"Error: {exc}"
154
+ status_dict["final"] = base
155
 
156
 
157
+ # -------------------------------------------------
158
+ # Gradio callbacks
159
+ # -------------------------------------------------
160
  def fast_describe(url, prompt, detail, beams, sample):
161
+ """Quick path – returns image, short caption and a transient status."""
162
  img, err = load_image(url)
163
  if err:
164
  return None, "", err
165
+
166
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
167
  max_expand = detail_map.get(detail, 140)
168
+
169
  base = generate_base(img, beams=beams, sample=sample)
170
+
171
+ # status is a mutable dict that the UI can read later
172
  status = {"text": "Queued…", "final": ""}
173
+
174
+ threading.Thread(
175
+ target=async_expand,
176
+ args=(base, prompt, max_expand, status),
177
+ daemon=True,
178
+ ).start()
179
+
180
+ # The UI will poll `status_out` to see the final text later
181
  return img, base, status["text"]
182
 
183
 
184
  def final_caption(url, prompt, detail, beams, sample):
185
+ """Blocking path – returns the fully expanded caption."""
186
  img, err = load_image(url)
187
  if err:
188
  return "", err
189
+
190
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
191
  max_expand = detail_map.get(detail, 140)
192
+
193
  base = generate_base(img, beams=beams, sample=sample)
194
  try:
195
  final = expand_caption(base, prompt, max_expand)
196
  return final, "Done"
197
+ except Exception as exc:
198
+ return base, f"Expand error: {exc}"
199
 
200
 
201
+ # -------------------------------------------------
202
+ # UI layout
203
+ # -------------------------------------------------
204
  css = "footer {display:none !important;}"
205
+ with gr.Blocks(title="Image Describer (CPU‑only)", css=css) as demo:
206
+ gr.Markdown("## Image Describer (CPU‑only)")
207
+
208
  with gr.Row():
209
+ # ---- Left column – inputs ----
210
  with gr.Column():
211
  url_in = gr.Textbox(label="Image URL / data‑URL")
212
  prompt_in = gr.Textbox(label="Optional prompt")
213
+ detail_in = gr.Radio(
214
+ ["Low", "Medium", "High"], value="Medium", label="Detail level"
215
+ )
216
  beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
217
+ sample_in = gr.Checkbox(
218
+ label="Enable sampling (more diverse)", value=False
219
+ )
220
  go_btn = gr.Button("Load & Describe (fast)")
221
  final_btn = gr.Button("Get final caption (detailed)")
222
  status_out = gr.Textbox(label="Status", interactive=False)
223
+
224
+ # ---- Middle column – image preview ----
225
  with gr.Column():
226
  img_out = gr.Image(type="pil", label="Image")
227
+
228
+ # ---- Right column – caption output ----
229
  with gr.Column():
230
  caption_out = gr.Textbox(label="Caption", lines=8)
231
 
232
+ # Fast path: returns image + short caption immediately
233
  go_btn.click(
234
  fn=fast_describe,
235
  inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
236
  outputs=[img_out, caption_out, status_out],
237
  )
238
+
239
+ # Detailed path: blocks until the expanded caption is ready
240
  final_btn.click(
241
  fn=final_caption,
242
  inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
 
244
  )
245
 
246
  if __name__ == "__main__":
247
+ demo.queue() # enables request queuing (helps with sandbox limits)
248
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)