SaniaE commited on
Commit
7b80a15
·
verified ·
1 Parent(s): 5f1d4a9

added cross-attention saliency

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -212,29 +212,33 @@ async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Q
212
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
213
 
214
  blip = MODELS["blip"]
215
- # We must explicitly call the vision_model to get the attentions cleanly
216
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
217
 
 
218
  with torch.no_grad():
219
- # Get vision outputs specifically to access the self-attention maps
220
- vision_outputs = blip["model"].vision_model(
221
- pixel_values=inputs.pixel_values,
222
- output_attentions=True
 
223
  )
224
 
225
- # Access attentions from the vision model output
226
- # Shape: (layers, batch, heads, patches, patches)
227
- attentions = vision_outputs.attentions[-1]
228
 
229
- # Grid size (usually 16x16 for BLIP)
230
- grid_size = int(np.sqrt(attentions.shape[-1] - 1))
231
- # Take attention from the [CLS] token (index 0) to all other patches
232
- mask = attentions[0, :, 0, 1:].mean(0).view(grid_size, grid_size).cpu().numpy()
 
 
 
233
 
234
- # Normalize, upscale, and blur for that "Pinterest-chic" glow
235
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
236
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
237
- mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12))
238
 
239
  heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
240
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
 
212
  orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
213
 
214
  blip = MODELS["blip"]
 
215
  inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
216
 
217
+ # We use the text_decoder because that's where the image and text actually 'meet'
218
  with torch.no_grad():
219
+ outputs = blip["model"].text_decoder(
220
+ input_ids=inputs.input_ids,
221
+ attention_mask=inputs.attention_mask,
222
+ encoder_hidden_states=blip["model"].vision_model(inputs.pixel_values).last_hidden_state,
223
+ output_attentions=True # This is key
224
  )
225
 
226
+ # Get Cross-Attentions (the link between text and image)
227
+ # Shape: (layers, batch, heads, text_tokens, image_patches)
228
+ cross_attentions = outputs.cross_attentions[-1]
229
 
230
+ # Average across heads and text tokens to get a single 1D map of image importance
231
+ # We exclude the first and last text tokens ([CLS], [SEP])
232
+ mask_1d = cross_attentions[0, :, 1:-1, :].mean(dim=(0, 1))
233
+
234
+ # Reshape to the grid (usually 16x16 for BLIP-large)
235
+ grid_size = int(np.sqrt(mask_1d.shape[-1]))
236
+ mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
237
 
238
+ # Normalize and create the "Glow"
239
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
240
  mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
241
+ mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12)) # The XAI Glow
242
 
243
  heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
244
  heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")