Spaces:
Sleeping
Sleeping
| """ | |
| step3_gradcam_flow.py | |
| ====================== | |
| STEP 3 β Multi-Layer Gradient-Weighted Attention Flow. | |
| What this implements (Iteration 3 upgrade over single-layer GradCAM): | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| In Iteration 2, we only looked at the last cross-attention layer. | |
| Here, we hook into ALL cross-attention layers simultaneously, run | |
| GradCAM on each one, and then combine them using "Attention Flow" | |
| (Abnar & Zuidema, 2020): | |
| rollout[0] = gradcam(layer 0) | |
| rollout[i] = normalize(rollout[i-1]) Γ gradcam(layer i) | |
| + 0.5 Γ I β residual identity keeps a | |
| "floor" of global context so | |
| the rollout never collapses to | |
| zero for deep layers. | |
| This multi-layer aggregation captures: | |
| β’ Early layers β edges, textures, spatial structure. | |
| β’ Middle layers β part-level features (ears, wheels, legs). | |
| β’ Last layer β high-level semantic concepts. | |
| The combined result is a 14Γ14 heatmap with dramatically tighter | |
| object contours and less background bleed. | |
| Additionally, instead of plain OpenCV nearest-neighbour upsampling, | |
| we use PyTorch bicubic interpolation to upscale 14Γ14 β 224Γ224, | |
| producing pixel-smooth, non-blocky heatmap edges. | |
| """ | |
| import os | |
| import sys | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) | |
| if _PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, _PROJECT_ROOT) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FlowExtractor: | |
| """ | |
| Registers forward hooks on every cross-attention layer of BLIP's | |
| text decoder. Stores per-layer attention_probs AND their gradients | |
| so we can compute per-layer GradCAM and then roll them up. | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| self._hooks = [] | |
| self.layers = [] # list of (fwd_map, grad_map) per layer, in order | |
| # Collect ALL cross-attention layers in order (layer 0 β¦ 11) | |
| for layer in model.text_decoder.bert.encoder.layer: | |
| if hasattr(layer, "crossattention"): | |
| holder = {"fwd": None, "grad": None} | |
| self.layers.append(holder) | |
| # Closure to avoid late-binding bug | |
| def _make_hooks(h): | |
| def _fwd(module, inp, out): | |
| if len(out) > 1 and out[1] is not None: | |
| h["fwd"] = out[1] | |
| if h["fwd"].requires_grad: | |
| h["fwd"].register_hook(lambda g, _h=h: _h.update({"grad": g.detach()})) | |
| return _fwd | |
| target = layer.crossattention.self | |
| self._hooks.append(target.register_forward_hook(_make_hooks(holder))) | |
| def clear(self): | |
| for h in self.layers: | |
| h["fwd"] = None | |
| h["grad"] = None | |
| def remove(self): | |
| for hook in self._hooks: | |
| hook.remove() | |
| self._hooks = [] | |
| def _single_layer_gradcam(holder, token_idx: int = -1) -> torch.Tensor: | |
| """ | |
| Compute GradCAM for one layer. | |
| holder : dict with 'fwd' and 'grad' tensors, shape (1, heads, seq, 197) | |
| Returns : 1D tensor of length 197, values β₯ 0. | |
| """ | |
| attn = holder["fwd"][:, :, token_idx, :] # (1, heads, 197) | |
| grad = holder["grad"][:, :, token_idx, :] # (1, heads, 197) | |
| cam = (attn * grad).mean(dim=1).squeeze() # (197,) | |
| return torch.clamp(cam, min=0.0) | |
| def _normalize1d(t: torch.Tensor) -> torch.Tensor: | |
| """L1-normalize a 1D tensor so it sums to 1.""" | |
| s = t.sum() | |
| if s > 0: | |
| return t / s | |
| return t | |
| def compute_attention_flow( | |
| extractor: FlowExtractor, | |
| num_image_tokens: int = 197, | |
| residual_weight: float = 0.05, | |
| out_resolution: int = 224, | |
| ) -> np.ndarray: | |
| """ | |
| Compute Attention Flow across all layers. | |
| Algorithm | |
| --------- | |
| 1. For each layer i that has valid fwd AND grad maps: | |
| - Compute single-layer GradCAM (shape 197,). | |
| - Apply ReLU. | |
| 2. Roll up using recursive multiplication with residual identity: | |
| rollout[0] = cam[0] | |
| rollout[i] = norm(rollout[i-1]) β norm(cam[i]) | |
| + residual_weight Γ uniform | |
| β οΈ Why residual_weight MUST be small (0.05, not 0.5): | |
| The uniform term provides a "background floor" so the rollout | |
| never collapses to zero. BUT if it is too large, the uniform | |
| baseline dominates after 12 layers and ALL words produce the | |
| same flat heatmap (everything lights up equally, losing | |
| word-specificity). 0.05 keeps a tiny floor while letting the | |
| gradient signal dominate. | |
| 3. Drop the [CLS] token β 196 spatial patches β 14Γ14. | |
| 4. Bicubic upsample to (out_resolution Γ out_resolution) for | |
| pixel-smooth overlay. | |
| 5. Min-max normalise to [0, 1]. | |
| Args: | |
| extractor : FlowExtractor after a backward pass. | |
| num_image_tokens: Total image tokens (197 for BLIP ViT-base). | |
| residual_weight : Small regularisation weight (keep β€ 0.1). | |
| out_resolution : Target spatial size for the output heatmap. | |
| Returns: | |
| heatmap_np : (out_resolution, out_resolution) numpy float32 array. | |
| """ | |
| valid_cams = [] | |
| for holder in extractor.layers: | |
| if holder["fwd"] is None or holder["grad"] is None: | |
| continue | |
| cam = _single_layer_gradcam(holder) # (197,) | |
| valid_cams.append(cam.detach()) | |
| if not valid_cams: | |
| # Fallback: uniform heatmap | |
| grid = int(math.sqrt(num_image_tokens - 1)) | |
| return np.zeros((out_resolution, out_resolution), dtype=np.float32) | |
| # --- Attention Flow rollout --- | |
| # Uniform baseline (Abnar & Zuidema eq. 6 identity term) | |
| uniform = torch.ones(num_image_tokens, device=valid_cams[0].device) / num_image_tokens | |
| rollout = _normalize1d(valid_cams[0]) | |
| for cam in valid_cams[1:]: | |
| rollout = _normalize1d(rollout) * _normalize1d(cam) + residual_weight * uniform | |
| rollout = torch.clamp(rollout, min=0.0) | |
| # Drop [CLS] token (index 0) β 196 patch tokens | |
| spatial = rollout[1:] # (196,) | |
| grid_sz = int(math.sqrt(spatial.numel())) | |
| # Reshape β (1, 1, 14, 14) for F.interpolate | |
| hm_tensor = spatial.detach().cpu().reshape(1, 1, grid_sz, grid_sz).float() | |
| # Bicubic upsampling β (1, 1, out_res, out_res) | |
| hm_up = F.interpolate( | |
| hm_tensor, | |
| size=(out_resolution, out_resolution), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() # (out_res, out_res) | |
| hm_np = hm_up.numpy() | |
| # Min-max normalise to [0, 1] | |
| lo, hi = hm_np.min(), hm_np.max() | |
| if hi > lo: | |
| hm_np = (hm_np - lo) / (hi - lo) | |
| else: | |
| hm_np = np.zeros_like(hm_np) | |
| return hm_np.astype(np.float32) | |
| # ββ Main decoding loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_with_flow( | |
| model, | |
| processor, | |
| device, | |
| encoder_hidden, | |
| encoder_mask, | |
| max_tokens: int = 20, | |
| verbose: bool = True, | |
| ) -> tuple[list[str], list[np.ndarray]]: | |
| """ | |
| Token-by-token greedy decode with full Attention Flow heatmap per step. | |
| Args: | |
| model β BlipForConditionalGeneration (eval, grad-ckpt disabled). | |
| processor β BlipProcessor. | |
| device β torch.device. | |
| encoder_hidden β (1, 197, 768) from step2_encode_image. | |
| encoder_mask β (1, 197) all-ones mask. | |
| max_tokens β Maximum tokens to generate. | |
| verbose β Print per-token progress. | |
| Returns: | |
| tokens β List of decoded word strings. | |
| heatmaps β Parallel list of (224, 224) numpy heatmaps. | |
| """ | |
| extractor = FlowExtractor(model) | |
| input_ids = torch.LongTensor([[model.config.text_config.bos_token_id]]).to(device) | |
| tokens = [] | |
| heatmaps = [] | |
| if verbose: | |
| print("π Generating caption with Attention Flow heatmaps β¦") | |
| for step in range(max_tokens): | |
| model.zero_grad() | |
| extractor.clear() | |
| outputs = model.text_decoder( | |
| input_ids=input_ids, | |
| encoder_hidden_states=encoder_hidden, | |
| encoder_attention_mask=encoder_mask, | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| next_token = torch.argmax(logits, dim=-1) | |
| if next_token.item() == model.config.text_config.sep_token_id: | |
| break | |
| # Backward from the chosen token's logit | |
| logits[0, next_token.item()].backward(retain_graph=False) | |
| # Compute Attention Flow across all layers | |
| hm = compute_attention_flow(extractor) | |
| heatmaps.append(hm) | |
| word = processor.tokenizer.decode([next_token.item()]).strip() | |
| tokens.append(word) | |
| if verbose: | |
| print(f" [{step+1:02d}] '{word}' heatmap peak={hm.max():.3f}") | |
| input_ids = torch.cat([input_ids, next_token.reshape(1, 1)], dim=-1) | |
| extractor.remove() | |
| if verbose: | |
| caption = " ".join(tokens) | |
| print(f"\nβ Caption: {caption}") | |
| return tokens, heatmaps | |