Image Segmentation
English
CASWiT / utils /attention_viz.py
antoine.carreaud67
clean release
36b4539
"""
Cross-attention visualization for CASWiT model.
This module provides utilities to visualize cross-attention maps between
HR and LR branches at different encoder stages.
"""
from typing import Dict, List, Tuple
import math
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
def _to_numpy(img: torch.Tensor) -> np.ndarray:
"""
Convert normalized tensor to numpy array for visualization.
Args:
img: (1,3,H,W) tensor in normalized space [-1,1]
Returns:
uint8 HxWx3 array for plotting
"""
x = img.detach().float().cpu()[0]
# Undo Normalize(mean=0.5, std=0.5) -> x*0.5 + 0.5
x = x * 0.5 + 0.5
x = torch.clamp(x, 0, 1)
x = (x.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
return x
def _pixel_to_token(x: int, y: int, W_img: int, H_img: int,
W_tokens: int, H_tokens: int) -> int:
"""
Map HR pixel (x,y) to linear token index for a grid (H_tokens, W_tokens).
Uses a ratio-based mapping which is correct for uniform patch embeddings.
"""
tx = min(max(int(math.floor(x * W_tokens / max(W_img, 1))), 0), W_tokens - 1)
ty = min(max(int(math.floor(y * H_tokens / max(H_img, 1))), 0), H_tokens - 1)
return ty * W_tokens + tx
class _CrossAttnTap:
"""Hook storage for capturing cross-attention weights."""
def __init__(self):
# Per stage we store: attn (B, N_q, N_k) averaged over heads, and grids
self.attn_by_stage: Dict[int, torch.Tensor] = {}
self.hr_grid_by_stage: Dict[int, Tuple[int, int]] = {}
self.lr_grid_by_stage: Dict[int, Tuple[int, int]] = {}
self._handles: List[torch.utils.hooks.RemovableHandle] = []
def register(self, model) -> None:
"""Register hooks on model to capture attention weights."""
# Locate the list of CrossFusionBlock modules
blocks = getattr(model, 'cross_attn_blocks', None)
if blocks is None:
raise RuntimeError('Model has no attribute cross_attn_blocks')
for s, block in enumerate(blocks):
# Hook on the CrossFusionBlock to get H/W of inputs (x_hr, x_lr)
def fwd_hook(stage_idx: int):
def _f(module, inputs, output):
# inputs: (x_hr, x_lr)
x_hr, x_lr = inputs[0], inputs[1]
_, _, Hh, Wh = x_hr.shape
_, _, Hl, Wl = x_lr.shape
self.hr_grid_by_stage[stage_idx] = (Hh, Wh)
self.lr_grid_by_stage[stage_idx] = (Hl, Wl)
return _f
self._handles.append(block.register_forward_hook(fwd_hook(s)))
# Hook on the internal nn.MultiheadAttention to grab attn weights
mha = getattr(block, 'attn', None)
if mha is None:
raise RuntimeError(f'CrossFusionBlock at stage {s} has no attn module')
def attn_hook(stage_idx: int):
def _f(module, inputs, output):
# output is a tuple: (attn_out, attn_weights)
if isinstance(output, tuple) and len(output) == 2:
attn_w = output[1] # shape: (B, N_q, N_k) averaged over heads
self.attn_by_stage[stage_idx] = attn_w.detach()
return _f
self._handles.append(mha.register_forward_hook(attn_hook(s)))
def clear(self):
"""Clear stored attention weights."""
self.attn_by_stage.clear()
self.hr_grid_by_stage.clear()
self.lr_grid_by_stage.clear()
def remove(self):
"""Remove all registered hooks."""
for h in self._handles:
h.remove()
self._handles.clear()
@torch.no_grad()
def viz_cross_attention(
model: torch.nn.Module,
img_hr: torch.Tensor, # (1,3,H,W) normalized with mean=std=0.5
img_lr: torch.Tensor, # (1,3,h,w) normalized
pixel_xy: Tuple[int, int],
save_path: str = 'attn_maps.png',
overlay_alpha: float = 0.55,
dpi: int = 180,
show_titles: bool = True,
):
"""
Visualize cross-attention maps for a given pixel location.
Runs a forward pass and saves a multi-panel PNG: one panel per cross-attn stage.
The attention is averaged over heads (default behavior of nn.MultiheadAttention).
Args:
model: CASWiT model (unwrap DDP if needed)
img_hr: HR input image [1, 3, H, W]
img_lr: LR input image [1, 3, h, w]
pixel_xy: (x, y) pixel coordinates in HR image space
save_path: Path to save visualization
overlay_alpha: Alpha transparency for attention overlay
dpi: DPI for saved figure
show_titles: Whether to show stage titles
"""
was_training = model.training
model.eval()
# If user passed a DDP-wrapped model, unwrap
if hasattr(model, 'module') and not hasattr(model, 'cross_attn_blocks'):
model = model.module
tap = _CrossAttnTap()
tap.register(model)
tap.clear()
try:
# Forward to populate hooks
device = next(model.parameters()).device
img_hr = img_hr.to(device)
img_lr = img_lr.to(device)
_ = model(img_hr, img_lr)
H_img, W_img = img_hr.shape[-2:]
px, py = pixel_xy
px = int(np.clip(px, 0, W_img - 1))
py = int(np.clip(py, 0, H_img - 1))
# Prepare base images for overlays (H,W,3) in [0,255]
base_hr = _to_numpy(img_hr)
base_lr = _to_numpy(img_lr)
stages = sorted(tap.attn_by_stage.keys())
if len(stages) == 0:
raise RuntimeError('No attention captured. Ensure a forward pass reached the cross-attention blocks.')
n = len(stages)
# Figure with a dedicated column for colorbar
fig = plt.figure(figsize=(4.0*n, 4.2), dpi=dpi)
gs = fig.add_gridspec(nrows=1, ncols=n+1, width_ratios=[1]*n + [0.04], wspace=0.05)
axes = [fig.add_subplot(gs[0, i]) for i in range(n)]
cax = fig.add_subplot(gs[0, -1]) # Axis reserved for colorbar
hm = None
for i, s in enumerate(stages):
attn = tap.attn_by_stage[s] # (B, N_q, N_k)
(Hh, Wh) = tap.hr_grid_by_stage[s]
(Hl, Wl) = tap.lr_grid_by_stage[s]
# Pick batch 0
attn0 = attn[0] # (N_q, N_k)
q_idx = _pixel_to_token(px, py, W_img, H_img, Wh, Hh) # note W first in tokens
row = attn0[q_idx] # (N_k,)
attn_map = row.view(Hl, Wl) # reshape to LR grid because K comes from LR branch
# Normalize for visualization
attn_map = attn_map - attn_map.min()
denom = float(attn_map.max().item()) if float(attn_map.max().item()) > 0 else 1.0
attn_map = attn_map / denom
# Upsample to LR background size
attn_up = F.interpolate(
attn_map[None, None, ...],
size=base_lr.shape[:2],
mode='bilinear',
align_corners=False
)[0, 0]
attn_np = attn_up.detach().cpu().numpy()
ax = axes[i]
ax.imshow(base_lr)
hm = ax.imshow(attn_np, cmap='jet', alpha=overlay_alpha, vmin=0.0, vmax=1.0)
# Approx HR→LR pixel mapping for the marker (simple ratio)
hx, hy = base_hr.shape[1], base_hr.shape[0]
lx, ly = base_lr.shape[1], base_lr.shape[0]
px_lr = int(round(px * lx / max(hx, 1)))
py_lr = int(round(py * ly / max(hy, 1)))
ax.scatter([px_lr], [py_lr], s=18, c='white', marker='o',
linewidths=0.5, edgecolors='black')
if show_titles:
ax.set_title(f'Stage {s+1}: HR→LR attn', fontsize=10)
ax.set_axis_off()
# Colorbar in dedicated axis
cbar = fig.colorbar(hm, cax=cax)
cbar.set_label('Attention')
# Save PNG
fig.savefig(save_path, bbox_inches='tight', format='png')
plt.close(fig)
finally:
# Cleanup hooks, restore training mode if needed
tap.remove()
if was_training:
model.train()