SaniaE commited on
Commit
5f1d4a9
·
verified ·
1 Parent(s): b5397cf

updated calibration logic

Browse files
Files changed (1) hide show
  1. app.py +38 -33
app.py CHANGED
@@ -160,43 +160,49 @@ async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Quer
160
  image = Image.open(file.file).convert("RGB")
161
  blip = MODELS["blip"]
162
 
163
- # Get model's caption
164
  inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
165
  with torch.no_grad():
166
  generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
167
  model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
168
 
169
- # 1. NEW: Localized Keyword Embedding
170
- # We focus on the core nouns and adjectives to prevent 'template bias'
171
- def get_focused_embedding(text):
172
  inputs = blip["processor"](text=text, return_tensors="pt", padding=True).to(DEVICE)
173
  with torch.no_grad():
174
- # Get output from the BERT-based text decoder
175
  outputs = blip["model"].text_decoder.bert(**inputs)
176
- # Average hidden states of ALL tokens to capture keyword specifics
177
  return F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=-1)
178
 
179
- user_embed = get_focused_embedding(user_prompt)
180
- model_embed = get_focused_embedding(model_caption)
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # Visual alignment
183
  with torch.no_grad():
184
  vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
185
  image_embed = F.normalize(vision_outputs.last_hidden_state[:, 0, :], p=2, dim=-1)
186
-
187
- # 2. Calculate Corrected Scores
188
- sim_image_user = torch.matmul(image_embed, user_embed.T).item()
189
- sim_image_model = torch.matmul(image_embed, model_embed.T).item()
190
- sim_user_model = torch.matmul(user_embed, model_embed.T).item()
191
 
192
  return {
193
  "captions": {"user": user_prompt, "model": model_caption},
194
  "similarity_scores": {
195
- "visual_alignment_user": round(sim_image_user, 4),
196
- "visual_alignment_model": round(sim_image_model, 4),
197
- "semantic_overlap": round(sim_user_model, 4)
198
  },
199
- "interpretation": "Strong Agreement" if sim_user_model > 0.8 else "Perspective Divergence"
200
  }
201
 
202
 
@@ -206,32 +212,31 @@ async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Q
206
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
207
 
208
  blip = MODELS["blip"]
209
- # We enable 'output_attentions' to grab the internal map directly
210
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
211
 
212
  with torch.no_grad():
213
- outputs = blip["model"](**inputs, output_attentions=True)
214
- # Use the last layer of vision encoder self-attention
215
- # Shape: (batch, heads, patches, patches)
216
- attentions = outputs.vision_model_output.attentions[-1]
 
 
 
 
 
217
 
218
- # Average across heads and take the attention from the [CLS] token to all patches
219
- # Patch size for BLIP is typically 14x14 or 16x16
220
  grid_size = int(np.sqrt(attentions.shape[-1] - 1))
221
- # Remove [CLS] token and reshape to grid
222
  mask = attentions[0, :, 0, 1:].mean(0).view(grid_size, grid_size).cpu().numpy()
223
 
224
- # 1. Normalize and Upscale
225
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
226
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
 
227
 
228
- # 2. Apply Gaussian Glow for XAI Aesthetic
229
- mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=15))
230
- mask_final = np.array(mask_pill) / 255.0
231
-
232
- # 3. Apply Colormap and Blend
233
- cm = plt.get_cmap('jet')
234
- heatmap_rgba = cm(mask_final)
235
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
236
 
237
  blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
 
160
  image = Image.open(file.file).convert("RGB")
161
  blip = MODELS["blip"]
162
 
 
163
  inputs_gen = blip["processor"](images=image, return_tensors="pt").to(DEVICE)
164
  with torch.no_grad():
165
  generated_ids = blip["model"].generate(**inputs_gen, max_length=40)
166
  model_caption = blip["processor"].decode(generated_ids[0], skip_special_tokens=True)
167
 
168
+ def get_clean_embedding(text):
 
 
169
  inputs = blip["processor"](text=text, return_tensors="pt", padding=True).to(DEVICE)
170
  with torch.no_grad():
 
171
  outputs = blip["model"].text_decoder.bert(**inputs)
 
172
  return F.normalize(outputs.last_hidden_state.mean(dim=1), p=2, dim=-1)
173
 
174
+ user_embed = get_clean_embedding(user_prompt)
175
+ model_embed = get_clean_embedding(model_caption)
176
 
177
+ # --- MLE TRICK: Word-Level Calibration ---
178
+ # This prevents 'Pink Cafe' and 'Yellow Sofa' from being 0.99
179
+ user_words = set(user_prompt.lower().split())
180
+ model_words = set(model_caption.lower().split())
181
+ intersection = user_words.intersection(model_words)
182
+ union = user_words.union(model_words)
183
+ jaccard_sim = len(intersection) / len(union) if len(union) > 0 else 0
184
+
185
+ # Calculate raw embedding similarity
186
+ raw_sim = torch.matmul(user_embed, model_embed.T).item()
187
+
188
+ # Weighted Similarity: Combine vector meaning with actual word overlap
189
+ # This will pull the 0.99 score down if the keywords don't match
190
+ calibrated_overlap = (raw_sim * 0.4) + (jaccard_sim * 0.6)
191
+
192
  # Visual alignment
193
  with torch.no_grad():
194
  vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
195
  image_embed = F.normalize(vision_outputs.last_hidden_state[:, 0, :], p=2, dim=-1)
196
+ sim_image_user = torch.matmul(image_embed, user_embed.T).item()
 
 
 
 
197
 
198
  return {
199
  "captions": {"user": user_prompt, "model": model_caption},
200
  "similarity_scores": {
201
+ "semantic_overlap": round(calibrated_overlap, 4),
202
+ "visual_alignment": round(sim_image_user, 4),
203
+ "word_match_penalty": round(1 - jaccard_sim, 2)
204
  },
205
+ "interpretation": "Perspective Divergence" if calibrated_overlap < 0.6 else "Strong Agreement"
206
  }
207
 
208
 
 
212
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
213
 
214
  blip = MODELS["blip"]
215
+ # We must explicitly call the vision_model to get the attentions cleanly
216
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
217
 
218
  with torch.no_grad():
219
+ # Get vision outputs specifically to access the self-attention maps
220
+ vision_outputs = blip["model"].vision_model(
221
+ pixel_values=inputs.pixel_values,
222
+ output_attentions=True
223
+ )
224
+
225
+ # Access attentions from the vision model output
226
+ # Shape: (layers, batch, heads, patches, patches)
227
+ attentions = vision_outputs.attentions[-1]
228
 
229
+ # Grid size (usually 16x16 for BLIP)
 
230
  grid_size = int(np.sqrt(attentions.shape[-1] - 1))
231
+ # Take attention from the [CLS] token (index 0) to all other patches
232
  mask = attentions[0, :, 0, 1:].mean(0).view(grid_size, grid_size).cpu().numpy()
233
 
234
+ # Normalize, upscale, and blur for that "Pinterest-chic" glow
235
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
236
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
237
+ mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12))
238
 
239
+ heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
 
 
 
 
 
 
240
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
241
 
242
  blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)