import numpy as np import torch import torch.nn.functional as F from PIL import Image class GradCAMExplainer: """ Generates Grad-CAM heatmaps showing which spatial regions of a garment most influenced a retrieval result. Works with ViT-based encoders (e.g. FashionSigLIP): hooks into the final transformer block and reshapes the sequence output into a 2-D spatial grid. """ def __init__(self, model, preprocess): self.model = model self.preprocess = preprocess # ← fixed: now stored correctly self._hooks: list = [] self._activations: torch.Tensor | None = None self._gradients: torch.Tensor | None = None # ── Public API ──────────────────────────────────────────────────────────── def explain(self, image: Image.Image, query_vec: np.ndarray) -> np.ndarray: """ Returns a float32 H×W array (values in [0, 1]) highlighting which parts of `image` are most responsible for its similarity to `query_vec`. """ self._register_hooks() img_tensor = self.preprocess(image).unsqueeze(0) img_tensor.requires_grad_(True) # Forward pass img_vec = self.model.encode_image(img_tensor) # (1, 768) # Similarity score w.r.t. the query vector is our scalar target q = torch.tensor(query_vec, dtype=torch.float32) score = (img_vec @ q).sum() self.model.zero_grad() score.backward() cam = self._compute_cam() self._remove_hooks() return cam # ── Grad-CAM computation ────────────────────────────────────────────────── def _compute_cam(self) -> np.ndarray: """ ViT blocks output tensors of shape (seq_len, batch, dim) or (batch, seq_len, dim) depending on the open_clip version. We strip the [CLS] token, reshape to a square spatial grid, and apply the standard Grad-CAM formula. """ act = self._activations # captured during forward grad = self._gradients # captured during backward if act is None or grad is None: # Fallback: uniform heatmap return np.ones((14, 14), dtype=np.float32) # Normalise tensor layout to (batch, seq_len, dim) if act.dim() == 3 and act.shape[1] != act.shape[0]: # shape is (seq_len, batch, dim) — permute act = act.permute(1, 0, 2) grad = grad.permute(1, 0, 2) # Drop CLS token (index 0) → (batch, patches, dim) act = act[:, 1:, :] grad = grad[:, 1:, :] # Grad-CAM weights: mean over the dim axis → (batch, patches) weights = grad.mean(dim=-1, keepdim=True) # (1, patches, 1) cam_flat = (weights * act).sum(dim=-1).squeeze(0) # (patches,) cam_flat = F.relu(cam_flat) # Reshape to square spatial grid (typically 14×14 for ViT-B/16 @ 224px) n_patches = cam_flat.shape[0] grid_size = int(n_patches ** 0.5) cam_2d = cam_flat[: grid_size * grid_size].reshape(grid_size, grid_size) # Normalise to [0, 1] cam_np = cam_2d.detach().numpy() cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8) return cam_np.astype(np.float32) # ── Hook registration ───────────────────────────────────────────────────── def _register_hooks(self) -> None: target = self.model.visual.transformer.resblocks[-1] self._hooks.append( target.register_forward_hook(self._save_activation) ) self._hooks.append( target.register_full_backward_hook(self._save_gradient) ) def _remove_hooks(self) -> None: for h in self._hooks: h.remove() self._hooks.clear() self._activations = None self._gradients = None # ── Hook callbacks ──────────────────────────────────────────────────────── def _save_activation(self, module, input, output) -> None: # output may be a tuple (e.g. (tensor, attn_weights)); take first element self._activations = output[0].detach() if isinstance(output, tuple) else output.detach() def _save_gradient(self, module, grad_input, grad_output) -> None: self._gradients = grad_output[0].detach() if isinstance(grad_output, tuple) else grad_output.detach()