Hug0endob commited on
Commit
516d7c2
Β·
verified Β·
1 Parent(s): 9bccfcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -155
app.py CHANGED
@@ -1,7 +1,7 @@
 
1
  import gradio as gr
2
  from PIL import Image
3
- import requests
4
- from io import BytesIO
5
  import torch
6
  from transformers import (
7
  VisionEncoderDecoderModel,
@@ -10,195 +10,174 @@ from transformers import (
10
  T5ForConditionalGeneration,
11
  T5Tokenizer,
12
  )
13
- import urllib.parse
14
- import threading
15
- import time
16
 
 
 
 
17
  device = torch.device("cpu")
18
 
19
- # Model names
20
- PROCESSOR_NAME = "nlpconnect/vit-gpt2-image-captioning"
21
- REWRITER_NAME = "t5-small"
22
 
23
- # Load models (CPU)
24
- processor = ViTImageProcessor.from_pretrained(PROCESSOR_NAME)
25
- tokenizer = AutoTokenizer.from_pretrained(PROCESSOR_NAME)
26
- model = VisionEncoderDecoderModel.from_pretrained(PROCESSOR_NAME).to(device)
27
- model.eval()
28
 
29
- rewriter_tokenizer = T5Tokenizer.from_pretrained(REWRITER_NAME)
30
- rewriter = T5ForConditionalGeneration.from_pretrained(REWRITER_NAME).to(device)
31
- rewriter.eval()
32
 
33
- def load_image_from_url(url: str, timeout=10):
 
 
 
 
34
  try:
35
  url = url.strip()
36
  if url.startswith("data:"):
37
- header, encoded = url.split(",", 1)
38
  import base64
39
- data = base64.b64decode(encoded)
40
- img = Image.open(BytesIO(data)).convert("RGB")
41
  return img, None
42
- parsed = urllib.parse.urlsplit(url)
43
- if parsed.scheme == "":
44
- return None, "Invalid URL (missing scheme: http/https)."
45
- resp = requests.get(url, timeout=timeout, headers={"User-Agent": "huggingface-space/1.0"})
46
- resp.raise_for_status()
47
- img = Image.open(BytesIO(resp.content)).convert("RGB")
48
- return img, None
49
  except Exception as e:
50
- return None, f"Error loading image: {e}"
51
 
52
- # --- Generation & rewriting helpers ---
53
- def generate_caption_candidates(img: Image.Image, max_len: int = 40, num_beams: int = 2, num_return_sequences: int = 3, do_sample: bool = False):
54
  inputs = processor(images=img, return_tensors="pt")
55
- pixel_values = inputs.pixel_values.to(device)
56
-
57
- gen_kwargs = {
58
- "max_length": max_len,
59
- "num_beams": num_beams,
60
- "early_stopping": True,
61
- "do_sample": do_sample,
62
- "num_return_sequences": num_return_sequences,
63
- }
64
- # model.generate returns tensor of shape (num_return_sequences, seq_len) when requested
65
- outputs = model.generate(pixel_values, **gen_kwargs)
66
- captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs]
67
- # Deduplicate preserving order
68
- seen = set()
69
- unique = []
70
- for c in captions:
71
- if c not in seen:
72
- seen.add(c)
73
- unique.append(c)
74
- return unique
75
-
76
- def pick_most_detailed(candidates):
77
- # heuristic: prefer longer by word count, then more unique words
78
- best = max(candidates, key=lambda s: (len(s.split()), len(set(s.split()))))
79
- return best
80
-
81
- def expand_with_t5(caption: str, prompt: str = None, max_len: int = 160):
82
- # Instruction to expand and add rich visual detail
83
  if prompt and prompt.strip():
84
- instr = f"Expand and elaborate the caption using this instruction: '{prompt}'. Caption: \"{caption}\""
85
  else:
86
- instr = f"Expand and elaborate the caption with rich visual detail (objects, colors, textures, scene, actions). Caption: \"{caption}\""
87
- tok = rewriter_tokenizer(instr, return_tensors="pt", truncation=True, padding=True).to(device)
88
- out = rewriter.generate(**tok, max_length=max_len, num_beams=4, early_stopping=True, no_repeat_ngram_size=3)
89
- expanded = rewriter_tokenizer.decode(out[0], skip_special_tokens=True).strip()
90
- return expanded
91
-
92
- # Background worker pattern to run expansion and report progress
93
- def _background_expand_and_return(caption, prompt, max_expand_len, status_callback):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  try:
95
- # Inform start
96
- status_callback("Expanding caption (step 1/2)...")
97
- # Small sleep allows UI update
98
- time.sleep(0.1)
99
- expanded = expand_with_t5(caption, prompt=prompt, max_len=max_expand_len)
100
- status_callback("Finalizing (step 2/2)...")
101
  time.sleep(0.1)
102
- status_callback("Done")
103
- return expanded
 
104
  except Exception as e:
105
- status_callback(f"Error during expand: {e}")
106
- return caption
107
-
108
- # Main describe function used by Gradio; it triggers generation and then expansion in background
109
- def describe_image_controller(url: str, prompt: str, detail_level: str, max_caption_len: int = 40, beams: int = 2, do_sample: bool = True):
110
- """
111
- Returns: (img or None, caption_text, status_text)
112
- The UI will start background expansion and update status via a small helper.
113
- """
114
- img, err = load_image_from_url(url)
115
  if err:
116
- return None, "", f"Error: {err}"
117
 
118
- # Map detail_level to rewriter max_len
119
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
120
- max_expand_len = detail_map.get(detail_level, 140)
121
-
122
- # Generate candidates
123
- candidates = generate_caption_candidates(img, max_len=max_caption_len, num_beams=beams, num_return_sequences=3, do_sample=do_sample)
124
- base = pick_most_detailed(candidates)
125
 
126
- # Start background thread to expand (T5) and update status via a Gradio status element (we'll use a simple polling text)
127
- # We'll use a small mutable container to send status updates via closure
128
- status = {"text": "Queued for expansion..."}
129
- def status_callback(s):
130
- status["text"] = s
131
-
132
- result_container = {"final": base}
133
 
134
  def worker():
135
- expanded = _background_expand_and_return(base, prompt, max_expand_len, status_callback)
136
- result_container["final"] = expanded
137
-
138
- thread = threading.Thread(target=worker, daemon=True)
139
- thread.start()
140
 
141
- # Return image, initial base caption, and initial status. The frontend will poll for status/final via separate endpoints
142
  return img, base, status["text"]
143
 
144
- # Polling endpoints to retrieve status and final caption
145
- def poll_status_and_caption(url: str, prompt: str, _placeholder):
146
- # In this simple pattern we re-run a lightweight check by storing results in a global map keyed by URL+prompt
147
- # For simplicity in this Space we will re-run expansion synchronously here if needed.
148
- # But to avoid redoing heavy work, you can implement a shared cache (omitted for brevity).
149
- return "If expansion still running, refresh in a few seconds. Final caption will replace base when ready."
150
-
151
- # Simple endpoint to get final expanded caption synchronously (used when user hits 'Get final caption')
152
- def get_final_caption(url: str, prompt: str, detail_level: str, max_caption_len: int = 40, beams: int = 2, do_sample: bool = True):
153
- img, err = load_image_from_url(url)
154
  if err:
155
- return "", f"Error: {err}"
156
- candidates = generate_caption_candidates(img, max_len=max_caption_len, num_beams=beams, num_return_sequences=3, do_sample=do_sample)
157
- base = pick_most_detailed(candidates)
158
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
159
- max_expand_len = detail_map.get(detail_level, 140)
 
 
160
  try:
161
- expanded = expand_with_t5(base, prompt=prompt, max_len=max_expand_len)
162
- return expanded, "Done"
163
  except Exception as e:
164
  return base, f"Expand error: {e}"
165
 
166
- # Gradio UI
167
- css = """
168
- footer {display: none !important;}
169
- """
170
-
171
- with gr.Blocks(css=css, title="Image Describer (vit-gpt2, promptable, detailed)") as demo:
172
- gr.Markdown("## Image Describer β€” uncensored captions, optional prompt to bias description. Use 'Get final caption' for the detailed expanded output (may take longer).")
173
  with gr.Row():
174
- with gr.Column(scale=1):
175
- url_in = gr.Textbox(label="Image URL or data URL", placeholder="https://example.com/photo.jpg")
176
- prompt_in = gr.Textbox(label="Optional prompt (e.g. 'Focus on people and actions')", placeholder="Focus on people, actions, or colors.")
177
- detail_level = gr.Radio(choices=["Low", "Medium", "High"], value="Medium", label="Detail level (affects expansion length)")
178
- max_len = gr.Slider(minimum=8, maximum=80, value=40, label="Base caption max length")
179
- beams = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Num beams (higher = better quality, slower)")
180
- do_sample_chk = gr.Checkbox(label="Enable sampling (more diverse)", value=True)
181
- go = gr.Button("Load & Describe (fast)")
182
- get_final = gr.Button("Get final caption (detailed, slower)")
183
- status_txt = gr.Textbox(label="Status", value="Idle", interactive=False)
184
- with gr.Column(scale=1):
185
  img_out = gr.Image(type="pil", label="Image")
186
- with gr.Column(scale=1):
187
- caption_out = gr.Textbox(label="Caption (base or final)", lines=8)
188
-
189
- # Fast path: generate base caption and immediately start background expand (status will be approximate)
190
- def on_go(url, prompt, detail_level, max_len, beams, do_sample):
191
- img, base_caption, status = describe_image_controller(url, prompt, detail_level, max_caption_len=max_len, beams=beams, do_sample=do_sample)
192
- return img, base_caption, status
193
-
194
- go.click(fn=on_go, inputs=[url_in, prompt_in, detail_level, max_len, beams, do_sample_chk], outputs=[img_out, caption_out, status_txt])
195
-
196
- # Synchronous, explicit final result (user clicks when they want the full expanded caption)
197
- def on_get_final(url, prompt, detail_level, max_len, beams, do_sample):
198
- final_caption, status = get_final_caption(url, prompt, detail_level, max_caption_len=max_len, beams=beams, do_sample=do_sample)
199
- return final_caption, status
200
-
201
- get_final.click(fn=on_get_final, inputs=[url_in, prompt_in, detail_level, max_len, beams, do_sample_chk], outputs=[caption_out, status_txt])
202
 
203
  if __name__ == "__main__":
204
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py – minimal, CPU‑only, high‑quality captions
2
  import gradio as gr
3
  from PIL import Image
4
+ import requests, urllib.parse, threading, time
 
5
  import torch
6
  from transformers import (
7
  VisionEncoderDecoderModel,
 
10
  T5ForConditionalGeneration,
11
  T5Tokenizer,
12
  )
 
 
 
13
 
14
+ # -------------------------------------------------
15
+ # Device & models (CPU)
16
+ # -------------------------------------------------
17
  device = torch.device("cpu")
18
 
19
+ IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
20
+ TXT_MODEL = "t5-small"
 
21
 
22
+ processor = ViTImageProcessor.from_pretrained(IMG_MODEL)
23
+ tokenizer = AutoTokenizer.from_pretrained(IMG_MODEL)
24
+ vision = VisionEncoderDecoderModel.from_pretrained(IMG_MODEL).to(device).eval()
 
 
25
 
26
+ rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
27
+ rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
 
28
 
29
+ # -------------------------------------------------
30
+ # Helpers
31
+ # -------------------------------------------------
32
+ def load_image(url: str):
33
+ """Return PIL image or (None, error). Handles http/https and data‑URL."""
34
  try:
35
  url = url.strip()
36
  if url.startswith("data:"):
 
37
  import base64
38
+ _, data = url.split(",", 1)
39
+ img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
40
  return img, None
41
+ if not urllib.parse.urlsplit(url).scheme:
42
+ return None, "Missing http/https scheme."
43
+ r = requests.get(url, timeout=10, headers={"User-Agent": "duck.ai"})
44
+ r.raise_for_status()
45
+ return Image.open(BytesIO(r.content)).convert("RGB"), None
 
 
46
  except Exception as e:
47
+ return None, f"Load error: {e}"
48
 
49
+ def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
50
+ """Return a single β€œmost detailed” base caption."""
51
  inputs = processor(images=img, return_tensors="pt")
52
+ pix = inputs.pixel_values.to(device)
53
+
54
+ if sample:
55
+ out = vision.generate(
56
+ pix,
57
+ max_length=max_len,
58
+ do_sample=True,
59
+ temperature=0.8,
60
+ top_k=50,
61
+ top_p=0.9,
62
+ num_return_sequences=3,
63
+ early_stopping=True,
64
+ )
65
+ else:
66
+ # ensure num_return ≀ beams
67
+ out = vision.generate(
68
+ pix,
69
+ max_length=max_len,
70
+ num_beams=beams,
71
+ num_return_sequences=min(3, beams),
72
+ early_stopping=True,
73
+ )
74
+ caps = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
75
+ # pick longest (most detailed)
76
+ return max(caps, key=lambda s: len(s.split()))
77
+
78
+ def expand_caption(base: str, prompt: str = None, max_len=160):
79
+ """Rich T5 expansion."""
80
  if prompt and prompt.strip():
81
+ instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
82
  else:
83
+ instr = f"Expand with rich visual detail. Caption: \"{base}\""
84
+
85
+ toks = rewriter_tok(
86
+ instr,
87
+ return_tensors="pt",
88
+ truncation=True,
89
+ padding="max_length",
90
+ max_length=256,
91
+ ).to(device)
92
+
93
+ out = rewriter.generate(
94
+ **toks,
95
+ max_length=max_len,
96
+ num_beams=4,
97
+ early_stopping=True,
98
+ no_repeat_ngram_size=3,
99
+ )
100
+ return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
101
+
102
+ # -------------------------------------------------
103
+ # Async expansion (background thread)
104
+ # -------------------------------------------------
105
+ def async_expand(base, prompt, max_len, status):
106
  try:
107
+ status["text"] = "Expanding…"
 
 
 
 
 
108
  time.sleep(0.1)
109
+ result = expand_caption(base, prompt, max_len)
110
+ status["text"] = "Done"
111
+ return result
112
  except Exception as e:
113
+ status["text"] = f"Error: {e}"
114
+ return base
115
+
116
+ # -------------------------------------------------
117
+ # Gradio callbacks
118
+ # -------------------------------------------------
119
+ def fast_describe(url, prompt, detail, beams, sample):
120
+ img, err = load_image(url)
 
 
121
  if err:
122
+ return None, "", err
123
 
 
124
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
125
+ max_expand = detail_map.get(detail, 140)
 
 
 
 
126
 
127
+ base = generate_base(img, beams=beams, sample=sample)
128
+ status = {"text": "Queued…"}
 
 
 
 
 
129
 
130
  def worker():
131
+ status["final"] = async_expand(base, prompt, max_expand, status)
 
 
 
 
132
 
133
+ threading.Thread(target=worker, daemon=True).start()
134
  return img, base, status["text"]
135
 
136
+ def final_caption(url, prompt, detail, beams, sample):
137
+ img, err = load_image(url)
 
 
 
 
 
 
 
 
138
  if err:
139
+ return "", err
 
 
140
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
141
+ max_expand = detail_map.get(detail, 140)
142
+
143
+ base = generate_base(img, beams=beams, sample=sample)
144
  try:
145
+ final = expand_caption(base, prompt, max_expand)
146
+ return final, "Done"
147
  except Exception as e:
148
  return base, f"Expand error: {e}"
149
 
150
+ # -------------------------------------------------
151
+ # UI
152
+ # -------------------------------------------------
153
+ css = "footer {display:none !important;}"
154
+ with gr.Blocks(css=css, title="Image Describer (CPU)") as demo:
155
+ gr.Markdown("## Image Describer – fast base caption + optional detailed rewrite")
 
156
  with gr.Row():
157
+ with gr.Column():
158
+ url_in = gr.Textbox(label="Image URL / data‑URL")
159
+ prompt_in = gr.Textbox(label="Optional prompt")
160
+ detail_in = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Detail level")
161
+ beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams (higher = better, slower)")
162
+ sample_in = gr.Checkbox(label="Enable sampling (more diverse)", value=False)
163
+ go_btn = gr.Button("Load & Describe (fast)")
164
+ final_btn = gr.Button("Get final caption (detailed)")
165
+ status_out = gr.Textbox(label="Status", interactive=False)
166
+ with gr.Column():
 
167
  img_out = gr.Image(type="pil", label="Image")
168
+ with gr.Column():
169
+ caption_out = gr.Textbox(label="Caption", lines=8)
170
+
171
+ go_btn.click(
172
+ fn=fast_describe,
173
+ inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
174
+ outputs=[img_out, caption_out, status_out],
175
+ )
176
+ final_btn.click(
177
+ fn=final_caption,
178
+ inputs=[url_in, prompt_in, detail_in, beams_in, sample_in],
179
+ outputs=[caption_out, status_out],
180
+ )
 
 
 
181
 
182
  if __name__ == "__main__":
183
  demo.launch(server_name="0.0.0.0", server_port=7860)