Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| class GradCAMExplainer: | |
| """ | |
| Generates Grad-CAM heatmaps showing which spatial regions of a garment | |
| most influenced a retrieval result. | |
| Works with ViT-based encoders (e.g. FashionSigLIP): hooks into the final | |
| transformer block and reshapes the sequence output into a 2-D spatial grid. | |
| """ | |
| def __init__(self, model, preprocess): | |
| self.model = model | |
| self.preprocess = preprocess # β fixed: now stored correctly | |
| self._hooks: list = [] | |
| self._activations: torch.Tensor | None = None | |
| self._gradients: torch.Tensor | None = None | |
| # ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def explain(self, image: Image.Image, query_vec: np.ndarray) -> np.ndarray: | |
| """ | |
| Returns a float32 HΓW array (values in [0, 1]) highlighting which | |
| parts of `image` are most responsible for its similarity to `query_vec`. | |
| """ | |
| self._register_hooks() | |
| img_tensor = self.preprocess(image).unsqueeze(0) | |
| img_tensor.requires_grad_(True) | |
| # Forward pass | |
| img_vec = self.model.encode_image(img_tensor) # (1, 768) | |
| # Similarity score w.r.t. the query vector is our scalar target | |
| q = torch.tensor(query_vec, dtype=torch.float32) | |
| score = (img_vec @ q).sum() | |
| self.model.zero_grad() | |
| score.backward() | |
| cam = self._compute_cam() | |
| self._remove_hooks() | |
| return cam | |
| # ββ Grad-CAM computation ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_cam(self) -> np.ndarray: | |
| """ | |
| ViT blocks output tensors of shape (seq_len, batch, dim) or | |
| (batch, seq_len, dim) depending on the open_clip version. | |
| We strip the [CLS] token, reshape to a square spatial grid, | |
| and apply the standard Grad-CAM formula. | |
| """ | |
| act = self._activations # captured during forward | |
| grad = self._gradients # captured during backward | |
| if act is None or grad is None: | |
| # Fallback: uniform heatmap | |
| return np.ones((14, 14), dtype=np.float32) | |
| # Normalise tensor layout to (batch, seq_len, dim) | |
| if act.dim() == 3 and act.shape[1] != act.shape[0]: | |
| # shape is (seq_len, batch, dim) β permute | |
| act = act.permute(1, 0, 2) | |
| grad = grad.permute(1, 0, 2) | |
| # Drop CLS token (index 0) β (batch, patches, dim) | |
| act = act[:, 1:, :] | |
| grad = grad[:, 1:, :] | |
| # Grad-CAM weights: mean over the dim axis β (batch, patches) | |
| weights = grad.mean(dim=-1, keepdim=True) # (1, patches, 1) | |
| cam_flat = (weights * act).sum(dim=-1).squeeze(0) # (patches,) | |
| cam_flat = F.relu(cam_flat) | |
| # Reshape to square spatial grid (typically 14Γ14 for ViT-B/16 @ 224px) | |
| n_patches = cam_flat.shape[0] | |
| grid_size = int(n_patches ** 0.5) | |
| cam_2d = cam_flat[: grid_size * grid_size].reshape(grid_size, grid_size) | |
| # Normalise to [0, 1] | |
| cam_np = cam_2d.detach().numpy() | |
| cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8) | |
| return cam_np.astype(np.float32) | |
| # ββ Hook registration βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _register_hooks(self) -> None: | |
| target = self.model.visual.transformer.resblocks[-1] | |
| self._hooks.append( | |
| target.register_forward_hook(self._save_activation) | |
| ) | |
| self._hooks.append( | |
| target.register_full_backward_hook(self._save_gradient) | |
| ) | |
| def _remove_hooks(self) -> None: | |
| for h in self._hooks: | |
| h.remove() | |
| self._hooks.clear() | |
| self._activations = None | |
| self._gradients = None | |
| # ββ Hook callbacks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _save_activation(self, module, input, output) -> None: | |
| # output may be a tuple (e.g. (tensor, attn_weights)); take first element | |
| self._activations = output[0].detach() if isinstance(output, tuple) else output.detach() | |
| def _save_gradient(self, module, grad_input, grad_output) -> None: | |
| self._gradients = grad_output[0].detach() if isinstance(grad_output, tuple) else grad_output.detach() |