File size: 9,855 Bytes
0710b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
step3_gradcam_flow.py
======================
STEP 3 β€” Multi-Layer Gradient-Weighted Attention Flow.

What this implements (Iteration 3 upgrade over single-layer GradCAM):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
In Iteration 2, we only looked at the last cross-attention layer.
Here, we hook into ALL cross-attention layers simultaneously, run
GradCAM on each one, and then combine them using "Attention Flow"
(Abnar & Zuidema, 2020):

  rollout[0] = gradcam(layer 0)
  rollout[i] = normalize(rollout[i-1]) Γ— gradcam(layer i)
             + 0.5 Γ— I               ← residual identity keeps a
                                       "floor" of global context so
                                       the rollout never collapses to
                                       zero for deep layers.

This multi-layer aggregation captures:
  β€’ Early layers  β†’ edges, textures, spatial structure.
  β€’ Middle layers β†’ part-level features (ears, wheels, legs).
  β€’ Last layer    β†’ high-level semantic concepts.

The combined result is a 14Γ—14 heatmap with dramatically tighter
object contours and less background bleed.

Additionally, instead of plain OpenCV nearest-neighbour upsampling,
we use PyTorch bicubic interpolation to upscale 14Γ—14 β†’ 224Γ—224,
producing pixel-smooth, non-blocky heatmap edges.
"""

import os
import sys
import math
import torch
import torch.nn.functional as F
import numpy as np

_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)


# ────────────────────────────────────────────────────────────────────────────
class FlowExtractor:
    """
    Registers forward hooks on every cross-attention layer of BLIP's
    text decoder.  Stores per-layer attention_probs AND their gradients
    so we can compute per-layer GradCAM and then roll them up.
    """

    def __init__(self, model):
        self.model   = model
        self._hooks  = []
        self.layers  = []               # list of (fwd_map, grad_map) per layer, in order

        # Collect ALL cross-attention layers in order (layer 0 … 11)
        for layer in model.text_decoder.bert.encoder.layer:
            if hasattr(layer, "crossattention"):
                holder = {"fwd": None, "grad": None}
                self.layers.append(holder)

                # Closure to avoid late-binding bug
                def _make_hooks(h):
                    def _fwd(module, inp, out):
                        if len(out) > 1 and out[1] is not None:
                            h["fwd"] = out[1]
                            if h["fwd"].requires_grad:
                                h["fwd"].register_hook(lambda g, _h=h: _h.update({"grad": g.detach()}))
                    return _fwd

                target = layer.crossattention.self
                self._hooks.append(target.register_forward_hook(_make_hooks(holder)))

    def clear(self):
        for h in self.layers:
            h["fwd"] = None
            h["grad"] = None

    def remove(self):
        for hook in self._hooks:
            hook.remove()
        self._hooks = []


def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor:
    """
    Compute GradCAM for one layer.

    holder  : dict with 'fwd' and 'grad' tensors, shape (1, heads, seq, 197)
    Returns : 1D tensor of length 197, values β‰₯ 0.
    """
    attn = holder["fwd"][:, :, token_idx, :]      # (1, heads, 197)
    grad = holder["grad"][:, :, token_idx, :]      # (1, heads, 197)
    cam  = (attn * grad).mean(dim=1).squeeze()     # (197,)
    return torch.clamp(cam, min=0.0)


def _normalize1d(t: torch.Tensor) -> torch.Tensor:
    """L1-normalize a 1D tensor so it sums to 1."""
    s = t.sum()
    if s > 0:
        return t / s
    return t


def compute_attention_flow(
    extractor: FlowExtractor,
    num_image_tokens: int = 197,
    residual_weight: float = 0.05,
    out_resolution: int = 224,
) -> np.ndarray:
    """
    Compute Attention Flow across all layers.

    Algorithm
    ---------
    1.  For each layer i that has valid fwd AND grad maps:
        - Compute single-layer GradCAM (shape 197,).
        - Apply ReLU.
    2.  Roll up using recursive multiplication with residual identity:
        rollout[0] = cam[0]
        rollout[i] = norm(rollout[i-1]) βŠ™ norm(cam[i])
                     + residual_weight Γ— uniform

    ⚠️  Why residual_weight MUST be small (0.05, not 0.5):
        The uniform term provides a "background floor" so the rollout
        never collapses to zero.  BUT if it is too large, the uniform
        baseline dominates after 12 layers and ALL words produce the
        same flat heatmap (everything lights up equally, losing
        word-specificity).  0.05 keeps a tiny floor while letting the
        gradient signal dominate.

    3.  Drop the [CLS] token β†’ 196 spatial patches β†’ 14Γ—14.
    4.  Bicubic upsample to (out_resolution Γ— out_resolution) for
        pixel-smooth overlay.
    5.  Min-max normalise to [0, 1].

    Args:
        extractor       : FlowExtractor after a backward pass.
        num_image_tokens: Total image tokens (197 for BLIP ViT-base).
        residual_weight : Small regularisation weight (keep ≀ 0.1).
        out_resolution  : Target spatial size for the output heatmap.

    Returns:
        heatmap_np : (out_resolution, out_resolution) numpy float32 array.
    """
    valid_cams = []

    for holder in extractor.layers:
        if holder["fwd"] is None or holder["grad"] is None:
            continue
        cam = _single_layer_gradcam(holder)      # (197,)
        valid_cams.append(cam.detach())

    if not valid_cams:
        # Fallback: uniform heatmap
        grid = int(math.sqrt(num_image_tokens - 1))
        return np.zeros((out_resolution, out_resolution), dtype=np.float32)

    # --- Attention Flow rollout ---
    # Uniform baseline (Abnar & Zuidema eq. 6 identity term)
    uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens

    rollout = _normalize1d(valid_cams[0])
    for cam in valid_cams[1:]:
        rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform
        rollout = torch.clamp(rollout, min=0.0)

    # Drop [CLS] token (index 0) β†’ 196 patch tokens
    spatial = rollout[1:]                   # (196,)
    grid_sz = int(math.sqrt(spatial.numel()))

    # Reshape β†’ (1, 1, 14, 14) for F.interpolate
    hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_sz, grid_sz).float()

    # Bicubic upsampling β†’ (1, 1, out_res, out_res)
    hm_up = F.interpolate(
        hm_tensor,
        size=(out_resolution, out_resolution),
        mode="bicubic",
        align_corners=False,
    ).squeeze()                             # (out_res, out_res)

    hm_np = hm_up.numpy()

    # Min-max normalise to [0, 1]
    lo, hi = hm_np.min(), hm_np.max()
    if hi > lo:
        hm_np = (hm_np - lo) / (hi - lo)
    else:
        hm_np = np.zeros_like(hm_np)

    return hm_np.astype(np.float32)


# ── Main decoding loop ────────────────────────────────────────────────────────
def generate_with_flow(
    model,
    processor,
    device,
    encoder_hidden,
    encoder_mask,
    max_tokens: int = 20,
    verbose: bool = True,
) -> tuple[list[str], list[np.ndarray]]:
    """
    Token-by-token greedy decode with full Attention Flow heatmap per step.

    Args:
        model           – BlipForConditionalGeneration (eval, grad-ckpt disabled).
        processor       – BlipProcessor.
        device          – torch.device.
        encoder_hidden  – (1, 197, 768) from step2_encode_image.
        encoder_mask    – (1, 197) all-ones mask.
        max_tokens      – Maximum tokens to generate.
        verbose         – Print per-token progress.

    Returns:
        tokens   – List of decoded word strings.
        heatmaps – Parallel list of (224, 224) numpy heatmaps.
    """
    extractor = FlowExtractor(model)

    input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
    tokens    = []
    heatmaps  = []

    if verbose:
        print("πŸ”„ Generating caption with Attention Flow heatmaps …")

    for step in range(max_tokens):
        model.zero_grad()
        extractor.clear()

        outputs = model.text_decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden,
            encoder_attention_mask=encoder_mask,
            output_attentions=True,
            return_dict=True,
        )

        logits     = outputs.logits[:, -1, :]
        next_token = torch.argmax(logits, dim=-1)

        if next_token.item() == model.config.text_config.sep_token_id:
            break

        # Backward from the chosen token's logit
        logits[0, next_token.item()].backward(retain_graph=False)

        # Compute Attention Flow across all layers
        hm = compute_attention_flow(extractor)
        heatmaps.append(hm)

        word = processor.tokenizer.decode([next_token.item()]).strip()
        tokens.append(word)

        if verbose:
            print(f"  [{step+1:02d}] '{word}'  heatmap peak={hm.max():.3f}")

        input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1)

    extractor.remove()

    if verbose:
        caption = " ".join(tokens)
        print(f"\nβœ… Caption: {caption}")

    return tokens, heatmaps