| """ |
| Cross-attention visualization for CASWiT model. |
| |
| This module provides utilities to visualize cross-attention maps between |
| HR and LR branches at different encoder stages. |
| """ |
|
|
| from typing import Dict, List, Tuple |
| import math |
| import torch |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
|
|
| def _to_numpy(img: torch.Tensor) -> np.ndarray: |
| """ |
| Convert normalized tensor to numpy array for visualization. |
| |
| Args: |
| img: (1,3,H,W) tensor in normalized space [-1,1] |
| |
| Returns: |
| uint8 HxWx3 array for plotting |
| """ |
| x = img.detach().float().cpu()[0] |
| |
| x = x * 0.5 + 0.5 |
| x = torch.clamp(x, 0, 1) |
| x = (x.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) |
| return x |
|
|
|
|
| def _pixel_to_token(x: int, y: int, W_img: int, H_img: int, |
| W_tokens: int, H_tokens: int) -> int: |
| """ |
| Map HR pixel (x,y) to linear token index for a grid (H_tokens, W_tokens). |
| |
| Uses a ratio-based mapping which is correct for uniform patch embeddings. |
| """ |
| tx = min(max(int(math.floor(x * W_tokens / max(W_img, 1))), 0), W_tokens - 1) |
| ty = min(max(int(math.floor(y * H_tokens / max(H_img, 1))), 0), H_tokens - 1) |
| return ty * W_tokens + tx |
|
|
|
|
| class _CrossAttnTap: |
| """Hook storage for capturing cross-attention weights.""" |
| def __init__(self): |
| |
| self.attn_by_stage: Dict[int, torch.Tensor] = {} |
| self.hr_grid_by_stage: Dict[int, Tuple[int, int]] = {} |
| self.lr_grid_by_stage: Dict[int, Tuple[int, int]] = {} |
| self._handles: List[torch.utils.hooks.RemovableHandle] = [] |
|
|
| def register(self, model) -> None: |
| """Register hooks on model to capture attention weights.""" |
| |
| blocks = getattr(model, 'cross_attn_blocks', None) |
| if blocks is None: |
| raise RuntimeError('Model has no attribute cross_attn_blocks') |
| |
| for s, block in enumerate(blocks): |
| |
| def fwd_hook(stage_idx: int): |
| def _f(module, inputs, output): |
| |
| x_hr, x_lr = inputs[0], inputs[1] |
| _, _, Hh, Wh = x_hr.shape |
| _, _, Hl, Wl = x_lr.shape |
| self.hr_grid_by_stage[stage_idx] = (Hh, Wh) |
| self.lr_grid_by_stage[stage_idx] = (Hl, Wl) |
| return _f |
| self._handles.append(block.register_forward_hook(fwd_hook(s))) |
|
|
| |
| mha = getattr(block, 'attn', None) |
| if mha is None: |
| raise RuntimeError(f'CrossFusionBlock at stage {s} has no attn module') |
|
|
| def attn_hook(stage_idx: int): |
| def _f(module, inputs, output): |
| |
| if isinstance(output, tuple) and len(output) == 2: |
| attn_w = output[1] |
| self.attn_by_stage[stage_idx] = attn_w.detach() |
| return _f |
| self._handles.append(mha.register_forward_hook(attn_hook(s))) |
|
|
| def clear(self): |
| """Clear stored attention weights.""" |
| self.attn_by_stage.clear() |
| self.hr_grid_by_stage.clear() |
| self.lr_grid_by_stage.clear() |
|
|
| def remove(self): |
| """Remove all registered hooks.""" |
| for h in self._handles: |
| h.remove() |
| self._handles.clear() |
|
|
|
|
| @torch.no_grad() |
| def viz_cross_attention( |
| model: torch.nn.Module, |
| img_hr: torch.Tensor, |
| img_lr: torch.Tensor, |
| pixel_xy: Tuple[int, int], |
| save_path: str = 'attn_maps.png', |
| overlay_alpha: float = 0.55, |
| dpi: int = 180, |
| show_titles: bool = True, |
| ): |
| """ |
| Visualize cross-attention maps for a given pixel location. |
| |
| Runs a forward pass and saves a multi-panel PNG: one panel per cross-attn stage. |
| The attention is averaged over heads (default behavior of nn.MultiheadAttention). |
| |
| Args: |
| model: CASWiT model (unwrap DDP if needed) |
| img_hr: HR input image [1, 3, H, W] |
| img_lr: LR input image [1, 3, h, w] |
| pixel_xy: (x, y) pixel coordinates in HR image space |
| save_path: Path to save visualization |
| overlay_alpha: Alpha transparency for attention overlay |
| dpi: DPI for saved figure |
| show_titles: Whether to show stage titles |
| """ |
| was_training = model.training |
| model.eval() |
|
|
| |
| if hasattr(model, 'module') and not hasattr(model, 'cross_attn_blocks'): |
| model = model.module |
|
|
| tap = _CrossAttnTap() |
| tap.register(model) |
| tap.clear() |
|
|
| try: |
| |
| device = next(model.parameters()).device |
| img_hr = img_hr.to(device) |
| img_lr = img_lr.to(device) |
| _ = model(img_hr, img_lr) |
|
|
| H_img, W_img = img_hr.shape[-2:] |
| px, py = pixel_xy |
| px = int(np.clip(px, 0, W_img - 1)) |
| py = int(np.clip(py, 0, H_img - 1)) |
|
|
| |
| base_hr = _to_numpy(img_hr) |
| base_lr = _to_numpy(img_lr) |
|
|
| stages = sorted(tap.attn_by_stage.keys()) |
| if len(stages) == 0: |
| raise RuntimeError('No attention captured. Ensure a forward pass reached the cross-attention blocks.') |
|
|
| n = len(stages) |
|
|
| |
| fig = plt.figure(figsize=(4.0*n, 4.2), dpi=dpi) |
| gs = fig.add_gridspec(nrows=1, ncols=n+1, width_ratios=[1]*n + [0.04], wspace=0.05) |
|
|
| axes = [fig.add_subplot(gs[0, i]) for i in range(n)] |
| cax = fig.add_subplot(gs[0, -1]) |
|
|
| hm = None |
| for i, s in enumerate(stages): |
| attn = tap.attn_by_stage[s] |
| (Hh, Wh) = tap.hr_grid_by_stage[s] |
| (Hl, Wl) = tap.lr_grid_by_stage[s] |
|
|
| |
| attn0 = attn[0] |
| q_idx = _pixel_to_token(px, py, W_img, H_img, Wh, Hh) |
| row = attn0[q_idx] |
|
|
| attn_map = row.view(Hl, Wl) |
|
|
| |
| attn_map = attn_map - attn_map.min() |
| denom = float(attn_map.max().item()) if float(attn_map.max().item()) > 0 else 1.0 |
| attn_map = attn_map / denom |
|
|
| |
| attn_up = F.interpolate( |
| attn_map[None, None, ...], |
| size=base_lr.shape[:2], |
| mode='bilinear', |
| align_corners=False |
| )[0, 0] |
| attn_np = attn_up.detach().cpu().numpy() |
|
|
| ax = axes[i] |
| ax.imshow(base_lr) |
| hm = ax.imshow(attn_np, cmap='jet', alpha=overlay_alpha, vmin=0.0, vmax=1.0) |
|
|
| |
| hx, hy = base_hr.shape[1], base_hr.shape[0] |
| lx, ly = base_lr.shape[1], base_lr.shape[0] |
| px_lr = int(round(px * lx / max(hx, 1))) |
| py_lr = int(round(py * ly / max(hy, 1))) |
| ax.scatter([px_lr], [py_lr], s=18, c='white', marker='o', |
| linewidths=0.5, edgecolors='black') |
|
|
| if show_titles: |
| ax.set_title(f'Stage {s+1}: HR→LR attn', fontsize=10) |
| ax.set_axis_off() |
|
|
| |
| cbar = fig.colorbar(hm, cax=cax) |
| cbar.set_label('Attention') |
|
|
| |
| fig.savefig(save_path, bbox_inches='tight', format='png') |
| plt.close(fig) |
|
|
| finally: |
| |
| tap.remove() |
| if was_training: |
| model.train() |
|
|
|
|