added cross-attention saliency
Browse files
app.py
CHANGED
|
@@ -212,29 +212,33 @@ async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Q
|
|
| 212 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 213 |
|
| 214 |
blip = MODELS["blip"]
|
| 215 |
-
# We must explicitly call the vision_model to get the attentions cleanly
|
| 216 |
inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
|
| 217 |
|
|
|
|
| 218 |
with torch.no_grad():
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
-
#
|
| 226 |
-
# Shape: (layers, batch, heads,
|
| 227 |
-
|
| 228 |
|
| 229 |
-
#
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
-
# Normalize
|
| 235 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 236 |
mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
|
| 237 |
-
mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12))
|
| 238 |
|
| 239 |
heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
|
| 240 |
heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
|
|
|
|
| 212 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 213 |
|
| 214 |
blip = MODELS["blip"]
|
|
|
|
| 215 |
inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
|
| 216 |
|
| 217 |
+
# We use the text_decoder because that's where the image and text actually 'meet'
|
| 218 |
with torch.no_grad():
|
| 219 |
+
outputs = blip["model"].text_decoder(
|
| 220 |
+
input_ids=inputs.input_ids,
|
| 221 |
+
attention_mask=inputs.attention_mask,
|
| 222 |
+
encoder_hidden_states=blip["model"].vision_model(inputs.pixel_values).last_hidden_state,
|
| 223 |
+
output_attentions=True # This is key
|
| 224 |
)
|
| 225 |
|
| 226 |
+
# Get Cross-Attentions (the link between text and image)
|
| 227 |
+
# Shape: (layers, batch, heads, text_tokens, image_patches)
|
| 228 |
+
cross_attentions = outputs.cross_attentions[-1]
|
| 229 |
|
| 230 |
+
# Average across heads and text tokens to get a single 1D map of image importance
|
| 231 |
+
# We exclude the first and last text tokens ([CLS], [SEP])
|
| 232 |
+
mask_1d = cross_attentions[0, :, 1:-1, :].mean(dim=(0, 1))
|
| 233 |
+
|
| 234 |
+
# Reshape to the grid (usually 16x16 for BLIP-large)
|
| 235 |
+
grid_size = int(np.sqrt(mask_1d.shape[-1]))
|
| 236 |
+
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
|
| 237 |
|
| 238 |
+
# Normalize and create the "Glow"
|
| 239 |
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
|
| 240 |
mask_pill = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
|
| 241 |
+
mask_pill = mask_pill.filter(ImageFilter.GaussianBlur(radius=12)) # The XAI Glow
|
| 242 |
|
| 243 |
heatmap_rgba = plt.get_cmap('jet')(np.array(mask_pill)/255.0)
|
| 244 |
heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
|