Priyanshiiiii commited on
Commit
94c02a1
Β·
verified Β·
1 Parent(s): 9c62ef8

Update explainability.py

Browse files
Files changed (1) hide show
  1. explainability.py +104 -32
explainability.py CHANGED
@@ -1,42 +1,114 @@
1
- # src/explainability.py
2
- import torch, numpy as np
3
- from PIL import Image
4
  import torch.nn.functional as F
 
 
5
 
6
  class GradCAMExplainer:
7
- """Generates attention heatmaps for why a result was retrieved."""
8
-
9
- def __init__(self, model):
10
- self.model = model
11
- self._hooks = []
12
- self._gradients = None
13
- self._activations = None
14
-
 
 
 
 
 
 
 
 
 
15
  def explain(self, image: Image.Image, query_vec: np.ndarray) -> np.ndarray:
16
- """Returns HΓ—W heatmap (values 0-1) highlighting retrieved features."""
 
 
 
17
  self._register_hooks()
18
-
19
- img_tensor = self.preprocess(image).unsqueeze(0).requires_grad_(True)
20
- img_vec = self.model.encode_image(img_tensor)
21
-
22
- # Similarity to query is our scalar target
23
- q = torch.tensor(query_vec).float()
 
 
 
24
  score = (img_vec @ q).sum()
 
25
  score.backward()
26
-
27
- # Grad-CAM formula: global average pooled gradients Γ— activations
28
- weights = self._gradients.mean(dim=[2, 3], keepdim=True)
29
- cam = (weights * self._activations).sum(dim=1).squeeze()
30
- cam = F.relu(cam)
31
- cam = cam / (cam.max() + 1e-8)
32
-
33
  self._remove_hooks()
34
- return cam.detach().numpy() # return to caller to overlay on image
35
-
36
- def _register_hooks(self):
37
- # Hook the last transformer block's attention output in ViT
38
- target_layer = self.model.visual.transformer.resblocks[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  self._hooks.append(
40
- target_layer.register_forward_hook(self._save_activation))
 
41
  self._hooks.append(
42
- target_layer.register_backward_hook(self._save_gradient))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
 
3
  import torch.nn.functional as F
4
+ from PIL import Image
5
+
6
 
7
  class GradCAMExplainer:
8
+ """
9
+ Generates Grad-CAM heatmaps showing which spatial regions of a garment
10
+ most influenced a retrieval result.
11
+
12
+ Works with ViT-based encoders (e.g. FashionSigLIP): hooks into the final
13
+ transformer block and reshapes the sequence output into a 2-D spatial grid.
14
+ """
15
+
16
+ def __init__(self, model, preprocess):
17
+ self.model = model
18
+ self.preprocess = preprocess # ← fixed: now stored correctly
19
+ self._hooks: list = []
20
+ self._activations: torch.Tensor | None = None
21
+ self._gradients: torch.Tensor | None = None
22
+
23
+ # ── Public API ────────────────────────────────────────────────────────────
24
+
25
  def explain(self, image: Image.Image, query_vec: np.ndarray) -> np.ndarray:
26
+ """
27
+ Returns a float32 HΓ—W array (values in [0, 1]) highlighting which
28
+ parts of `image` are most responsible for its similarity to `query_vec`.
29
+ """
30
  self._register_hooks()
31
+
32
+ img_tensor = self.preprocess(image).unsqueeze(0)
33
+ img_tensor.requires_grad_(True)
34
+
35
+ # Forward pass
36
+ img_vec = self.model.encode_image(img_tensor) # (1, 768)
37
+
38
+ # Similarity score w.r.t. the query vector is our scalar target
39
+ q = torch.tensor(query_vec, dtype=torch.float32)
40
  score = (img_vec @ q).sum()
41
+ self.model.zero_grad()
42
  score.backward()
43
+
44
+ cam = self._compute_cam()
 
 
 
 
 
45
  self._remove_hooks()
46
+ return cam
47
+
48
+ # ── Grad-CAM computation ──────────────────────────────────────────────────
49
+
50
+ def _compute_cam(self) -> np.ndarray:
51
+ """
52
+ ViT blocks output tensors of shape (seq_len, batch, dim) or
53
+ (batch, seq_len, dim) depending on the open_clip version.
54
+ We strip the [CLS] token, reshape to a square spatial grid,
55
+ and apply the standard Grad-CAM formula.
56
+ """
57
+ act = self._activations # captured during forward
58
+ grad = self._gradients # captured during backward
59
+
60
+ if act is None or grad is None:
61
+ # Fallback: uniform heatmap
62
+ return np.ones((14, 14), dtype=np.float32)
63
+
64
+ # Normalise tensor layout to (batch, seq_len, dim)
65
+ if act.dim() == 3 and act.shape[1] != act.shape[0]:
66
+ # shape is (seq_len, batch, dim) β€” permute
67
+ act = act.permute(1, 0, 2)
68
+ grad = grad.permute(1, 0, 2)
69
+
70
+ # Drop CLS token (index 0) β†’ (batch, patches, dim)
71
+ act = act[:, 1:, :]
72
+ grad = grad[:, 1:, :]
73
+
74
+ # Grad-CAM weights: mean over the dim axis β†’ (batch, patches)
75
+ weights = grad.mean(dim=-1, keepdim=True) # (1, patches, 1)
76
+ cam_flat = (weights * act).sum(dim=-1).squeeze(0) # (patches,)
77
+ cam_flat = F.relu(cam_flat)
78
+
79
+ # Reshape to square spatial grid (typically 14Γ—14 for ViT-B/16 @ 224px)
80
+ n_patches = cam_flat.shape[0]
81
+ grid_size = int(n_patches ** 0.5)
82
+ cam_2d = cam_flat[: grid_size * grid_size].reshape(grid_size, grid_size)
83
+
84
+ # Normalise to [0, 1]
85
+ cam_np = cam_2d.detach().numpy()
86
+ cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-8)
87
+ return cam_np.astype(np.float32)
88
+
89
+ # ── Hook registration ─────────────────────────────────────────────────────
90
+
91
+ def _register_hooks(self) -> None:
92
+ target = self.model.visual.transformer.resblocks[-1]
93
  self._hooks.append(
94
+ target.register_forward_hook(self._save_activation)
95
+ )
96
  self._hooks.append(
97
+ target.register_full_backward_hook(self._save_gradient)
98
+ )
99
+
100
+ def _remove_hooks(self) -> None:
101
+ for h in self._hooks:
102
+ h.remove()
103
+ self._hooks.clear()
104
+ self._activations = None
105
+ self._gradients = None
106
+
107
+ # ── Hook callbacks ────────────────────────────────────────────────────────
108
+
109
+ def _save_activation(self, module, input, output) -> None:
110
+ # output may be a tuple (e.g. (tensor, attn_weights)); take first element
111
+ self._activations = output[0].detach() if isinstance(output, tuple) else output.detach()
112
+
113
+ def _save_gradient(self, module, grad_input, grad_output) -> None:
114
+ self._gradients = grad_output[0].detach() if isinstance(grad_output, tuple) else grad_output.detach()