cot-anc / app /analysis /hooks.py
BART-ender's picture
Deploy Thought Anchors
fda8fb3 verified
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
import torch
from app.core.model_support import get_decoder_layers
_ATTENTION_STORE: dict[int, torch.Tensor] = {}
def clear_stored_attentions() -> None:
_ATTENTION_STORE.clear()
def get_stored_attentions() -> dict[int, torch.Tensor]:
return dict(_ATTENTION_STORE)
def _extract_attention_tensor(output: Any) -> torch.Tensor | None:
if isinstance(output, torch.Tensor):
return output if output.dim() == 4 else None
if isinstance(output, dict):
for value in output.values():
if isinstance(value, torch.Tensor) and value.dim() == 4:
return value
if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
for item in output:
if isinstance(item, torch.Tensor) and item.dim() == 4:
return item
return None
def _get_attention_impl(model: Any) -> str | None:
config = getattr(model, "config", None)
if config is None:
return None
return getattr(config, "_attn_implementation", None) or getattr(
config,
"attn_implementation",
None,
)
def make_attn_hook(layer_idx: int):
def hook(_module: Any, _inputs: Any, output: Any) -> None:
attn = _extract_attention_tensor(output)
if attn is None:
return
if attn.dim() != 4:
raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.")
attn.retain_grad()
_ATTENTION_STORE[layer_idx] = attn
return hook
def register_hooks(model: Any) -> list[Any]:
clear_stored_attentions()
layers, _layer_path, attention_attr = get_decoder_layers(model)
handles: list[Any] = []
for layer_idx, layer in enumerate(layers):
self_attn = getattr(layer, attention_attr, None)
if self_attn is None:
raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.")
handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx)))
return handles
def remove_hooks(handles: list[Any]) -> None:
for handle in handles:
handle.remove()
clear_stored_attentions()