from __future__ import annotations from collections.abc import Iterable from typing import Any import torch from app.core.model_support import get_decoder_layers _ATTENTION_STORE: dict[int, torch.Tensor] = {} def clear_stored_attentions() -> None: _ATTENTION_STORE.clear() def get_stored_attentions() -> dict[int, torch.Tensor]: return dict(_ATTENTION_STORE) def _extract_attention_tensor(output: Any) -> torch.Tensor | None: if isinstance(output, torch.Tensor): return output if output.dim() == 4 else None if isinstance(output, dict): for value in output.values(): if isinstance(value, torch.Tensor) and value.dim() == 4: return value if isinstance(output, Iterable) and not isinstance(output, (str, bytes)): for item in output: if isinstance(item, torch.Tensor) and item.dim() == 4: return item return None def _get_attention_impl(model: Any) -> str | None: config = getattr(model, "config", None) if config is None: return None return getattr(config, "_attn_implementation", None) or getattr( config, "attn_implementation", None, ) def make_attn_hook(layer_idx: int): def hook(_module: Any, _inputs: Any, output: Any) -> None: attn = _extract_attention_tensor(output) if attn is None: return if attn.dim() != 4: raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.") attn.retain_grad() _ATTENTION_STORE[layer_idx] = attn return hook def register_hooks(model: Any) -> list[Any]: clear_stored_attentions() layers, _layer_path, attention_attr = get_decoder_layers(model) handles: list[Any] = [] for layer_idx, layer in enumerate(layers): self_attn = getattr(layer, attention_attr, None) if self_attn is None: raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.") handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx))) return handles def remove_hooks(handles: list[Any]) -> None: for handle in handles: handle.remove() clear_stored_attentions()