Hug0endob commited on
Commit
9bccfcb
Β·
verified Β·
1 Parent(s): e3c1c79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -49
app.py CHANGED
@@ -11,33 +11,34 @@ from transformers import (
11
  T5Tokenizer,
12
  )
13
  import urllib.parse
 
 
14
 
15
  device = torch.device("cpu")
16
 
17
- # Models
18
  PROCESSOR_NAME = "nlpconnect/vit-gpt2-image-captioning"
 
 
 
19
  processor = ViTImageProcessor.from_pretrained(PROCESSOR_NAME)
20
  tokenizer = AutoTokenizer.from_pretrained(PROCESSOR_NAME)
21
  model = VisionEncoderDecoderModel.from_pretrained(PROCESSOR_NAME).to(device)
22
  model.eval()
23
 
24
- # Optional rewriter (T5-small) to make captions more natural / respond to prompt
25
- rewriter_tokenizer = T5Tokenizer.from_pretrained("t5-small")
26
- rewriter = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
27
  rewriter.eval()
28
 
29
  def load_image_from_url(url: str, timeout=10):
30
  try:
31
- # allow string that is data URL or direct URL
32
  url = url.strip()
33
  if url.startswith("data:"):
34
- # let PIL handle data URLs via BytesIO after splitting
35
  header, encoded = url.split(",", 1)
36
  import base64
37
  data = base64.b64decode(encoded)
38
  img = Image.open(BytesIO(data)).convert("RGB")
39
  return img, None
40
- # ensure proper URL encoding
41
  parsed = urllib.parse.urlsplit(url)
42
  if parsed.scheme == "":
43
  return None, "Invalid URL (missing scheme: http/https)."
@@ -48,68 +49,156 @@ def load_image_from_url(url: str, timeout=10):
48
  except Exception as e:
49
  return None, f"Error loading image: {e}"
50
 
51
- def generate_caption(img: Image.Image, prompt: str = None, max_len: int = 30, num_beams: int = 2):
52
- # Prepare encoder inputs
53
  inputs = processor(images=img, return_tensors="pt")
54
  pixel_values = inputs.pixel_values.to(device)
55
 
56
- # If a prompt is provided, prepend it to the decoder start tokens via tokenizer (prefix)
57
- # This is a lightweight way to bias output by using the tokenizer's bos/tokenizer decoding prefix.
58
- gen_kwargs = {"max_length": max_len, "num_beams": num_beams, "early_stopping": True}
59
- if prompt:
60
- # For vit-gpt2 model, we can try to use forced_decoder_input_ids or prefix decoding
61
- # Simpler approach: generate normally and then rely on rewriter to apply prompt.
62
- pass
63
-
64
- out = model.generate(pixel_values, **gen_kwargs)
65
- caption = tokenizer.decode(out[0], skip_special_tokens=True).strip()
66
- return caption
67
-
68
- def rewrite_caption_with_prompt(caption: str, prompt: str = None, max_len: int = 64):
69
- # If prompt provided, use it to instruct T5; otherwise paraphrase
70
- if prompt:
71
- input_text = f"paraphrase: {caption} prompt: {prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
- input_text = "paraphrase: " + caption
74
- tok = rewriter_tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
75
- out = rewriter.generate(**tok, max_length=max_len, num_beams=2, early_stopping=True)
76
- rewritten = rewriter_tokenizer.decode(out[0], skip_special_tokens=True).strip()
77
- return rewritten
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- def describe_image(url: str, prompt: str, max_caption_len: int = 30, expand: bool = True, beams: int = 2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  img, err = load_image_from_url(url)
81
  if err:
82
- return None, f"Error: {err}"
83
- caption = generate_caption(img, prompt=prompt, max_len=max_caption_len, num_beams=beams)
84
- if expand:
85
- try:
86
- caption = rewrite_caption_with_prompt(caption, prompt=prompt, max_len=64)
87
- except Exception:
88
- pass
89
- if len(caption.split()) < 6:
90
- caption = f"{caption}. The scene appears to contain: {caption.lower()}."
91
- return img, caption
92
 
 
93
  css = """
94
  footer {display: none !important;}
95
  """
96
 
97
- with gr.Blocks(css=css, title="Image Describer (vit-gpt2, uncensored, promptable)") as demo:
98
- gr.Markdown("## Image Describer β€” uncensored captions, optional prompt to bias description")
99
  with gr.Row():
100
  with gr.Column(scale=1):
101
  url_in = gr.Textbox(label="Image URL or data URL", placeholder="https://example.com/photo.jpg")
102
- prompt_in = gr.Textbox(label="Optional prompt (e.g. 'Describe people and actions')", placeholder="Focus on people, actions, or colors.")
103
- max_len = gr.Slider(minimum=8, maximum=60, value=30, label="Max caption length")
 
104
  beams = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Num beams (higher = better quality, slower)")
105
- expand_chk = gr.Checkbox(label="Rewrite/Paraphrase with prompt (slower)", value=True)
106
- go = gr.Button("Load & Describe")
 
 
107
  with gr.Column(scale=1):
108
  img_out = gr.Image(type="pil", label="Image")
109
  with gr.Column(scale=1):
110
- caption_out = gr.Textbox(label="Descriptive caption", lines=6)
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- go.click(fn=describe_image, inputs=[url_in, prompt_in, max_len, expand_chk, beams], outputs=[img_out, caption_out])
113
 
114
  if __name__ == "__main__":
115
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
11
  T5Tokenizer,
12
  )
13
  import urllib.parse
14
+ import threading
15
+ import time
16
 
17
  device = torch.device("cpu")
18
 
19
+ # Model names
20
  PROCESSOR_NAME = "nlpconnect/vit-gpt2-image-captioning"
21
+ REWRITER_NAME = "t5-small"
22
+
23
+ # Load models (CPU)
24
  processor = ViTImageProcessor.from_pretrained(PROCESSOR_NAME)
25
  tokenizer = AutoTokenizer.from_pretrained(PROCESSOR_NAME)
26
  model = VisionEncoderDecoderModel.from_pretrained(PROCESSOR_NAME).to(device)
27
  model.eval()
28
 
29
+ rewriter_tokenizer = T5Tokenizer.from_pretrained(REWRITER_NAME)
30
+ rewriter = T5ForConditionalGeneration.from_pretrained(REWRITER_NAME).to(device)
 
31
  rewriter.eval()
32
 
33
  def load_image_from_url(url: str, timeout=10):
34
  try:
 
35
  url = url.strip()
36
  if url.startswith("data:"):
 
37
  header, encoded = url.split(",", 1)
38
  import base64
39
  data = base64.b64decode(encoded)
40
  img = Image.open(BytesIO(data)).convert("RGB")
41
  return img, None
 
42
  parsed = urllib.parse.urlsplit(url)
43
  if parsed.scheme == "":
44
  return None, "Invalid URL (missing scheme: http/https)."
 
49
  except Exception as e:
50
  return None, f"Error loading image: {e}"
51
 
52
+ # --- Generation & rewriting helpers ---
53
+ def generate_caption_candidates(img: Image.Image, max_len: int = 40, num_beams: int = 2, num_return_sequences: int = 3, do_sample: bool = False):
54
  inputs = processor(images=img, return_tensors="pt")
55
  pixel_values = inputs.pixel_values.to(device)
56
 
57
+ gen_kwargs = {
58
+ "max_length": max_len,
59
+ "num_beams": num_beams,
60
+ "early_stopping": True,
61
+ "do_sample": do_sample,
62
+ "num_return_sequences": num_return_sequences,
63
+ }
64
+ # model.generate returns tensor of shape (num_return_sequences, seq_len) when requested
65
+ outputs = model.generate(pixel_values, **gen_kwargs)
66
+ captions = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs]
67
+ # Deduplicate preserving order
68
+ seen = set()
69
+ unique = []
70
+ for c in captions:
71
+ if c not in seen:
72
+ seen.add(c)
73
+ unique.append(c)
74
+ return unique
75
+
76
+ def pick_most_detailed(candidates):
77
+ # heuristic: prefer longer by word count, then more unique words
78
+ best = max(candidates, key=lambda s: (len(s.split()), len(set(s.split()))))
79
+ return best
80
+
81
+ def expand_with_t5(caption: str, prompt: str = None, max_len: int = 160):
82
+ # Instruction to expand and add rich visual detail
83
+ if prompt and prompt.strip():
84
+ instr = f"Expand and elaborate the caption using this instruction: '{prompt}'. Caption: \"{caption}\""
85
  else:
86
+ instr = f"Expand and elaborate the caption with rich visual detail (objects, colors, textures, scene, actions). Caption: \"{caption}\""
87
+ tok = rewriter_tokenizer(instr, return_tensors="pt", truncation=True, padding=True).to(device)
88
+ out = rewriter.generate(**tok, max_length=max_len, num_beams=4, early_stopping=True, no_repeat_ngram_size=3)
89
+ expanded = rewriter_tokenizer.decode(out[0], skip_special_tokens=True).strip()
90
+ return expanded
91
+
92
+ # Background worker pattern to run expansion and report progress
93
+ def _background_expand_and_return(caption, prompt, max_expand_len, status_callback):
94
+ try:
95
+ # Inform start
96
+ status_callback("Expanding caption (step 1/2)...")
97
+ # Small sleep allows UI update
98
+ time.sleep(0.1)
99
+ expanded = expand_with_t5(caption, prompt=prompt, max_len=max_expand_len)
100
+ status_callback("Finalizing (step 2/2)...")
101
+ time.sleep(0.1)
102
+ status_callback("Done")
103
+ return expanded
104
+ except Exception as e:
105
+ status_callback(f"Error during expand: {e}")
106
+ return caption
107
+
108
+ # Main describe function used by Gradio; it triggers generation and then expansion in background
109
+ def describe_image_controller(url: str, prompt: str, detail_level: str, max_caption_len: int = 40, beams: int = 2, do_sample: bool = True):
110
+ """
111
+ Returns: (img or None, caption_text, status_text)
112
+ The UI will start background expansion and update status via a small helper.
113
+ """
114
+ img, err = load_image_from_url(url)
115
+ if err:
116
+ return None, "", f"Error: {err}"
117
+
118
+ # Map detail_level to rewriter max_len
119
+ detail_map = {"Low": 80, "Medium": 140, "High": 220}
120
+ max_expand_len = detail_map.get(detail_level, 140)
121
+
122
+ # Generate candidates
123
+ candidates = generate_caption_candidates(img, max_len=max_caption_len, num_beams=beams, num_return_sequences=3, do_sample=do_sample)
124
+ base = pick_most_detailed(candidates)
125
 
126
+ # Start background thread to expand (T5) and update status via a Gradio status element (we'll use a simple polling text)
127
+ # We'll use a small mutable container to send status updates via closure
128
+ status = {"text": "Queued for expansion..."}
129
+ def status_callback(s):
130
+ status["text"] = s
131
+
132
+ result_container = {"final": base}
133
+
134
+ def worker():
135
+ expanded = _background_expand_and_return(base, prompt, max_expand_len, status_callback)
136
+ result_container["final"] = expanded
137
+
138
+ thread = threading.Thread(target=worker, daemon=True)
139
+ thread.start()
140
+
141
+ # Return image, initial base caption, and initial status. The frontend will poll for status/final via separate endpoints
142
+ return img, base, status["text"]
143
+
144
+ # Polling endpoints to retrieve status and final caption
145
+ def poll_status_and_caption(url: str, prompt: str, _placeholder):
146
+ # In this simple pattern we re-run a lightweight check by storing results in a global map keyed by URL+prompt
147
+ # For simplicity in this Space we will re-run expansion synchronously here if needed.
148
+ # But to avoid redoing heavy work, you can implement a shared cache (omitted for brevity).
149
+ return "If expansion still running, refresh in a few seconds. Final caption will replace base when ready."
150
+
151
+ # Simple endpoint to get final expanded caption synchronously (used when user hits 'Get final caption')
152
+ def get_final_caption(url: str, prompt: str, detail_level: str, max_caption_len: int = 40, beams: int = 2, do_sample: bool = True):
153
  img, err = load_image_from_url(url)
154
  if err:
155
+ return "", f"Error: {err}"
156
+ candidates = generate_caption_candidates(img, max_len=max_caption_len, num_beams=beams, num_return_sequences=3, do_sample=do_sample)
157
+ base = pick_most_detailed(candidates)
158
+ detail_map = {"Low": 80, "Medium": 140, "High": 220}
159
+ max_expand_len = detail_map.get(detail_level, 140)
160
+ try:
161
+ expanded = expand_with_t5(base, prompt=prompt, max_len=max_expand_len)
162
+ return expanded, "Done"
163
+ except Exception as e:
164
+ return base, f"Expand error: {e}"
165
 
166
+ # Gradio UI
167
  css = """
168
  footer {display: none !important;}
169
  """
170
 
171
+ with gr.Blocks(css=css, title="Image Describer (vit-gpt2, promptable, detailed)") as demo:
172
+ gr.Markdown("## Image Describer β€” uncensored captions, optional prompt to bias description. Use 'Get final caption' for the detailed expanded output (may take longer).")
173
  with gr.Row():
174
  with gr.Column(scale=1):
175
  url_in = gr.Textbox(label="Image URL or data URL", placeholder="https://example.com/photo.jpg")
176
+ prompt_in = gr.Textbox(label="Optional prompt (e.g. 'Focus on people and actions')", placeholder="Focus on people, actions, or colors.")
177
+ detail_level = gr.Radio(choices=["Low", "Medium", "High"], value="Medium", label="Detail level (affects expansion length)")
178
+ max_len = gr.Slider(minimum=8, maximum=80, value=40, label="Base caption max length")
179
  beams = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Num beams (higher = better quality, slower)")
180
+ do_sample_chk = gr.Checkbox(label="Enable sampling (more diverse)", value=True)
181
+ go = gr.Button("Load & Describe (fast)")
182
+ get_final = gr.Button("Get final caption (detailed, slower)")
183
+ status_txt = gr.Textbox(label="Status", value="Idle", interactive=False)
184
  with gr.Column(scale=1):
185
  img_out = gr.Image(type="pil", label="Image")
186
  with gr.Column(scale=1):
187
+ caption_out = gr.Textbox(label="Caption (base or final)", lines=8)
188
+
189
+ # Fast path: generate base caption and immediately start background expand (status will be approximate)
190
+ def on_go(url, prompt, detail_level, max_len, beams, do_sample):
191
+ img, base_caption, status = describe_image_controller(url, prompt, detail_level, max_caption_len=max_len, beams=beams, do_sample=do_sample)
192
+ return img, base_caption, status
193
+
194
+ go.click(fn=on_go, inputs=[url_in, prompt_in, detail_level, max_len, beams, do_sample_chk], outputs=[img_out, caption_out, status_txt])
195
+
196
+ # Synchronous, explicit final result (user clicks when they want the full expanded caption)
197
+ def on_get_final(url, prompt, detail_level, max_len, beams, do_sample):
198
+ final_caption, status = get_final_caption(url, prompt, detail_level, max_caption_len=max_len, beams=beams, do_sample=do_sample)
199
+ return final_caption, status
200
 
201
+ get_final.click(fn=on_get_final, inputs=[url_in, prompt_in, detail_level, max_len, beams, do_sample_chk], outputs=[caption_out, status_txt])
202
 
203
  if __name__ == "__main__":
204
  demo.launch(server_name="0.0.0.0", server_port=7860)