Atelier-AI / explainability.py
Priyanshiiiii's picture
Update explainability.py
94c02a1 verified
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
class GradCAMExplainer:
"""
Generates Grad-CAM heatmaps showing which spatial regions of a garment
most influenced a retrieval result.
Works with ViT-based encoders (e.g. FashionSigLIP): hooks into the final
transformer block and reshapes the sequence output into a 2-D spatial grid.
"""
def __init__(self, model, preprocess):
self.model = model
self.preprocess = preprocess # ← fixed: now stored correctly
self._hooks: list = []
self._activations: torch.Tensor | None = None
self._gradients: torch.Tensor | None = None
# ── Public API ────────────────────────────────────────────────────────────
def explain(self, image: Image.Image, query_vec: np.ndarray) -> np.ndarray:
"""
Returns a float32 HΓ—W array (values in [0, 1]) highlighting which
parts of `image` are most responsible for its similarity to `query_vec`.
"""
self._register_hooks()
img_tensor = self.preprocess(image).unsqueeze(0)
img_tensor.requires_grad_(True)
# Forward pass
img_vec = self.model.encode_image(img_tensor) # (1, 768)
# Similarity score w.r.t. the query vector is our scalar target
q = torch.tensor(query_vec, dtype=torch.float32)
score = (img_vec @ q).sum()
self.model.zero_grad()
score.backward()
cam = self._compute_cam()
self._remove_hooks()
return cam
# ── Grad-CAM computation ──────────────────────────────────────────────────
def _compute_cam(self) -> np.ndarray:
"""
ViT blocks output tensors of shape (seq_len, batch, dim) or
(batch, seq_len, dim) depending on the open_clip version.
We strip the [CLS] token, reshape to a square spatial grid,
and apply the standard Grad-CAM formula.
"""
act = self._activations # captured during forward
grad = self._gradients # captured during backward
if act is None or grad is None:
# Fallback: uniform heatmap
return np.ones((14, 14), dtype=np.float32)
# Normalise tensor layout to (batch, seq_len, dim)
if act.dim() == 3 and act.shape[1] != act.shape[0]:
# shape is (seq_len, batch, dim) β€” permute
act = act.permute(1, 0, 2)
grad = grad.permute(1, 0, 2)
# Drop CLS token (index 0) β†’ (batch, patches, dim)
act = act[:, 1:, :]
grad = grad[:, 1:, :]
# Grad-CAM weights: mean over the dim axis β†’ (batch, patches)
weights = grad.mean(dim=-1, keepdim=True) # (1, patches, 1)
cam_flat = (weights * act).sum(dim=-1).squeeze(0) # (patches,)
cam_flat = F.relu(cam_flat)
# Reshape to square spatial grid (typically 14Γ—14 for ViT-B/16 @ 224px)
n_patches = cam_flat.shape[0]
grid_size = int(n_patches ** 0.5)
cam_2d = cam_flat[: grid_size * grid_size].reshape(grid_size, grid_size)
# Normalise to [0, 1]
cam_np = cam_2d.detach().numpy()
cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8)
return cam_np.astype(np.float32)
# ── Hook registration ─────────────────────────────────────────────────────
def _register_hooks(self) -> None:
target = self.model.visual.transformer.resblocks[-1]
self._hooks.append(
target.register_forward_hook(self._save_activation)
)
self._hooks.append(
target.register_full_backward_hook(self._save_gradient)
)
def _remove_hooks(self) -> None:
for h in self._hooks:
h.remove()
self._hooks.clear()
self._activations = None
self._gradients = None
# ── Hook callbacks ────────────────────────────────────────────────────────
def _save_activation(self, module, input, output) -> None:
# output may be a tuple (e.g. (tensor, attn_weights)); take first element
self._activations = output[0].detach() if isinstance(output, tuple) else output.detach()
def _save_gradient(self, module, grad_input, grad_output) -> None:
self._gradients = grad_output[0].detach() if isinstance(grad_output, tuple) else grad_output.detach()