Hug0endob commited on
Commit
63ffe59
·
verified ·
1 Parent(s): 049393b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -1,18 +1,25 @@
1
- # app.py – corrected for Gradio 6+
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from PIL import Image
4
- import requests, urllib.parse, threading, time
5
  import torch
 
6
  from transformers import (
 
7
  VisionEncoderDecoderModel,
8
  ViTImageProcessor,
9
- AutoTokenizer,
10
  T5ForConditionalGeneration,
11
  T5Tokenizer,
12
  )
13
 
14
  # -------------------------------------------------
15
- # Device & models (CPU)
16
  # -------------------------------------------------
17
  device = torch.device("cpu")
18
 
@@ -26,6 +33,7 @@ vision = VisionEncoderDecoderModel.from_pretrained(IMG_MODEL).to(device).eval()
26
  rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
27
  rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
28
 
 
29
  # -------------------------------------------------
30
  # Helpers
31
  # -------------------------------------------------
@@ -34,7 +42,6 @@ def load_image(url: str):
34
  try:
35
  url = url.strip()
36
  if url.startswith("data:"):
37
- import base64
38
  _, data = url.split(",", 1)
39
  img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
40
  return img, None
@@ -46,8 +53,9 @@ def load_image(url: str):
46
  except Exception as e:
47
  return None, f"Load error: {e}"
48
 
 
49
  def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
50
- """Return a single most detailed base caption."""
51
  inputs = processor(images=img, return_tensors="pt")
52
  pix = inputs.pixel_values.to(device)
53
 
@@ -71,10 +79,11 @@ def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
71
  early_stopping=True,
72
  )
73
  caps = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
74
- return max(caps, key=lambda s: len(s.split())) # longest = most detailed
 
75
 
76
  def expand_caption(base: str, prompt: str = None, max_len=160):
77
- """Rich T5 expansion."""
78
  if prompt and prompt.strip():
79
  instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
80
  else:
@@ -97,13 +106,12 @@ def expand_caption(base: str, prompt: str = None, max_len=160):
97
  )
98
  return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
99
 
100
- # -------------------------------------------------
101
- # Async expansion (background thread)
102
- # -------------------------------------------------
103
  def async_expand(base, prompt, max_len, status):
 
104
  try:
105
  status["text"] = "Expanding…"
106
- time.sleep(0.1)
107
  result = expand_caption(base, prompt, max_len)
108
  status["text"] = "Done"
109
  return result
@@ -111,6 +119,7 @@ def async_expand(base, prompt, max_len, status):
111
  status["text"] = f"Error: {e}"
112
  return base
113
 
 
114
  # -------------------------------------------------
115
  # Gradio callbacks
116
  # -------------------------------------------------
@@ -131,10 +140,12 @@ def fast_describe(url, prompt, detail, beams, sample):
131
  threading.Thread(target=worker, daemon=True).start()
132
  return img, base, status["text"]
133
 
 
134
  def final_caption(url, prompt, detail, beams, sample):
135
  img, err = load_image(url)
136
  if err:
137
  return "", err
 
138
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
139
  max_expand = detail_map.get(detail, 140)
140
 
@@ -145,25 +156,20 @@ def final_caption(url, prompt, detail, beams, sample):
145
  except Exception as e:
146
  return base, f"Expand error: {e}"
147
 
 
148
  # -------------------------------------------------
149
  # UI
150
  # -------------------------------------------------
151
  css = "footer {display:none !important;}"
152
- with gr.Blocks() as demo: # no css here
153
- gr.Markdown(
154
- "## Image Describer"
155
- )
156
  with gr.Row():
157
  with gr.Column():
158
  url_in = gr.Textbox(label="Image URL / data‑URL")
159
  prompt_in = gr.Textbox(label="Optional prompt")
160
- detail_in = gr.Radio(
161
- ["Low", "Medium", "High"], value="Medium", label="Detail level"
162
- )
163
  beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
164
- sample_in = gr.Checkbox(
165
- label="Enable sampling (more diverse)", value=False
166
- )
167
  go_btn = gr.Button("Load & Describe (fast)")
168
  final_btn = gr.Button("Get final caption (detailed)")
169
  status_out = gr.Textbox(label="Status", interactive=False)
@@ -184,14 +190,14 @@ with gr.Blocks() as demo: # no css here
184
  )
185
 
186
  # -------------------------------------------------
187
- # Launch – css and title are passed here (Gradio 6+)
188
  # -------------------------------------------------
189
  if __name__ == "__main__":
190
- demo.queue() # enables background threads without leaking the event loop
191
  demo.launch(
192
  server_name="0.0.0.0",
193
  server_port=7860,
194
  css=css,
195
  title="Image Describer (CPU)",
196
- prevent_thread_lock=True, # clean shutdown → no “invalid file descriptor” warnings
197
  )
 
1
+ # app.py – Gradio 6+ (CPU)
2
+
3
+ import base64
4
+ import threading
5
+ import time
6
+ import urllib.parse
7
+ from io import BytesIO
8
+
9
  import gradio as gr
10
+ import requests
 
11
  import torch
12
+ from PIL import Image
13
  from transformers import (
14
+ AutoTokenizer,
15
  VisionEncoderDecoderModel,
16
  ViTImageProcessor,
 
17
  T5ForConditionalGeneration,
18
  T5Tokenizer,
19
  )
20
 
21
  # -------------------------------------------------
22
+ # Device & models
23
  # -------------------------------------------------
24
  device = torch.device("cpu")
25
 
 
33
  rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
34
  rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
35
 
36
+
37
  # -------------------------------------------------
38
  # Helpers
39
  # -------------------------------------------------
 
42
  try:
43
  url = url.strip()
44
  if url.startswith("data:"):
 
45
  _, data = url.split(",", 1)
46
  img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
47
  return img, None
 
53
  except Exception as e:
54
  return None, f"Load error: {e}"
55
 
56
+
57
  def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
58
+ """Return the longest caption (most detailed) from the vision model."""
59
  inputs = processor(images=img, return_tensors="pt")
60
  pix = inputs.pixel_values.to(device)
61
 
 
79
  early_stopping=True,
80
  )
81
  caps = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in out]
82
+ return max(caps, key=lambda s: len(s.split()))
83
+
84
 
85
  def expand_caption(base: str, prompt: str = None, max_len=160):
86
+ """Use T5 to expand the base caption."""
87
  if prompt and prompt.strip():
88
  instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
89
  else:
 
106
  )
107
  return rewriter_tok.decode(out[0], skip_special_tokens=True).strip()
108
 
109
+
 
 
110
  def async_expand(base, prompt, max_len, status):
111
+ """Background expansion; updates status dict."""
112
  try:
113
  status["text"] = "Expanding…"
114
+ time.sleep(0.1) # tiny yield for UI responsiveness
115
  result = expand_caption(base, prompt, max_len)
116
  status["text"] = "Done"
117
  return result
 
119
  status["text"] = f"Error: {e}"
120
  return base
121
 
122
+
123
  # -------------------------------------------------
124
  # Gradio callbacks
125
  # -------------------------------------------------
 
140
  threading.Thread(target=worker, daemon=True).start()
141
  return img, base, status["text"]
142
 
143
+
144
  def final_caption(url, prompt, detail, beams, sample):
145
  img, err = load_image(url)
146
  if err:
147
  return "", err
148
+
149
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
150
  max_expand = detail_map.get(detail, 140)
151
 
 
156
  except Exception as e:
157
  return base, f"Expand error: {e}"
158
 
159
+
160
  # -------------------------------------------------
161
  # UI
162
  # -------------------------------------------------
163
  css = "footer {display:none !important;}"
164
+ with gr.Blocks() as demo:
165
+ gr.Markdown("## Image Describer")
 
 
166
  with gr.Row():
167
  with gr.Column():
168
  url_in = gr.Textbox(label="Image URL / data‑URL")
169
  prompt_in = gr.Textbox(label="Optional prompt")
170
+ detail_in = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Detail level")
 
 
171
  beams_in = gr.Slider(1, 4, step=1, value=2, label="Beams")
172
+ sample_in = gr.Checkbox(label="Enable sampling (more diverse)", value=False)
 
 
173
  go_btn = gr.Button("Load & Describe (fast)")
174
  final_btn = gr.Button("Get final caption (detailed)")
175
  status_out = gr.Textbox(label="Status", interactive=False)
 
190
  )
191
 
192
  # -------------------------------------------------
193
+ # Launch
194
  # -------------------------------------------------
195
  if __name__ == "__main__":
196
+ demo.queue()
197
  demo.launch(
198
  server_name="0.0.0.0",
199
  server_port=7860,
200
  css=css,
201
  title="Image Describer (CPU)",
202
+ prevent_thread_lock=True,
203
  )