Afsha001 commited on
Commit
8dffcbd
·
verified ·
1 Parent(s): 05ab8bc

update app.py

Browse files
Files changed (1) hide show
  1. app.py +489 -2
app.py CHANGED
@@ -1,3 +1,490 @@
1
 
2
- # PASTE YOUR COMPLETE app.py CONTENT HERE
3
- # (the one from /mnt/user-data/outputs/app.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import os
3
+ import gc
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ import requests
8
+ import base64
9
+ import streamlit as st
10
+ from PIL import Image
11
+ from io import BytesIO
12
+ from collections import Counter
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ from sklearn.preprocessing import normalize
15
+
16
+ # ============================================================================
17
+ # PAGE CONFIG
18
+ # ============================================================================
19
+ st.set_page_config(
20
+ page_title="Image Caption Fusion System",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded"
23
+ )
24
+
25
+ # ============================================================================
26
+ # CREDENTIALS
27
+ # ============================================================================
28
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
29
+ JINA_KEY = os.environ.get("JINA_KEY", "")
30
+
31
+ # ============================================================================
32
+ # API ENDPOINTS
33
+ # Florence-2: raw bytes, no Content-Type
34
+ # Qwen2.5: model-specific endpoint (not generic /v1/chat/completions)
35
+ # Jina: query=plain string, documents=list of data URI strings
36
+ # ============================================================================
37
+ FLORENCE_URL = "https://api-inference.huggingface.co/models/microsoft/Florence-2-large"
38
+ FLORENCE_HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
39
+
40
+ QWEN_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-1.5B-Instruct/v1/chat/completions"
41
+ HF_HEADERS = {
42
+ "Authorization": f"Bearer {HF_TOKEN}",
43
+ "Content-Type": "application/json"
44
+ }
45
+
46
+ JINA_URL = "https://api.jina.ai/v1/rerank"
47
+ JINA_HEADERS = {
48
+ "Authorization": f"Bearer {JINA_KEY}",
49
+ "Content-Type": "application/json"
50
+ }
51
+
52
+ DETECT_PROMPT = (
53
+ "person . child . man . woman . boy . girl . "
54
+ "dog . cat . horse . bird . animal . "
55
+ "ball . toy . bicycle . car . bench . "
56
+ "tree . grass . water . sky . mountain . "
57
+ "building . stairs . door . fence . floor . "
58
+ "jacket . dress . shirt . hat . bag ."
59
+ )
60
+
61
+ # ============================================================================
62
+ # CREDENTIAL CHECK
63
+ # ============================================================================
64
+ if not HF_TOKEN:
65
+ st.error("HF_TOKEN missing. Go to Space Settings → Secrets and add it.")
66
+ st.stop()
67
+
68
+ if not JINA_KEY:
69
+ st.error("JINA_KEY missing. Go to Space Settings → Secrets and add it.")
70
+ st.stop()
71
+
72
+ # ============================================================================
73
+ # LOAD LOCAL MODELS — BLIP ITM + GROUNDING DINO
74
+ # Cached so they load only once per session
75
+ # ============================================================================
76
+ @st.cache_resource
77
+ def load_local_models():
78
+ from transformers import (
79
+ BlipProcessor,
80
+ BlipForImageTextRetrieval,
81
+ AutoProcessor,
82
+ AutoModelForZeroShotObjectDetection
83
+ )
84
+ gc.collect()
85
+
86
+ blip_processor = BlipProcessor.from_pretrained(
87
+ "Salesforce/blip-image-captioning-large"
88
+ )
89
+ blip_itm_model = BlipForImageTextRetrieval.from_pretrained(
90
+ "Salesforce/blip-itm-large-coco",
91
+ torch_dtype=torch.float32
92
+ )
93
+ blip_itm_model.eval()
94
+
95
+ dino_processor = AutoProcessor.from_pretrained(
96
+ "IDEA-Research/grounding-dino-base"
97
+ )
98
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
99
+ "IDEA-Research/grounding-dino-base",
100
+ torch_dtype=torch.float32
101
+ )
102
+ dino_model.eval()
103
+
104
+ return blip_processor, blip_itm_model, dino_processor, dino_model
105
+
106
+ # ============================================================================
107
+ # HELPERS
108
+ # ============================================================================
109
+ def image_to_bytes(image: Image.Image) -> bytes:
110
+ buf = BytesIO()
111
+ image.save(buf, format="JPEG", quality=85)
112
+ return buf.getvalue()
113
+
114
+ def image_to_data_uri(image: Image.Image) -> str:
115
+ raw = image_to_bytes(image)
116
+ b64 = base64.b64encode(raw).decode()
117
+ return f"data:image/jpeg;base64,{b64}"
118
+
119
+ # ============================================================================
120
+ # STEP 1 — FLORENCE-2-LARGE: GENERATE 5 CAPTIONS
121
+ # Fix applied: data=raw_bytes instead of json={"inputs": base64}
122
+ # ============================================================================
123
+ def generate_captions_florence(image: Image.Image) -> list:
124
+ img_bytes = image_to_bytes(image)
125
+ captions = []
126
+
127
+ for i in range(5):
128
+ try:
129
+ response = requests.post(
130
+ FLORENCE_URL,
131
+ headers=FLORENCE_HEADERS,
132
+ data=img_bytes,
133
+ params={"wait_for_model": True},
134
+ timeout=60
135
+ )
136
+ if response.status_code == 200:
137
+ result = response.json()
138
+ if isinstance(result, list):
139
+ cap = result[0].get("generated_text", "").strip().lower()
140
+ elif isinstance(result, dict):
141
+ cap = result.get("generated_text", "").strip().lower()
142
+ else:
143
+ cap = ""
144
+ captions.append(cap if cap else "a scene shown in the image")
145
+ else:
146
+ st.warning(f"Florence API error {response.status_code}")
147
+ captions.append("a scene shown in the image")
148
+ except Exception as e:
149
+ st.warning(f"Florence exception: {str(e)[:80]}")
150
+ captions.append("a scene shown in the image")
151
+
152
+ seen, unique = set(), []
153
+ for c in captions:
154
+ if c not in seen:
155
+ seen.add(c)
156
+ unique.append(c)
157
+ while len(unique) < 5:
158
+ unique.append(unique[0])
159
+ return unique[:5]
160
+
161
+ # ============================================================================
162
+ # STEP 2 — BLIP ITM: IMAGE-TEXT MATCHING SCORES
163
+ # Local model, no API call needed
164
+ # ============================================================================
165
+ def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
166
+ scores = []
167
+ for cap in captions:
168
+ try:
169
+ inputs = blip_proc(
170
+ images=image, text=cap,
171
+ return_tensors="pt", padding=True
172
+ )
173
+ with torch.no_grad():
174
+ out = blip_itm(**inputs)
175
+ score = torch.nn.functional.softmax(
176
+ out.itm_score, dim=1
177
+ )[:, 1].item()
178
+ scores.append(round(float(score), 4))
179
+ except Exception as e:
180
+ st.warning(f"ITM error: {str(e)[:60]}")
181
+ scores.append(0.0)
182
+ return scores
183
+
184
+ # ============================================================================
185
+ # STEP 3 — JINA RERANKER M0: SEMANTIC SCORES
186
+ # Fix applied: query=plain string, documents=[data_uri_string]
187
+ # ============================================================================
188
+ def compute_jina_scores(image: Image.Image, captions: list) -> list:
189
+ img_data_uri = image_to_data_uri(image)
190
+ scores = []
191
+
192
+ for cap in captions:
193
+ try:
194
+ payload = {
195
+ "model": "jina-reranker-m0",
196
+ "query": cap,
197
+ "documents": [img_data_uri],
198
+ "top_n": 1
199
+ }
200
+ response = requests.post(
201
+ JINA_URL,
202
+ headers=JINA_HEADERS,
203
+ json=payload,
204
+ timeout=30
205
+ )
206
+ if response.status_code == 200:
207
+ result = response.json()
208
+ if "results" in result and result["results"]:
209
+ score = result["results"][0].get("relevance_score", 0.0)
210
+ scores.append(round(float(score), 4))
211
+ else:
212
+ scores.append(0.0)
213
+ else:
214
+ st.warning(f"Jina API error {response.status_code}: {response.text[:100]}")
215
+ scores.append(0.0)
216
+ except Exception as e:
217
+ st.warning(f"Jina exception: {str(e)[:60]}")
218
+ scores.append(0.0)
219
+ return scores
220
+
221
+ # ============================================================================
222
+ # STEP 4 — COSINE SIMILARITY: EMBEDDING SCORES
223
+ # Local model, reuses BLIP encoders
224
+ # ============================================================================
225
+ def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
226
+ try:
227
+ img_inp = blip_proc(images=image, return_tensors="pt")
228
+ with torch.no_grad():
229
+ vis = blip_itm.vision_model(pixel_values=img_inp["pixel_values"])
230
+ img_feat = blip_itm.vision_proj(vis.last_hidden_state[:, 0, :]).numpy()
231
+ img_feat = normalize(img_feat, norm="l2")
232
+
233
+ cap_inp = blip_proc(
234
+ text=captions, return_tensors="pt",
235
+ padding=True, truncation=True, max_length=512
236
+ )
237
+ with torch.no_grad():
238
+ txt = blip_itm.text_encoder(
239
+ input_ids=cap_inp["input_ids"],
240
+ attention_mask=cap_inp["attention_mask"]
241
+ )
242
+ cap_feat = blip_itm.text_proj(txt.last_hidden_state[:, 0, :]).numpy()
243
+ cap_feat = normalize(cap_feat, norm="l2")
244
+
245
+ sims = cosine_similarity(img_feat, cap_feat)[0]
246
+ return [round(float(s), 4) for s in sims]
247
+
248
+ except Exception as e:
249
+ st.warning(f"Cosine error: {str(e)[:60]}")
250
+ return [0.0] * len(captions)
251
+
252
+ # ============================================================================
253
+ # STEP 5 — MAJORITY VOTING: SELECT TOP 2 CAPTIONS
254
+ # Each of 3 methods votes for its top 2 — 6 votes total
255
+ # ============================================================================
256
+ def majority_voting(captions, itm, jina, cosine) -> tuple:
257
+ itm_r = np.argsort(itm)[::-1]
258
+ jina_r = np.argsort(jina)[::-1]
259
+ cosine_r = np.argsort(cosine)[::-1]
260
+
261
+ votes = [
262
+ int(itm_r[0]), int(itm_r[1]),
263
+ int(jina_r[0]), int(jina_r[1]),
264
+ int(cosine_r[0]), int(cosine_r[1])
265
+ ]
266
+ counts = Counter(votes)
267
+ top2 = [idx for idx, _ in counts.most_common(2)]
268
+ if len(top2) < 2:
269
+ top2 = [int(itm_r[0]), int(jina_r[0])]
270
+
271
+ return captions[top2[0]], captions[top2[1]], top2, dict(counts)
272
+
273
+ # ============================================================================
274
+ # STEP 6 — GROUNDING DINO: OBJECT DETECTION
275
+ # Local model, provides factual grounding for LLM fusion
276
+ # ============================================================================
277
+ def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
278
+ try:
279
+ inputs = dino_proc(
280
+ images=image, text=DETECT_PROMPT, return_tensors="pt"
281
+ )
282
+ with torch.no_grad():
283
+ outputs = dino_mod(**inputs)
284
+
285
+ target_sizes = torch.tensor([image.size[::-1]])
286
+ results = dino_proc.post_process_grounded_object_detection(
287
+ outputs,
288
+ inputs.input_ids,
289
+ target_sizes=target_sizes
290
+ )[0]
291
+
292
+ scores = results["scores"]
293
+ labels = results.get("text_labels", results["labels"])
294
+
295
+ keep = scores >= threshold
296
+ kept_sc = scores[keep].tolist()
297
+ kept_lbl = [labels[i] for i in range(len(labels)) if keep[i]]
298
+
299
+ if not kept_lbl:
300
+ return "No objects detected", []
301
+
302
+ label_dict = {}
303
+ for lbl, sc in zip(kept_lbl, kept_sc):
304
+ lbl = lbl.strip().lower()
305
+ if lbl not in label_dict or label_dict[lbl] < sc:
306
+ label_dict[lbl] = sc
307
+
308
+ sorted_labels = [
309
+ l for l, _ in
310
+ sorted(label_dict.items(), key=lambda x: x[1], reverse=True)
311
+ ]
312
+ formatted = "Detected objects: [" + ", ".join(sorted_labels) + "]"
313
+ return formatted, sorted_labels
314
+
315
+ except Exception as e:
316
+ st.warning(f"DINO error: {str(e)[:80]}")
317
+ return "Object detection unavailable", []
318
+
319
+ # ============================================================================
320
+ # STEP 7 — QWEN2.5-1.5B: CAPTION FUSION
321
+ # Fix applied: model-specific endpoint URL
322
+ # ============================================================================
323
+ def fuse_captions(cap1: str, cap2: str, objects: str) -> str:
324
+ system_prompt = (
325
+ "You are an expert image captioning assistant. "
326
+ "Write ONE natural, fluent, descriptive caption combining the best details. "
327
+ "Return ONLY the caption, no explanation or prefix."
328
+ )
329
+ user_prompt = (
330
+ f"Caption A: {cap1}\n"
331
+ f"Caption B: {cap2}\n"
332
+ f"{objects}\n\n"
333
+ "Fused caption:"
334
+ )
335
+ try:
336
+ payload = {
337
+ "model": "Qwen/Qwen2.5-1.5B-Instruct",
338
+ "messages": [
339
+ {"role": "system", "content": system_prompt},
340
+ {"role": "user", "content": user_prompt}
341
+ ],
342
+ "max_tokens": 100,
343
+ "temperature": 0.3,
344
+ "top_p": 0.9
345
+ }
346
+ response = requests.post(
347
+ QWEN_URL,
348
+ headers=HF_HEADERS,
349
+ json=payload,
350
+ timeout=40
351
+ )
352
+ if response.status_code == 200:
353
+ fused = response.json()["choices"][0]["message"]["content"].strip()
354
+ for prefix in ["Fused caption:", "Caption:", "Result:"]:
355
+ if fused.lower().startswith(prefix.lower()):
356
+ fused = fused[len(prefix):].strip()
357
+ return fused if fused else cap1
358
+ else:
359
+ st.warning(f"Qwen API error {response.status_code}")
360
+ return cap1
361
+ except Exception as e:
362
+ st.warning(f"Qwen exception: {str(e)[:60]}")
363
+ return cap1
364
+
365
+ # ============================================================================
366
+ # SIDEBAR
367
+ # ============================================================================
368
+ with st.sidebar:
369
+ st.title("Image Caption Fusion")
370
+ st.markdown("---")
371
+ st.markdown("### Pipeline Steps")
372
+ st.markdown("""
373
+ **1. Florence-2-Large** (API)
374
+ Generate 5 captions
375
+
376
+ **2. BLIP ITM** (Local)
377
+ Image-text matching
378
+
379
+ **3. Jina Reranker M0** (API)
380
+ Semantic reranking
381
+
382
+ **4. Cosine Similarity** (Local)
383
+ Embedding similarity
384
+
385
+ **5. Majority Voting**
386
+ Best 2 captions selected
387
+
388
+ **6. Grounding DINO** (Local)
389
+ Object detection
390
+
391
+ **7. Qwen2.5-1.5B** (API)
392
+ Caption fusion
393
+ """)
394
+ st.markdown("---")
395
+ st.markdown("**Local:** BLIP ITM, DINO")
396
+ st.markdown("**API:** Florence-2, Jina, Qwen2.5")
397
+
398
+ # ============================================================================
399
+ # MAIN UI
400
+ # ============================================================================
401
+ st.title("Image Caption Fusion System")
402
+ st.markdown("Upload an image to generate a refined, grounded caption.")
403
+ st.markdown("---")
404
+
405
+ uploaded_file = st.file_uploader(
406
+ "Select an image",
407
+ type=["jpg", "jpeg", "png"]
408
+ )
409
+
410
+ if uploaded_file is not None:
411
+ input_image = Image.open(uploaded_file).convert("RGB")
412
+
413
+ col_img, col_run = st.columns([1, 1])
414
+
415
+ with col_img:
416
+ st.image(input_image, caption="Uploaded Image", use_column_width=True)
417
+
418
+ with col_run:
419
+ if st.button("Run Pipeline", type="primary", use_container_width=True):
420
+
421
+ with st.spinner("Loading local models (first run takes 1-2 min)..."):
422
+ blip_proc, blip_itm, dino_proc, dino_mod = load_local_models()
423
+
424
+ progress = st.progress(0)
425
+ status = st.empty()
426
+
427
+ status.info("Step 1/7: Generating captions with Florence-2-Large...")
428
+ captions = generate_captions_florence(input_image)
429
+ progress.progress(14)
430
+
431
+ with st.expander("5 Generated Captions", expanded=True):
432
+ for i, cap in enumerate(captions):
433
+ st.write(f"**{i+1}.** {cap}")
434
+
435
+ status.info("Step 2/7: Computing BLIP ITM scores...")
436
+ itm_scores = compute_itm_scores(input_image, captions, blip_proc, blip_itm)
437
+ progress.progress(28)
438
+
439
+ status.info("Step 3/7: Computing Jina Reranker scores...")
440
+ jina_scores = compute_jina_scores(input_image, captions)
441
+ progress.progress(42)
442
+
443
+ status.info("Step 4/7: Computing Cosine Similarity scores...")
444
+ cosine_scores = compute_cosine_scores(input_image, captions, blip_proc, blip_itm)
445
+ progress.progress(57)
446
+
447
+ scores_df = pd.DataFrame({
448
+ "Caption": [f"Cap {i+1}: {c[:50]}" for i, c in enumerate(captions)],
449
+ "ITM": itm_scores,
450
+ "Jina": jina_scores,
451
+ "Cosine": cosine_scores
452
+ })
453
+ with st.expander("All Scores", expanded=False):
454
+ st.dataframe(scores_df, use_container_width=True, hide_index=True)
455
+
456
+ status.info("Step 5/7: Running majority voting...")
457
+ best_1, best_2, _, _ = majority_voting(
458
+ captions, itm_scores, jina_scores, cosine_scores
459
+ )
460
+ progress.progress(71)
461
+
462
+ st.markdown("### Majority Voted Captions")
463
+ c1, c2 = st.columns(2)
464
+ with c1:
465
+ st.success(f"1. {best_1}")
466
+ with c2:
467
+ st.info(f"2. {best_2}")
468
+
469
+ status.info("Step 6/7: Detecting objects with DINO...")
470
+ obj_str, obj_list = detect_objects(input_image, dino_proc, dino_mod)
471
+ progress.progress(85)
472
+
473
+ st.markdown("### Detected Objects")
474
+ st.write(" | ".join(obj_list) if obj_list else obj_str)
475
+
476
+ status.info("Step 7/7: Fusing captions with Qwen2.5-1.5B...")
477
+ final = fuse_captions(best_1, best_2, obj_str)
478
+ progress.progress(100)
479
+ status.success("Pipeline complete!")
480
+
481
+ st.markdown("---")
482
+ st.markdown("### Final Fused Caption")
483
+ st.markdown(
484
+ f"<div style='"
485
+ f"background:linear-gradient(135deg,#667eea,#764ba2);"
486
+ f"padding:24px;border-radius:12px;color:white;"
487
+ f"font-size:18px;font-weight:500;text-align:center;"
488
+ f"line-height:1.6;'>{final}</div>",
489
+ unsafe_allow_html=True
490
+ )