Spaces:
Sleeping
Sleeping
File size: 9,855 Bytes
0710b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 | """
step3_gradcam_flow.py
======================
STEP 3 β Multi-Layer Gradient-Weighted Attention Flow.
What this implements (Iteration 3 upgrade over single-layer GradCAM):
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
In Iteration 2, we only looked at the last cross-attention layer.
Here, we hook into ALL cross-attention layers simultaneously, run
GradCAM on each one, and then combine them using "Attention Flow"
(Abnar & Zuidema, 2020):
rollout[0] = gradcam(layer 0)
rollout[i] = normalize(rollout[i-1]) Γ gradcam(layer i)
+ 0.5 Γ I β residual identity keeps a
"floor" of global context so
the rollout never collapses to
zero for deep layers.
This multi-layer aggregation captures:
β’ Early layers β edges, textures, spatial structure.
β’ Middle layers β part-level features (ears, wheels, legs).
β’ Last layer β high-level semantic concepts.
The combined result is a 14Γ14 heatmap with dramatically tighter
object contours and less background bleed.
Additionally, instead of plain OpenCV nearest-neighbour upsampling,
we use PyTorch bicubic interpolation to upscale 14Γ14 β 224Γ224,
producing pixel-smooth, non-blocky heatmap edges.
"""
import os
import sys
import math
import torch
import torch.nn.functional as F
import numpy as np
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class FlowExtractor:
"""
Registers forward hooks on every cross-attention layer of BLIP's
text decoder. Stores per-layer attention_probs AND their gradients
so we can compute per-layer GradCAM and then roll them up.
"""
def __init__(self, model):
self.model = model
self._hooks = []
self.layers = [] # list of (fwd_map, grad_map) per layer, in order
# Collect ALL cross-attention layers in order (layer 0 β¦ 11)
for layer in model.text_decoder.bert.encoder.layer:
if hasattr(layer, "crossattention"):
holder = {"fwd": None, "grad": None}
self.layers.append(holder)
# Closure to avoid late-binding bug
def _make_hooks(h):
def _fwd(module, inp, out):
if len(out) > 1 and out[1] is not None:
h["fwd"] = out[1]
if h["fwd"].requires_grad:
h["fwd"].register_hook(lambda g, _h=h: _h.update({"grad": g.detach()}))
return _fwd
target = layer.crossattention.self
self._hooks.append(target.register_forward_hook(_make_hooks(holder)))
def clear(self):
for h in self.layers:
h["fwd"] = None
h["grad"] = None
def remove(self):
for hook in self._hooks:
hook.remove()
self._hooks = []
def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor:
"""
Compute GradCAM for one layer.
holder : dict with 'fwd' and 'grad' tensors, shape (1, heads, seq, 197)
Returns : 1D tensor of length 197, values β₯ 0.
"""
attn = holder["fwd"][:, :, token_idx, :] # (1, heads, 197)
grad = holder["grad"][:, :, token_idx, :] # (1, heads, 197)
cam = (attn * grad).mean(dim=1).squeeze() # (197,)
return torch.clamp(cam, min=0.0)
def _normalize1d(t: torch.Tensor) -> torch.Tensor:
"""L1-normalize a 1D tensor so it sums to 1."""
s = t.sum()
if s > 0:
return t / s
return t
def compute_attention_flow(
extractor: FlowExtractor,
num_image_tokens: int = 197,
residual_weight: float = 0.05,
out_resolution: int = 224,
) -> np.ndarray:
"""
Compute Attention Flow across all layers.
Algorithm
---------
1. For each layer i that has valid fwd AND grad maps:
- Compute single-layer GradCAM (shape 197,).
- Apply ReLU.
2. Roll up using recursive multiplication with residual identity:
rollout[0] = cam[0]
rollout[i] = norm(rollout[i-1]) β norm(cam[i])
+ residual_weight Γ uniform
β οΈ Why residual_weight MUST be small (0.05, not 0.5):
The uniform term provides a "background floor" so the rollout
never collapses to zero. BUT if it is too large, the uniform
baseline dominates after 12 layers and ALL words produce the
same flat heatmap (everything lights up equally, losing
word-specificity). 0.05 keeps a tiny floor while letting the
gradient signal dominate.
3. Drop the [CLS] token β 196 spatial patches β 14Γ14.
4. Bicubic upsample to (out_resolution Γ out_resolution) for
pixel-smooth overlay.
5. Min-max normalise to [0, 1].
Args:
extractor : FlowExtractor after a backward pass.
num_image_tokens: Total image tokens (197 for BLIP ViT-base).
residual_weight : Small regularisation weight (keep β€ 0.1).
out_resolution : Target spatial size for the output heatmap.
Returns:
heatmap_np : (out_resolution, out_resolution) numpy float32 array.
"""
valid_cams = []
for holder in extractor.layers:
if holder["fwd"] is None or holder["grad"] is None:
continue
cam = _single_layer_gradcam(holder) # (197,)
valid_cams.append(cam.detach())
if not valid_cams:
# Fallback: uniform heatmap
grid = int(math.sqrt(num_image_tokens - 1))
return np.zeros((out_resolution, out_resolution), dtype=np.float32)
# --- Attention Flow rollout ---
# Uniform baseline (Abnar & Zuidema eq. 6 identity term)
uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens
rollout = _normalize1d(valid_cams[0])
for cam in valid_cams[1:]:
rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform
rollout = torch.clamp(rollout, min=0.0)
# Drop [CLS] token (index 0) β 196 patch tokens
spatial = rollout[1:] # (196,)
grid_sz = int(math.sqrt(spatial.numel()))
# Reshape β (1, 1, 14, 14) for F.interpolate
hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_sz, grid_sz).float()
# Bicubic upsampling β (1, 1, out_res, out_res)
hm_up = F.interpolate(
hm_tensor,
size=(out_resolution, out_resolution),
mode="bicubic",
align_corners=False,
).squeeze() # (out_res, out_res)
hm_np = hm_up.numpy()
# Min-max normalise to [0, 1]
lo, hi = hm_np.min(), hm_np.max()
if hi > lo:
hm_np = (hm_np - lo) / (hi - lo)
else:
hm_np = np.zeros_like(hm_np)
return hm_np.astype(np.float32)
# ββ Main decoding loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate_with_flow(
model,
processor,
device,
encoder_hidden,
encoder_mask,
max_tokens: int = 20,
verbose: bool = True,
) -> tuple[list[str], list[np.ndarray]]:
"""
Token-by-token greedy decode with full Attention Flow heatmap per step.
Args:
model β BlipForConditionalGeneration (eval, grad-ckpt disabled).
processor β BlipProcessor.
device β torch.device.
encoder_hidden β (1, 197, 768) from step2_encode_image.
encoder_mask β (1, 197) all-ones mask.
max_tokens β Maximum tokens to generate.
verbose β Print per-token progress.
Returns:
tokens β List of decoded word strings.
heatmaps β Parallel list of (224, 224) numpy heatmaps.
"""
extractor = FlowExtractor(model)
input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device)
tokens = []
heatmaps = []
if verbose:
print("π Generating caption with Attention Flow heatmaps β¦")
for step in range(max_tokens):
model.zero_grad()
extractor.clear()
outputs = model.text_decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden,
encoder_attention_mask=encoder_mask,
output_attentions=True,
return_dict=True,
)
logits = outputs.logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1)
if next_token.item() == model.config.text_config.sep_token_id:
break
# Backward from the chosen token's logit
logits[0, next_token.item()].backward(retain_graph=False)
# Compute Attention Flow across all layers
hm = compute_attention_flow(extractor)
heatmaps.append(hm)
word = processor.tokenizer.decode([next_token.item()]).strip()
tokens.append(word)
if verbose:
print(f" [{step+1:02d}] '{word}' heatmap peak={hm.max():.3f}")
input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1)
extractor.remove()
if verbose:
caption = " ".join(tokens)
print(f"\nβ
Caption: {caption}")
return tokens, heatmaps
|