Afsha001 commited on
Commit
cfb72d8
Β·
verified Β·
1 Parent(s): e67f5ad

update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -50,9 +50,6 @@ if not JINA_KEY:
50
  st.error("JINA_KEY missing. Go to Space Settings β†’ Secrets and add it.")
51
  st.stop()
52
 
53
- # ============================================================================
54
- # CHANGE 1: load_local_models β€” replaced moondream with GIT-Large-COCO
55
- # ============================================================================
56
  @st.cache_resource
57
  def load_local_models():
58
  from transformers import (
@@ -64,7 +61,6 @@ def load_local_models():
64
  )
65
  gc.collect()
66
 
67
- # GIT-Large-COCO β€” local caption generation, no API, no auth needed
68
  git_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
69
  git_model = AutoModelForCausalLM.from_pretrained(
70
  "microsoft/git-large-coco",
@@ -72,7 +68,6 @@ def load_local_models():
72
  )
73
  git_model.eval()
74
 
75
- # BLIP β€” for ITM scoring and cosine similarity
76
  blip_processor = BlipProcessor.from_pretrained(
77
  "Salesforce/blip-image-captioning-large"
78
  )
@@ -82,7 +77,6 @@ def load_local_models():
82
  )
83
  blip_itm_model.eval()
84
 
85
- # DINO β€” for object detection
86
  dino_processor = AutoProcessor.from_pretrained(
87
  "IDEA-Research/grounding-dino-base"
88
  )
@@ -105,23 +99,55 @@ def image_to_data_uri(image: Image.Image) -> str:
105
  return f"data:image/jpeg;base64,{b64}"
106
 
107
  # ============================================================================
108
- # CHANGE 2: generate_captions_git β€” replaced moondream caption function
 
 
109
  # ============================================================================
110
  def generate_captions_git(image: Image.Image, git_proc, git_mod) -> list:
111
- length_params = [30, 50, 60, 70, 40]
112
- captions = []
113
 
114
- for max_tokens in length_params:
115
- try:
116
- pixel_values = git_proc(
117
- images=image,
118
- return_tensors="pt"
119
- ).pixel_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
 
121
  with torch.no_grad():
122
  generated_ids = git_mod.generate(
123
  pixel_values=pixel_values,
124
- max_new_tokens=max_tokens
125
  )
126
 
127
  cap = git_proc.batch_decode(
@@ -135,17 +161,22 @@ def generate_captions_git(image: Image.Image, git_proc, git_mod) -> list:
135
  st.warning(f"GIT error: {str(e)[:80]}")
136
  captions.append("a scene shown in the image")
137
 
 
138
  seen, unique = set(), []
139
  for c in captions:
140
  if c not in seen:
141
  seen.add(c)
142
  unique.append(c)
 
 
 
 
 
143
  while len(unique) < 5:
144
  unique.append(unique[0])
145
 
146
  return unique[:5]
147
 
148
- # unchanged
149
  def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
150
  scores = []
151
  for cap in captions:
@@ -165,7 +196,6 @@ def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
165
  scores.append(0.0)
166
  return scores
167
 
168
- # unchanged
169
  def compute_jina_scores(image: Image.Image, captions: list) -> list:
170
  img_data_uri = image_to_data_uri(image)
171
  scores = []
@@ -199,7 +229,6 @@ def compute_jina_scores(image: Image.Image, captions: list) -> list:
199
  scores.append(0.0)
200
  return scores
201
 
202
- # unchanged
203
  def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
204
  try:
205
  img_inp = blip_proc(images=image, return_tensors="pt")
@@ -227,7 +256,6 @@ def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
227
  st.warning(f"Cosine error: {str(e)[:60]}")
228
  return [0.0] * len(captions)
229
 
230
- # unchanged
231
  def majority_voting(captions, itm, jina, cosine) -> tuple:
232
  itm_r = np.argsort(itm)[::-1]
233
  jina_r = np.argsort(jina)[::-1]
@@ -245,7 +273,6 @@ def majority_voting(captions, itm, jina, cosine) -> tuple:
245
 
246
  return captions[top2[0]], captions[top2[1]], top2, dict(counts)
247
 
248
- # unchanged
249
  def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
250
  try:
251
  inputs = dino_proc(
@@ -288,7 +315,6 @@ def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
288
  st.warning(f"DINO error: {str(e)[:80]}")
289
  return "Object detection unavailable", []
290
 
291
- # unchanged
292
  def fuse_captions(cap1: str, cap2: str, objects: str) -> str:
293
  system_prompt = (
294
  "You are an expert image captioning assistant. "
@@ -331,9 +357,6 @@ def fuse_captions(cap1: str, cap2: str, objects: str) -> str:
331
  st.warning(f"Qwen exception: {str(e)[:60]}")
332
  return cap1
333
 
334
- # ============================================================================
335
- # CHANGE 3: sidebar β€” updated step 1 label
336
- # ============================================================================
337
  with st.sidebar:
338
  st.title("Image Caption Fusion")
339
  st.markdown("---")
@@ -385,14 +408,12 @@ if uploaded_file is not None:
385
  if st.button("Generate Caption", type="primary", use_container_width=True):
386
 
387
  with st.spinner("Loading local models (first run takes 2-3 min)..."):
388
- # CHANGE 4: updated unpacking β€” git_proc, git_mod instead of moon_mod
389
  git_proc, git_mod, blip_proc, blip_itm, dino_proc, dino_mod = load_local_models()
390
 
391
  progress = st.progress(0)
392
  status = st.empty()
393
 
394
  status.info("Step 1/7: Generating captions with GIT-Large-COCO...")
395
- # CHANGE 4: updated function call
396
  captions = generate_captions_git(input_image, git_proc, git_mod)
397
  progress.progress(14)
398
 
@@ -455,4 +476,4 @@ if uploaded_file is not None:
455
  f"font-size:18px;font-weight:500;text-align:center;"
456
  f"line-height:1.6;'>{final}</div>",
457
  unsafe_allow_html=True
458
- )
 
50
  st.error("JINA_KEY missing. Go to Space Settings β†’ Secrets and add it.")
51
  st.stop()
52
 
 
 
 
53
  @st.cache_resource
54
  def load_local_models():
55
  from transformers import (
 
61
  )
62
  gc.collect()
63
 
 
64
  git_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
65
  git_model = AutoModelForCausalLM.from_pretrained(
66
  "microsoft/git-large-coco",
 
68
  )
69
  git_model.eval()
70
 
 
71
  blip_processor = BlipProcessor.from_pretrained(
72
  "Salesforce/blip-image-captioning-large"
73
  )
 
77
  )
78
  blip_itm_model.eval()
79
 
 
80
  dino_processor = AutoProcessor.from_pretrained(
81
  "IDEA-Research/grounding-dino-base"
82
  )
 
99
  return f"data:image/jpeg;base64,{b64}"
100
 
101
  # ============================================================================
102
+ # ONLY CHANGE: generate_captions_git
103
+ # Fix: 5 different generation strategies instead of just max_new_tokens
104
+ # Greedy / beam search / sampling with different temperatures
105
  # ============================================================================
106
  def generate_captions_git(image: Image.Image, git_proc, git_mod) -> list:
 
 
107
 
108
+ strategies = [
109
+ # Greedy β€” short deterministic baseline
110
+ {
111
+ "max_new_tokens": 30
112
+ },
113
+ # Beam search β€” explores multiple decode paths
114
+ {
115
+ "max_new_tokens": 50,
116
+ "num_beams": 5,
117
+ "early_stopping": True
118
+ },
119
+ # Sampling β€” low temperature, focused output
120
+ {
121
+ "max_new_tokens": 60,
122
+ "do_sample": True,
123
+ "temperature": 0.7,
124
+ "top_k": 50
125
+ },
126
+ # Sampling β€” high temperature, creative output
127
+ {
128
+ "max_new_tokens": 70,
129
+ "do_sample": True,
130
+ "temperature": 1.3,
131
+ "top_k": 100
132
+ },
133
+ # Nucleus sampling β€” top-p based
134
+ {
135
+ "max_new_tokens": 55,
136
+ "do_sample": True,
137
+ "top_p": 0.9,
138
+ "temperature": 1.0
139
+ },
140
+ ]
141
+
142
+ captions = []
143
+ pixel_values = git_proc(images=image, return_tensors="pt").pixel_values
144
 
145
+ for strategy in strategies:
146
+ try:
147
  with torch.no_grad():
148
  generated_ids = git_mod.generate(
149
  pixel_values=pixel_values,
150
+ **strategy
151
  )
152
 
153
  cap = git_proc.batch_decode(
 
161
  st.warning(f"GIT error: {str(e)[:80]}")
162
  captions.append("a scene shown in the image")
163
 
164
+ # Deduplicate while keeping order
165
  seen, unique = set(), []
166
  for c in captions:
167
  if c not in seen:
168
  seen.add(c)
169
  unique.append(c)
170
+
171
+ # If model still returns all duplicates keep originals so voting has input
172
+ if len(unique) < 2:
173
+ unique = captions
174
+
175
  while len(unique) < 5:
176
  unique.append(unique[0])
177
 
178
  return unique[:5]
179
 
 
180
  def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
181
  scores = []
182
  for cap in captions:
 
196
  scores.append(0.0)
197
  return scores
198
 
 
199
  def compute_jina_scores(image: Image.Image, captions: list) -> list:
200
  img_data_uri = image_to_data_uri(image)
201
  scores = []
 
229
  scores.append(0.0)
230
  return scores
231
 
 
232
  def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
233
  try:
234
  img_inp = blip_proc(images=image, return_tensors="pt")
 
256
  st.warning(f"Cosine error: {str(e)[:60]}")
257
  return [0.0] * len(captions)
258
 
 
259
  def majority_voting(captions, itm, jina, cosine) -> tuple:
260
  itm_r = np.argsort(itm)[::-1]
261
  jina_r = np.argsort(jina)[::-1]
 
273
 
274
  return captions[top2[0]], captions[top2[1]], top2, dict(counts)
275
 
 
276
  def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
277
  try:
278
  inputs = dino_proc(
 
315
  st.warning(f"DINO error: {str(e)[:80]}")
316
  return "Object detection unavailable", []
317
 
 
318
  def fuse_captions(cap1: str, cap2: str, objects: str) -> str:
319
  system_prompt = (
320
  "You are an expert image captioning assistant. "
 
357
  st.warning(f"Qwen exception: {str(e)[:60]}")
358
  return cap1
359
 
 
 
 
360
  with st.sidebar:
361
  st.title("Image Caption Fusion")
362
  st.markdown("---")
 
408
  if st.button("Generate Caption", type="primary", use_container_width=True):
409
 
410
  with st.spinner("Loading local models (first run takes 2-3 min)..."):
 
411
  git_proc, git_mod, blip_proc, blip_itm, dino_proc, dino_mod = load_local_models()
412
 
413
  progress = st.progress(0)
414
  status = st.empty()
415
 
416
  status.info("Step 1/7: Generating captions with GIT-Large-COCO...")
 
417
  captions = generate_captions_git(input_image, git_proc, git_mod)
418
  progress.progress(14)
419
 
 
476
  f"font-size:18px;font-weight:500;text-align:center;"
477
  f"line-height:1.6;'>{final}</div>",
478
  unsafe_allow_html=True
479
+ )