Hug0endob commited on
Commit
7d16cdf
·
verified ·
1 Parent(s): 63ffe59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -42
app.py CHANGED
@@ -18,9 +18,6 @@ from transformers import (
18
  T5Tokenizer,
19
  )
20
 
21
- # -------------------------------------------------
22
- # Device & models
23
- # -------------------------------------------------
24
  device = torch.device("cpu")
25
 
26
  IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
@@ -34,13 +31,12 @@ rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
34
  rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
35
 
36
 
37
- # -------------------------------------------------
38
- # Helpers
39
- # -------------------------------------------------
40
  def load_image(url: str):
41
  """Return (PIL.Image, None) or (None, error). Handles http/https and data‑URL."""
42
  try:
43
- url = url.strip()
 
 
44
  if url.startswith("data:"):
45
  _, data = url.split(",", 1)
46
  img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
@@ -55,10 +51,8 @@ def load_image(url: str):
55
 
56
 
57
  def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
58
- """Return the longest caption (most detailed) from the vision model."""
59
  inputs = processor(images=img, return_tensors="pt")
60
  pix = inputs.pixel_values.to(device)
61
-
62
  if sample:
63
  out = vision.generate(
64
  pix,
@@ -83,12 +77,10 @@ def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
83
 
84
 
85
  def expand_caption(base: str, prompt: str = None, max_len=160):
86
- """Use T5 to expand the base caption."""
87
  if prompt and prompt.strip():
88
  instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
89
  else:
90
  instr = f"Expand with rich visual detail. Caption: \"{base}\""
91
-
92
  toks = rewriter_tok(
93
  instr,
94
  return_tensors="pt",
@@ -96,7 +88,6 @@ def expand_caption(base: str, prompt: str = None, max_len=160):
96
  padding="max_length",
97
  max_length=256,
98
  ).to(device)
99
-
100
  out = rewriter.generate(
101
  **toks,
102
  max_length=max_len,
@@ -108,36 +99,26 @@ def expand_caption(base: str, prompt: str = None, max_len=160):
108
 
109
 
110
  def async_expand(base, prompt, max_len, status):
111
- """Background expansion; updates status dict."""
112
  try:
113
  status["text"] = "Expanding…"
114
- time.sleep(0.1) # tiny yield for UI responsiveness
115
  result = expand_caption(base, prompt, max_len)
116
  status["text"] = "Done"
117
- return result
118
  except Exception as e:
119
  status["text"] = f"Error: {e}"
120
- return base
121
 
122
 
123
- # -------------------------------------------------
124
- # Gradio callbacks
125
- # -------------------------------------------------
126
  def fast_describe(url, prompt, detail, beams, sample):
127
  img, err = load_image(url)
128
  if err:
129
  return None, "", err
130
-
131
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
132
  max_expand = detail_map.get(detail, 140)
133
-
134
  base = generate_base(img, beams=beams, sample=sample)
135
- status = {"text": "Queued…"}
136
-
137
- def worker():
138
- status["final"] = async_expand(base, prompt, max_expand, status)
139
-
140
- threading.Thread(target=worker, daemon=True).start()
141
  return img, base, status["text"]
142
 
143
 
@@ -145,10 +126,8 @@ def final_caption(url, prompt, detail, beams, sample):
145
  img, err = load_image(url)
146
  if err:
147
  return "", err
148
-
149
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
150
  max_expand = detail_map.get(detail, 140)
151
-
152
  base = generate_base(img, beams=beams, sample=sample)
153
  try:
154
  final = expand_caption(base, prompt, max_expand)
@@ -157,9 +136,6 @@ def final_caption(url, prompt, detail, beams, sample):
157
  return base, f"Expand error: {e}"
158
 
159
 
160
- # -------------------------------------------------
161
- # UI
162
- # -------------------------------------------------
163
  css = "footer {display:none !important;}"
164
  with gr.Blocks() as demo:
165
  gr.Markdown("## Image Describer")
@@ -189,15 +165,6 @@ with gr.Blocks() as demo:
189
  outputs=[caption_out, status_out],
190
  )
191
 
192
- # -------------------------------------------------
193
- # Launch
194
- # -------------------------------------------------
195
  if __name__ == "__main__":
196
  demo.queue()
197
- demo.launch(
198
- server_name="0.0.0.0",
199
- server_port=7860,
200
- css=css,
201
- title="Image Describer (CPU)",
202
- prevent_thread_lock=True,
203
- )
 
18
  T5Tokenizer,
19
  )
20
 
 
 
 
21
  device = torch.device("cpu")
22
 
23
  IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
 
31
  rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
32
 
33
 
 
 
 
34
  def load_image(url: str):
35
  """Return (PIL.Image, None) or (None, error). Handles http/https and data‑URL."""
36
  try:
37
+ url = (url or "").strip()
38
+ if not url:
39
+ return None, "No URL provided."
40
  if url.startswith("data:"):
41
  _, data = url.split(",", 1)
42
  img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
 
51
 
52
 
53
  def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
 
54
  inputs = processor(images=img, return_tensors="pt")
55
  pix = inputs.pixel_values.to(device)
 
56
  if sample:
57
  out = vision.generate(
58
  pix,
 
77
 
78
 
79
  def expand_caption(base: str, prompt: str = None, max_len=160):
 
80
  if prompt and prompt.strip():
81
  instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
82
  else:
83
  instr = f"Expand with rich visual detail. Caption: \"{base}\""
 
84
  toks = rewriter_tok(
85
  instr,
86
  return_tensors="pt",
 
88
  padding="max_length",
89
  max_length=256,
90
  ).to(device)
 
91
  out = rewriter.generate(
92
  **toks,
93
  max_length=max_len,
 
99
 
100
 
101
  def async_expand(base, prompt, max_len, status):
 
102
  try:
103
  status["text"] = "Expanding…"
104
+ time.sleep(0.1)
105
  result = expand_caption(base, prompt, max_len)
106
  status["text"] = "Done"
107
+ status["final"] = result
108
  except Exception as e:
109
  status["text"] = f"Error: {e}"
110
+ status["final"] = base
111
 
112
 
 
 
 
113
  def fast_describe(url, prompt, detail, beams, sample):
114
  img, err = load_image(url)
115
  if err:
116
  return None, "", err
 
117
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
118
  max_expand = detail_map.get(detail, 140)
 
119
  base = generate_base(img, beams=beams, sample=sample)
120
+ status = {"text": "Queued…", "final": ""}
121
+ threading.Thread(target=async_expand, args=(base, prompt, max_expand, status), daemon=True).start()
 
 
 
 
122
  return img, base, status["text"]
123
 
124
 
 
126
  img, err = load_image(url)
127
  if err:
128
  return "", err
 
129
  detail_map = {"Low": 80, "Medium": 140, "High": 220}
130
  max_expand = detail_map.get(detail, 140)
 
131
  base = generate_base(img, beams=beams, sample=sample)
132
  try:
133
  final = expand_caption(base, prompt, max_expand)
 
136
  return base, f"Expand error: {e}"
137
 
138
 
 
 
 
139
  css = "footer {display:none !important;}"
140
  with gr.Blocks() as demo:
141
  gr.Markdown("## Image Describer")
 
165
  outputs=[caption_out, status_out],
166
  )
167
 
 
 
 
168
  if __name__ == "__main__":
169
  demo.queue()
170
+ demo.launch(server_name="0.0.0.0", server_port=7860, css=css, prevent_thread_lock=True)