Spaces:
Sleeping
Sleeping
File size: 2,242 Bytes
fda8fb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | from __future__ import annotations
from collections.abc import Iterable
from typing import Any
import torch
from app.core.model_support import get_decoder_layers
_ATTENTION_STORE: dict[int, torch.Tensor] = {}
def clear_stored_attentions() -> None:
_ATTENTION_STORE.clear()
def get_stored_attentions() -> dict[int, torch.Tensor]:
return dict(_ATTENTION_STORE)
def _extract_attention_tensor(output: Any) -> torch.Tensor | None:
if isinstance(output, torch.Tensor):
return output if output.dim() == 4 else None
if isinstance(output, dict):
for value in output.values():
if isinstance(value, torch.Tensor) and value.dim() == 4:
return value
if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
for item in output:
if isinstance(item, torch.Tensor) and item.dim() == 4:
return item
return None
def _get_attention_impl(model: Any) -> str | None:
config = getattr(model, "config", None)
if config is None:
return None
return getattr(config, "_attn_implementation", None) or getattr(
config,
"attn_implementation",
None,
)
def make_attn_hook(layer_idx: int):
def hook(_module: Any, _inputs: Any, output: Any) -> None:
attn = _extract_attention_tensor(output)
if attn is None:
return
if attn.dim() != 4:
raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.")
attn.retain_grad()
_ATTENTION_STORE[layer_idx] = attn
return hook
def register_hooks(model: Any) -> list[Any]:
clear_stored_attentions()
layers, _layer_path, attention_attr = get_decoder_layers(model)
handles: list[Any] = []
for layer_idx, layer in enumerate(layers):
self_attn = getattr(layer, attention_attr, None)
if self_attn is None:
raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.")
handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx)))
return handles
def remove_hooks(handles: list[Any]) -> None:
for handle in handles:
handle.remove()
clear_stored_attentions()
|