Afsha001 commited on
Commit
00a1160
Β·
verified Β·
1 Parent(s): 7f1f360

update gemini

Browse files
Files changed (1) hide show
  1. app.py +82 -142
app.py CHANGED
@@ -6,6 +6,7 @@ import pandas as pd
6
  import requests
7
  import base64
8
  import streamlit as st
 
9
  from PIL import Image
10
  from io import BytesIO
11
  from collections import Counter
@@ -18,8 +19,11 @@ st.set_page_config(
18
  initial_sidebar_state="expanded"
19
  )
20
 
21
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
22
- JINA_KEY = os.environ.get("JINA_KEY", "")
 
 
 
23
 
24
  JINA_URL = "https://api.jina.ai/v1/rerank"
25
  JINA_HEADERS = {
@@ -41,33 +45,39 @@ DETECT_PROMPT = (
41
  "car . bicycle . motorcycle . bus . truck . street . kitchen . restaurant . cafe"
42
  )
43
 
 
 
 
44
  if not JINA_KEY:
45
  st.error("JINA_KEY missing. Go to Space Settings β†’ Secrets and add it.")
46
  st.stop()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @st.cache_resource
49
  def load_local_models():
50
  from transformers import (
51
- AutoProcessor,
52
  AutoModelForCausalLM,
53
  AutoTokenizer,
54
  BlipProcessor,
55
  BlipForImageTextRetrieval,
 
56
  AutoModelForZeroShotObjectDetection
57
  )
58
  gc.collect()
59
 
60
- florence_processor = AutoProcessor.from_pretrained(
61
- "microsoft/Florence-2-large",
62
- trust_remote_code=True
63
- )
64
- florence_model = AutoModelForCausalLM.from_pretrained(
65
- "microsoft/Florence-2-large",
66
- trust_remote_code=True,
67
- torch_dtype=torch.float32
68
- )
69
- florence_model.eval()
70
-
71
  blip_processor = BlipProcessor.from_pretrained(
72
  "Salesforce/blip-image-captioning-large"
73
  )
@@ -77,6 +87,7 @@ def load_local_models():
77
  )
78
  blip_itm_model.eval()
79
 
 
80
  dino_processor = AutoProcessor.from_pretrained(
81
  "IDEA-Research/grounding-dino-base"
82
  )
@@ -86,6 +97,7 @@ def load_local_models():
86
  )
87
  dino_model.eval()
88
 
 
89
  qwen_tokenizer = AutoTokenizer.from_pretrained(
90
  "Qwen/Qwen2.5-1.5B-Instruct"
91
  )
@@ -96,12 +108,14 @@ def load_local_models():
96
  qwen_model.eval()
97
 
98
  return (
99
- florence_processor, florence_model,
100
  blip_processor, blip_itm_model,
101
  dino_processor, dino_model,
102
  qwen_tokenizer, qwen_model
103
  )
104
 
 
 
 
105
  def image_to_bytes(image: Image.Image) -> bytes:
106
  buf = BytesIO()
107
  image.save(buf, format="JPEG", quality=85)
@@ -112,126 +126,36 @@ def image_to_data_uri(image: Image.Image) -> str:
112
  b64 = base64.b64encode(raw).decode()
113
  return f"data:image/jpeg;base64,{b64}"
114
 
115
- def generate_captions_florence(image: Image.Image, florence_proc, florence_mod) -> list:
116
-
117
- captions = []
118
- image_size = (image.width, image.height)
119
-
120
- # Task 1: Short caption
121
- try:
122
- inputs = florence_proc(
123
- text="<CAPTION>", images=image, return_tensors="pt"
124
- )
125
- with torch.no_grad():
126
- ids = florence_mod.generate(
127
- input_ids=inputs["input_ids"],
128
- pixel_values=inputs["pixel_values"],
129
- max_new_tokens=50, num_beams=3
130
- )
131
- raw = florence_proc.batch_decode(ids, skip_special_tokens=False)[0]
132
- parsed = florence_proc.post_process_generation(raw, task="<CAPTION>", image_size=image_size)
133
- cap = parsed.get("<CAPTION>", "").strip().lower()
134
- captions.append(cap if cap else "a scene shown in the image")
135
- except Exception as e:
136
- st.warning(f"Florence CAPTION error: {str(e)[:80]}")
137
- captions.append("a scene shown in the image")
138
 
139
- # Task 2: Detailed caption
140
- try:
141
- inputs = florence_proc(
142
- text="<DETAILED_CAPTION>", images=image, return_tensors="pt"
143
- )
144
- with torch.no_grad():
145
- ids = florence_mod.generate(
146
- input_ids=inputs["input_ids"],
147
- pixel_values=inputs["pixel_values"],
148
- max_new_tokens=100, num_beams=3
149
- )
150
- raw = florence_proc.batch_decode(ids, skip_special_tokens=False)[0]
151
- parsed = florence_proc.post_process_generation(raw, task="<DETAILED_CAPTION>", image_size=image_size)
152
- cap = parsed.get("<DETAILED_CAPTION>", "").strip().lower()
153
- captions.append(cap if cap else "a scene shown in the image")
154
- except Exception as e:
155
- st.warning(f"Florence DETAILED_CAPTION error: {str(e)[:80]}")
156
- captions.append("a scene shown in the image")
157
 
158
- # Task 3: More detailed caption
159
- try:
160
- inputs = florence_proc(
161
- text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt"
162
- )
163
- with torch.no_grad():
164
- ids = florence_mod.generate(
165
- input_ids=inputs["input_ids"],
166
- pixel_values=inputs["pixel_values"],
167
- max_new_tokens=150, num_beams=3
168
- )
169
- raw = florence_proc.batch_decode(ids, skip_special_tokens=False)[0]
170
- parsed = florence_proc.post_process_generation(raw, task="<MORE_DETAILED_CAPTION>", image_size=image_size)
171
- cap = parsed.get("<MORE_DETAILED_CAPTION>", "").strip().lower()
172
- captions.append(cap if cap else "a scene shown in the image")
173
- except Exception as e:
174
- st.warning(f"Florence MORE_DETAILED_CAPTION error: {str(e)[:80]}")
175
- captions.append("a scene shown in the image")
176
 
177
- # Task 4: Dense region caption
178
- try:
179
- inputs = florence_proc(
180
- text="<DENSE_REGION_CAPTION>", images=image, return_tensors="pt"
181
- )
182
- with torch.no_grad():
183
- ids = florence_mod.generate(
184
- input_ids=inputs["input_ids"],
185
- pixel_values=inputs["pixel_values"],
186
- max_new_tokens=200, num_beams=3
187
- )
188
- raw = florence_proc.batch_decode(ids, skip_special_tokens=False)[0]
189
- parsed = florence_proc.post_process_generation(raw, task="<DENSE_REGION_CAPTION>", image_size=image_size)
190
- labels = parsed.get("<DENSE_REGION_CAPTION>", {}).get("labels", [])
191
-
192
- if labels:
193
- seen_r, unique_r = set(), []
194
- for l in labels:
195
- if l.lower() not in seen_r:
196
- seen_r.add(l.lower())
197
- unique_r.append(l.lower())
198
- cap = ", ".join(unique_r[:6]) + " visible in the scene"
199
- else:
200
- cap = "a scene shown in the image"
201
- captions.append(cap)
202
- except Exception as e:
203
- st.warning(f"Florence DENSE_REGION error: {str(e)[:80]}")
204
- captions.append("a scene shown in the image")
205
 
206
- # Task 5: Object detection
207
- try:
208
- inputs = florence_proc(
209
- text="<OD>", images=image, return_tensors="pt"
210
- )
211
- with torch.no_grad():
212
- ids = florence_mod.generate(
213
- input_ids=inputs["input_ids"],
214
- pixel_values=inputs["pixel_values"],
215
- max_new_tokens=200, num_beams=3
216
- )
217
- raw = florence_proc.batch_decode(ids, skip_special_tokens=False)[0]
218
- parsed = florence_proc.post_process_generation(raw, task="<OD>", image_size=image_size)
219
- labels = parsed.get("<OD>", {}).get("labels", [])
220
-
221
- if labels:
222
- seen_o, unique_o = set(), []
223
- for l in labels:
224
- if l.lower() not in seen_o:
225
- seen_o.add(l.lower())
226
- unique_o.append(l.lower())
227
- cap = "a scene containing " + ", ".join(unique_o[:6])
228
- else:
229
- cap = "a scene shown in the image"
230
- captions.append(cap)
231
- except Exception as e:
232
- st.warning(f"Florence OD error: {str(e)[:80]}")
233
- captions.append("a scene shown in the image")
234
 
 
235
  seen, unique = set(), []
236
  for c in captions:
237
  if c not in seen:
@@ -246,6 +170,9 @@ def generate_captions_florence(image: Image.Image, florence_proc, florence_mod)
246
 
247
  return unique[:5]
248
 
 
 
 
249
  def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
250
  scores = []
251
  for cap in captions:
@@ -265,6 +192,9 @@ def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
265
  scores.append(0.0)
266
  return scores
267
 
 
 
 
268
  def compute_jina_scores(image: Image.Image, captions: list) -> list:
269
  img_data_uri = image_to_data_uri(image)
270
  scores = []
@@ -295,6 +225,9 @@ def compute_jina_scores(image: Image.Image, captions: list) -> list:
295
  scores.append(0.0)
296
  return scores
297
 
 
 
 
298
  def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
299
  try:
300
  img_inp = blip_proc(images=image, return_tensors="pt")
@@ -321,6 +254,9 @@ def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
321
  st.warning(f"Cosine error: {str(e)[:60]}")
322
  return [0.0] * len(captions)
323
 
 
 
 
324
  def majority_voting(captions, itm, jina, cosine) -> tuple:
325
  itm_r = np.argsort(itm)[::-1]
326
  jina_r = np.argsort(jina)[::-1]
@@ -338,6 +274,9 @@ def majority_voting(captions, itm, jina, cosine) -> tuple:
338
 
339
  return captions[top2[0]], captions[top2[1]], top2, dict(counts)
340
 
 
 
 
341
  def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
342
  try:
343
  inputs = dino_proc(
@@ -378,11 +317,7 @@ def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
378
  return "Object detection unavailable", []
379
 
380
  # ============================================================================
381
- # fuse_captions β€” CHANGED
382
- # system_prompt: explicitly covers clothing, colors, people, objects, setting
383
- # user_prompt: asks for all specific details including clothing and background
384
- # max_new_tokens: 100 β†’ 180 (room for 3-4 full sentences)
385
- # temperature: 0.2 β†’ 0.4 (more expressive while staying factual)
386
  # ============================================================================
387
  def fuse_captions(cap1: str, cap2: str, objects: str, qwen_tok, qwen_mod) -> str:
388
 
@@ -443,12 +378,15 @@ def fuse_captions(cap1: str, cap2: str, objects: str, qwen_tok, qwen_mod) -> str
443
  st.warning(f"Qwen fusion error: {str(e)[:80]}")
444
  return cap1
445
 
 
 
 
446
  with st.sidebar:
447
  st.title("Image Caption Fusion")
448
  st.markdown("---")
449
  st.markdown("### Pipeline Steps")
450
  st.markdown("""
451
- **1. Florence-2-Large** (Local)
452
  Generate 5 captions
453
 
454
  **2. BLIP ITM** (Local)
@@ -470,9 +408,12 @@ Object detection
470
  Caption fusion
471
  """)
472
  st.markdown("---")
473
- st.markdown("**Local:** Florence-2, BLIP ITM, DINO, Qwen2.5")
474
- st.markdown("**API:** Jina")
475
 
 
 
 
476
  st.title("Image Caption Fusion System")
477
  st.markdown("Upload an image to generate a refined, grounded caption.")
478
  st.markdown("---")
@@ -493,9 +434,8 @@ if uploaded_file is not None:
493
  with col_run:
494
  if st.button("Generate Caption", type="primary", use_container_width=True):
495
 
496
- with st.spinner("Loading local models (first run takes 3-4 min)..."):
497
  (
498
- florence_proc, florence_mod,
499
  blip_proc, blip_itm,
500
  dino_proc, dino_mod,
501
  qwen_tok, qwen_mod
@@ -504,8 +444,8 @@ if uploaded_file is not None:
504
  progress = st.progress(0)
505
  status = st.empty()
506
 
507
- status.info("Step 1/7: Generating captions with Florence-2-Large...")
508
- captions = generate_captions_florence(input_image, florence_proc, florence_mod)
509
  progress.progress(14)
510
 
511
  with st.expander("5 Generated Captions", expanded=True):
 
6
  import requests
7
  import base64
8
  import streamlit as st
9
+ import google.generativeai as genai
10
  from PIL import Image
11
  from io import BytesIO
12
  from collections import Counter
 
19
  initial_sidebar_state="expanded"
20
  )
21
 
22
+ # ============================================================================
23
+ # CREDENTIALS
24
+ # ============================================================================
25
+ JINA_KEY = os.environ.get("JINA_KEY", "")
26
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
27
 
28
  JINA_URL = "https://api.jina.ai/v1/rerank"
29
  JINA_HEADERS = {
 
45
  "car . bicycle . motorcycle . bus . truck . street . kitchen . restaurant . cafe"
46
  )
47
 
48
+ # ============================================================================
49
+ # CREDENTIAL CHECK
50
+ # ============================================================================
51
  if not JINA_KEY:
52
  st.error("JINA_KEY missing. Go to Space Settings β†’ Secrets and add it.")
53
  st.stop()
54
 
55
+ if not GOOGLE_API_KEY:
56
+ st.error("GOOGLE_API_KEY missing. Go to Space Settings β†’ Secrets and add it.")
57
+ st.stop()
58
+
59
+ # Configure Gemini API
60
+ genai.configure(api_key=GOOGLE_API_KEY)
61
+
62
+ # ============================================================================
63
+ # LOAD LOCAL MODELS
64
+ # Florence-2-Large removed β€” replaced by Gemini 1.5 Flash API
65
+ # Saves 1.6GB RAM and 2-3 min startup time
66
+ # Local: BLIP ITM, DINO, Qwen2.5
67
+ # ============================================================================
68
  @st.cache_resource
69
  def load_local_models():
70
  from transformers import (
 
71
  AutoModelForCausalLM,
72
  AutoTokenizer,
73
  BlipProcessor,
74
  BlipForImageTextRetrieval,
75
+ AutoProcessor,
76
  AutoModelForZeroShotObjectDetection
77
  )
78
  gc.collect()
79
 
80
+ # BLIP β€” ITM scoring and cosine similarity
 
 
 
 
 
 
 
 
 
 
81
  blip_processor = BlipProcessor.from_pretrained(
82
  "Salesforce/blip-image-captioning-large"
83
  )
 
87
  )
88
  blip_itm_model.eval()
89
 
90
+ # DINO β€” object detection
91
  dino_processor = AutoProcessor.from_pretrained(
92
  "IDEA-Research/grounding-dino-base"
93
  )
 
97
  )
98
  dino_model.eval()
99
 
100
+ # Qwen2.5-1.5B β€” caption fusion
101
  qwen_tokenizer = AutoTokenizer.from_pretrained(
102
  "Qwen/Qwen2.5-1.5B-Instruct"
103
  )
 
108
  qwen_model.eval()
109
 
110
  return (
 
111
  blip_processor, blip_itm_model,
112
  dino_processor, dino_model,
113
  qwen_tokenizer, qwen_model
114
  )
115
 
116
+ # ============================================================================
117
+ # HELPERS
118
+ # ============================================================================
119
  def image_to_bytes(image: Image.Image) -> bytes:
120
  buf = BytesIO()
121
  image.save(buf, format="JPEG", quality=85)
 
126
  b64 = base64.b64encode(raw).decode()
127
  return f"data:image/jpeg;base64,{b64}"
128
 
129
+ # ============================================================================
130
+ # STEP 1 β€” GEMINI 1.5 FLASH (API): GENERATE 5 DIVERSE CAPTIONS
131
+ # 5 different prompts β€” each focuses on a different aspect of the image
132
+ # Gemini sees the image directly as a VLM β€” no hallucination from task tokens
133
+ # API response ~2-4 sec per caption β€” 5 captions in ~15-20 sec total
134
+ # ============================================================================
135
+ def generate_captions_gemini(image: Image.Image) -> list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ model = genai.GenerativeModel("gemini-1.5-flash")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ prompts = [
140
+ "Describe this image in detail covering the overall scene.",
141
+ "Describe the people in this image β€” their clothing colors, style, and what they are doing.",
142
+ "Describe the background, setting, and surroundings visible in this image.",
143
+ "Describe all the objects, plants, and items visible around the people in this image.",
144
+ "Write a full description of this image covering who is in it, what is happening, their appearance, and where it takes place."
145
+ ]
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ captions = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ for prompt in prompts:
150
+ try:
151
+ response = model.generate_content([prompt, image])
152
+ cap = response.text.strip().lower()
153
+ captions.append(cap if cap else "a scene shown in the image")
154
+ except Exception as e:
155
+ st.warning(f"Gemini error: {str(e)[:80]}")
156
+ captions.append("a scene shown in the image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ # Deduplicate while keeping order
159
  seen, unique = set(), []
160
  for c in captions:
161
  if c not in seen:
 
170
 
171
  return unique[:5]
172
 
173
+ # ============================================================================
174
+ # STEP 2 β€” BLIP ITM: IMAGE-TEXT MATCHING SCORES
175
+ # ============================================================================
176
  def compute_itm_scores(image, captions, blip_proc, blip_itm) -> list:
177
  scores = []
178
  for cap in captions:
 
192
  scores.append(0.0)
193
  return scores
194
 
195
+ # ============================================================================
196
+ # STEP 3 β€” JINA RERANKER M0: SEMANTIC SCORES
197
+ # ============================================================================
198
  def compute_jina_scores(image: Image.Image, captions: list) -> list:
199
  img_data_uri = image_to_data_uri(image)
200
  scores = []
 
225
  scores.append(0.0)
226
  return scores
227
 
228
+ # ============================================================================
229
+ # STEP 4 β€” COSINE SIMILARITY: EMBEDDING SCORES
230
+ # ============================================================================
231
  def compute_cosine_scores(image, captions, blip_proc, blip_itm) -> list:
232
  try:
233
  img_inp = blip_proc(images=image, return_tensors="pt")
 
254
  st.warning(f"Cosine error: {str(e)[:60]}")
255
  return [0.0] * len(captions)
256
 
257
+ # ============================================================================
258
+ # STEP 5 β€” MAJORITY VOTING
259
+ # ============================================================================
260
  def majority_voting(captions, itm, jina, cosine) -> tuple:
261
  itm_r = np.argsort(itm)[::-1]
262
  jina_r = np.argsort(jina)[::-1]
 
274
 
275
  return captions[top2[0]], captions[top2[1]], top2, dict(counts)
276
 
277
+ # ============================================================================
278
+ # STEP 6 β€” GROUNDING DINO: OBJECT DETECTION
279
+ # ============================================================================
280
  def detect_objects(image, dino_proc, dino_mod, threshold=0.3) -> tuple:
281
  try:
282
  inputs = dino_proc(
 
317
  return "Object detection unavailable", []
318
 
319
  # ============================================================================
320
+ # STEP 7 β€” QWEN2.5-1.5B (LOCAL): CAPTION FUSION
 
 
 
 
321
  # ============================================================================
322
  def fuse_captions(cap1: str, cap2: str, objects: str, qwen_tok, qwen_mod) -> str:
323
 
 
378
  st.warning(f"Qwen fusion error: {str(e)[:80]}")
379
  return cap1
380
 
381
+ # ============================================================================
382
+ # SIDEBAR
383
+ # ============================================================================
384
  with st.sidebar:
385
  st.title("Image Caption Fusion")
386
  st.markdown("---")
387
  st.markdown("### Pipeline Steps")
388
  st.markdown("""
389
+ **1. Gemini 1.5 Flash** (API)
390
  Generate 5 captions
391
 
392
  **2. BLIP ITM** (Local)
 
408
  Caption fusion
409
  """)
410
  st.markdown("---")
411
+ st.markdown("**Local:** BLIP ITM, DINO, Qwen2.5")
412
+ st.markdown("**API:** Gemini 1.5 Flash, Jina")
413
 
414
+ # ============================================================================
415
+ # MAIN UI
416
+ # ============================================================================
417
  st.title("Image Caption Fusion System")
418
  st.markdown("Upload an image to generate a refined, grounded caption.")
419
  st.markdown("---")
 
434
  with col_run:
435
  if st.button("Generate Caption", type="primary", use_container_width=True):
436
 
437
+ with st.spinner("Loading local models (first run takes 2-3 min)..."):
438
  (
 
439
  blip_proc, blip_itm,
440
  dino_proc, dino_mod,
441
  qwen_tok, qwen_mod
 
444
  progress = st.progress(0)
445
  status = st.empty()
446
 
447
+ status.info("Step 1/7: Generating captions with Gemini 1.5 Flash...")
448
+ captions = generate_captions_gemini(input_image)
449
  progress.progress(14)
450
 
451
  with st.expander("5 Generated Captions", expanded=True):