project_02_DS / task /task_02 /step3_gradcam_flow.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step3_gradcam_flow.py
======================
STEP 3 β€” Multi-Layer Gradient-Weighted Attention Flow.
What this implements (Iteration 3 upgrade over single-layer GradCAM):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
In Iteration 2, we only looked at the last cross-attention layer.
Here, we hook into ALL cross-attention layers simultaneously, run
GradCAM on each one, and then combine them using "Attention Flow"
(Abnar & Zuidema, 2020):
rollout[0] = gradcam(layer 0)
rollout[i] = normalize(rollout[i-1]) Γ— gradcam(layer i)
+ 0.5 Γ— I ← residual identity keeps a
"floor" of global context so
the rollout never collapses to
zero for deep layers.
This multi-layer aggregation captures:
β€’ Early layers β†’ edges, textures, spatial structure.
β€’ Middle layers β†’ part-level features (ears, wheels, legs).
β€’ Last layer β†’ high-level semantic concepts.
The combined result is a 14Γ—14 heatmap with dramatically tighter
object contours and less background bleed.
Additionally, instead of plain OpenCV nearest-neighbour upsampling,
we use PyTorch bicubic interpolation to upscale 14Γ—14 β†’ 224Γ—224,
producing pixel-smooth, non-blocky heatmap edges.
"""
import os
import sys
import math
import torch
import torch.nn.functional as F
import numpy as np
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
# ────────────────────────────────────────────────────────────────────────────
class FlowExtractor:
"""
Registers forward hooks on every cross-attention layer of BLIP's
text decoder. Stores per-layer attention_probs AND their gradients
so we can compute per-layer GradCAM and then roll them up.
"""
def __init__(self, model):
self.model = model
self._hooks = []
self.layers = [] # list of (fwd_map, grad_map) per layer, in order
# Collect ALL cross-attention layers in order (layer 0 … 11)
for layer in model.text_decoder.bert.encoder.layer:
if hasattr(layer, "crossattention"):
holder = {"fwd": None, "grad": None}
self.layers.append(holder)
# Closure to avoid late-binding bug
def _make_hooks(h):
def _fwd(module, inp, out):
if len(out) > 1 and out[1] is not None:
h["fwd"] = out[1]
if h["fwd"].requires_grad:
h["fwd"].register_hook(lambda g, _h=h: _h.update({"grad": g.detach()}))
return _fwd
target = layer.crossattention.self
self._hooks.append(target.register_forward_hook(_make_hooks(holder)))
def clear(self):
for h in self.layers:
h["fwd"] = None
h["grad"] = None
def remove(self):
for hook in self._hooks:
hook.remove()
self._hooks = []
def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor:
"""
Compute GradCAM for one layer.
holder : dict with 'fwd' and 'grad' tensors, shape (1, heads, seq, 197)
Returns : 1D tensor of length 197, values β‰₯ 0.
"""
attn = holder["fwd"][:, :, token_idx, :] # (1, heads, 197)
grad = holder["grad"][:, :, token_idx, :] # (1, heads, 197)
cam = (attn * grad).mean(dim=1).squeeze() # (197,)
return torch.clamp(cam, min=0.0)
def _normalize1d(t: torch.Tensor) -> torch.Tensor:
"""L1-normalize a 1D tensor so it sums to 1."""
s = t.sum()
if s > 0:
return t / s
return t
def compute_attention_flow(
extractor: FlowExtractor,
num_image_tokens: int = 197,
residual_weight: float = 0.05,
out_resolution: int = 224,
) -> np.ndarray:
"""
Compute Attention Flow across all layers.
Algorithm
---------
1. For each layer i that has valid fwd AND grad maps:
- Compute single-layer GradCAM (shape 197,).
- Apply ReLU.
2. Roll up using recursive multiplication with residual identity:
rollout[0] = cam[0]
rollout[i] = norm(rollout[i-1]) βŠ™ norm(cam[i])
+ residual_weight Γ— uniform
⚠️ Why residual_weight MUST be small (0.05, not 0.5):
The uniform term provides a "background floor" so the rollout
never collapses to zero. BUT if it is too large, the uniform
baseline dominates after 12 layers and ALL words produce the
same flat heatmap (everything lights up equally, losing
word-specificity). 0.05 keeps a tiny floor while letting the
gradient signal dominate.
3. Drop the [CLS] token β†’ 196 spatial patches β†’ 14Γ—14.
4. Bicubic upsample to (out_resolution Γ— out_resolution) for
pixel-smooth overlay.
5. Min-max normalise to [0, 1].
Args:
extractor : FlowExtractor after a backward pass.
num_image_tokens: Total image tokens (197 for BLIP ViT-base).
residual_weight : Small regularisation weight (keep ≀ 0.1).
out_resolution : Target spatial size for the output heatmap.
Returns:
heatmap_np : (out_resolution, out_resolution) numpy float32 array.
"""
valid_cams = []
for holder in extractor.layers:
if holder["fwd"] is None or holder["grad"] is None:
continue
cam = _single_layer_gradcam(holder) # (197,)
valid_cams.append(cam.detach())
if not valid_cams:
# Fallback: uniform heatmap
grid = int(math.sqrt(num_image_tokens - 1))
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
# --- Attention Flow rollout ---
# Uniform baseline (Abnar & Zuidema eq. 6 identity term)
uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
rollout = _normalize1d(valid_cams[0])
for cam in valid_cams[1:]:
rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform
rollout = torch.clamp(rollout, min=0.0)
# Drop [CLS] token (index 0) β†’ 196 patch tokens
spatial = rollout[1:] # (196,)
grid_sz = int(math.sqrt(spatial.numel()))
# Reshape β†’ (1, 1, 14, 14) for F.interpolate
hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_sz, grid_sz).float()
# Bicubic upsampling β†’ (1, 1, out_res, out_res)
hm_up = F.interpolate(
hm_tensor,
size=(out_resolution, out_resolution),
mode="bicubic",
align_corners=False,
).squeeze() # (out_res, out_res)
hm_np = hm_up.numpy()
# Min-max normalise to [0, 1]
lo, hi = hm_np.min(), hm_np.max()
if hi > lo:
hm_np = (hm_np - lo) / (hi - lo)
else:
hm_np = np.zeros_like(hm_np)
return hm_np.astype(np.float32)
# ── Main decoding loop ────────────────────────────────────────────────────────
def generate_with_flow(
model,
processor,
device,
encoder_hidden,
encoder_mask,
max_tokens: int = 20,
verbose: bool = True,
) -> tuple[list[str], list[np.ndarray]]:
"""
Token-by-token greedy decode with full Attention Flow heatmap per step.
Args:
model – BlipForConditionalGeneration (eval, grad-ckpt disabled).
processor – BlipProcessor.
device – torch.device.
encoder_hidden – (1, 197, 768) from step2_encode_image.
encoder_mask – (1, 197) all-ones mask.
max_tokens – Maximum tokens to generate.
verbose – Print per-token progress.
Returns:
tokens – List of decoded word strings.
heatmaps – Parallel list of (224, 224) numpy heatmaps.
"""
extractor = FlowExtractor(model)
input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
tokens = []
heatmaps = []
if verbose:
print("πŸ”„ Generating caption with Attention Flow heatmaps …")
for step in range(max_tokens):
model.zero_grad()
extractor.clear()
outputs = model.text_decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden,
encoder_attention_mask=encoder_mask,
output_attentions=True,
return_dict=True,
)
logits = outputs.logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1)
if next_token.item() == model.config.text_config.sep_token_id:
break
# Backward from the chosen token's logit
logits[0, next_token.item()].backward(retain_graph=False)
# Compute Attention Flow across all layers
hm = compute_attention_flow(extractor)
heatmaps.append(hm)
word = processor.tokenizer.decode([next_token.item()]).strip()
tokens.append(word)
if verbose:
print(f" [{step+1:02d}] '{word}' heatmap peak={hm.max():.3f}")
input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1)
extractor.remove()
if verbose:
caption = " ".join(tokens)
print(f"\nβœ… Caption: {caption}")
return tokens, heatmaps