File size: 2,242 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

from collections.abc import Iterable
from typing import Any

import torch

from app.core.model_support import get_decoder_layers

_ATTENTION_STORE: dict[int, torch.Tensor] = {}


def clear_stored_attentions() -> None:
    _ATTENTION_STORE.clear()


def get_stored_attentions() -> dict[int, torch.Tensor]:
    return dict(_ATTENTION_STORE)


def _extract_attention_tensor(output: Any) -> torch.Tensor | None:
    if isinstance(output, torch.Tensor):
        return output if output.dim() == 4 else None

    if isinstance(output, dict):
        for value in output.values():
            if isinstance(value, torch.Tensor) and value.dim() == 4:
                return value

    if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
        for item in output:
            if isinstance(item, torch.Tensor) and item.dim() == 4:
                return item

    return None


def _get_attention_impl(model: Any) -> str | None:
    config = getattr(model, "config", None)
    if config is None:
        return None
    return getattr(config, "_attn_implementation", None) or getattr(
        config,
        "attn_implementation",
        None,
    )


def make_attn_hook(layer_idx: int):
    def hook(_module: Any, _inputs: Any, output: Any) -> None:
        attn = _extract_attention_tensor(output)
        if attn is None:
            return
        if attn.dim() != 4:
            raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.")
        attn.retain_grad()
        _ATTENTION_STORE[layer_idx] = attn

    return hook


def register_hooks(model: Any) -> list[Any]:
    clear_stored_attentions()
    layers, _layer_path, attention_attr = get_decoder_layers(model)

    handles: list[Any] = []
    for layer_idx, layer in enumerate(layers):
        self_attn = getattr(layer, attention_attr, None)
        if self_attn is None:
            raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.")
        handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx)))
    return handles


def remove_hooks(handles: list[Any]) -> None:
    for handle in handles:
        handle.remove()
    clear_stored_attentions()