updated saliency map logic
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import torch
|
| 3 |
import random
|
| 4 |
import asyncio
|
| 5 |
-
from PIL import Image
|
| 6 |
from fastapi import FastAPI, UploadFile, File, Query
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from huggingface_hub import snapshot_download, login
|
|
@@ -171,16 +171,18 @@ async def concept_ensemble(file: UploadFile = File(...), user_prompt: str = Quer
|
|
| 171 |
inputs_text = blip["processor"](text=texts, return_tensors="pt", padding=True).to(DEVICE)
|
| 172 |
|
| 173 |
with torch.no_grad():
|
| 174 |
-
# Get
|
| 175 |
-
text_outputs = blip["model"].text_encoder(**inputs_text)
|
| 176 |
-
text_embeds = text_outputs.last_hidden_state[:, 0, :] # Use [CLS] token
|
| 177 |
-
|
| 178 |
vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
|
| 179 |
-
image_embeds = vision_outputs.last_hidden_state[:, 0, :]
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
|
|
|
|
| 184 |
|
| 185 |
# Similarity Matrix calculation
|
| 186 |
sim_image_user = torch.matmul(image_embeds, text_embeds[0].T).item()
|
|
@@ -207,32 +209,39 @@ async def get_saliency_heatmap(file: UploadFile = File(...), query_text: str = Q
|
|
| 207 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 208 |
|
| 209 |
blip = MODELS["blip"]
|
|
|
|
| 210 |
inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
|
|
|
|
| 211 |
|
| 212 |
# 2. Extract Gradients for Saliency
|
| 213 |
-
inputs.pixel_values.requires_grad = True
|
| 214 |
outputs = blip["model"](**inputs, labels=inputs["input_ids"])
|
| 215 |
loss = outputs.loss
|
| 216 |
loss.backward()
|
| 217 |
|
| 218 |
-
#
|
|
|
|
| 219 |
grad = inputs.pixel_values.grad.abs().max(dim=1)[0][0].cpu().numpy()
|
| 220 |
|
| 221 |
-
# 3. Create Heatmap with
|
| 222 |
# Normalize to [0, 1]
|
| 223 |
grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
|
| 224 |
|
| 225 |
-
# Apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
cm = plt.get_cmap('jet')
|
| 227 |
-
heatmap_rgba = cm(
|
| 228 |
|
| 229 |
-
# Convert heatmap to PIL
|
| 230 |
heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
|
| 231 |
heatmap_img = heatmap_img.resize(orig_img.size, resample=Image.BILINEAR)
|
| 232 |
|
| 233 |
-
# 4. Blend Original + Heatmap
|
| 234 |
-
# 0.
|
| 235 |
-
blended_img = Image.blend(orig_img, heatmap_img, alpha=0.
|
| 236 |
|
| 237 |
# 5. Stream back
|
| 238 |
buf = io.BytesIO()
|
|
|
|
| 2 |
import torch
|
| 3 |
import random
|
| 4 |
import asyncio
|
| 5 |
+
from PIL import Image, ImageFilter
|
| 6 |
from fastapi import FastAPI, UploadFile, File, Query
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from huggingface_hub import snapshot_download, login
|
|
|
|
| 171 |
inputs_text = blip["processor"](text=texts, return_tensors="pt", padding=True).to(DEVICE)
|
| 172 |
|
| 173 |
with torch.no_grad():
|
| 174 |
+
# 1. Get Image Embeddings from the vision_model
|
|
|
|
|
|
|
|
|
|
| 175 |
vision_outputs = blip["model"].vision_model(inputs_gen["pixel_values"])
|
| 176 |
+
image_embeds = vision_outputs.last_hidden_state[:, 0, :] # Use [CLS] token
|
| 177 |
|
| 178 |
+
# 2. Get Text Embeddings using the text_decoder's bert model
|
| 179 |
+
# BLIP's text_decoder typically wraps a BERT-like architecture
|
| 180 |
+
text_outputs = blip["model"].text_decoder.bert(**inputs_text)
|
| 181 |
+
text_embeds = text_outputs.last_hidden_state[:, 0, :] # Use [CLS] token
|
| 182 |
+
|
| 183 |
+
# Normalize
|
| 184 |
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
|
| 185 |
+
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
|
| 186 |
|
| 187 |
# Similarity Matrix calculation
|
| 188 |
sim_image_user = torch.matmul(image_embeds, text_embeds[0].T).item()
|
|
|
|
| 209 |
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 210 |
|
| 211 |
blip = MODELS["blip"]
|
| 212 |
+
# Ensure pixel_values can track gradients
|
| 213 |
inputs = blip["processor"](images=orig_img, text=query_text, return_tensors="pt").to(DEVICE)
|
| 214 |
+
inputs.pixel_values.requires_grad = True
|
| 215 |
|
| 216 |
# 2. Extract Gradients for Saliency
|
|
|
|
| 217 |
outputs = blip["model"](**inputs, labels=inputs["input_ids"])
|
| 218 |
loss = outputs.loss
|
| 219 |
loss.backward()
|
| 220 |
|
| 221 |
+
# Generate Saliency from gradients of pixel values
|
| 222 |
+
# We take the maximum absolute gradient across the RGB channels
|
| 223 |
grad = inputs.pixel_values.grad.abs().max(dim=1)[0][0].cpu().numpy()
|
| 224 |
|
| 225 |
+
# 3. Create Heatmap with "Glow" Effect (XAI Style)
|
| 226 |
# Normalize to [0, 1]
|
| 227 |
grad = (grad - grad.min()) / (grad.max() - grad.min() + 1e-8)
|
| 228 |
|
| 229 |
+
# Apply Gaussian Blur to smooth tiny speckles into a professional heatmap
|
| 230 |
+
grad_pill = Image.fromarray((grad * 255).astype('uint8'))
|
| 231 |
+
grad_pill = grad_pill.filter(ImageFilter.GaussianBlur(radius=8))
|
| 232 |
+
grad_smoothed = np.array(grad_pill) / 255.0
|
| 233 |
+
|
| 234 |
+
# Apply colormap (jet)
|
| 235 |
cm = plt.get_cmap('jet')
|
| 236 |
+
heatmap_rgba = cm(grad_smoothed)
|
| 237 |
|
| 238 |
+
# Convert heatmap to PIL and resize to original image dimensions
|
| 239 |
heatmap_img = Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype('uint8')).convert("RGB")
|
| 240 |
heatmap_img = heatmap_img.resize(orig_img.size, resample=Image.BILINEAR)
|
| 241 |
|
| 242 |
+
# 4. Blend Original + Heatmap (Adjust alpha for visibility on dark/light UIs)
|
| 243 |
+
# 0.5 alpha provides a strong clear highlight for the "Rorompok" sofa
|
| 244 |
+
blended_img = Image.blend(orig_img, heatmap_img, alpha=0.5)
|
| 245 |
|
| 246 |
# 5. Stream back
|
| 247 |
buf = io.BytesIO()
|