File size: 6,269 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
2620860
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2620860
fda8fb3
2620860
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from __future__ import annotations

import time
from dataclasses import asdict
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch

from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks
from app.core.model_support import add_prefix_token_type_ids, describe_model_support
from app.core.schemas import ModelCapability, RuntimeMetadata


@dataclass(slots=True)
class AttributionMatrixComputation:
    matrix: np.ndarray
    raw_matrix: np.ndarray
    token_nll: np.ndarray
    runtime_metadata: RuntimeMetadata


def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
    if logits.ndim != 3 or input_ids.ndim != 2:
        raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].")
    if logits.shape[0] != 1 or input_ids.shape[0] != 1:
        raise ValueError("Only batch size 1 is supported for the prototype.")
    if input_ids.shape[1] < 2:
        raise ValueError("Need at least two tokens to compute next-token loss.")

    shifted_logits = logits[:, :-1, :]
    shifted_targets = input_ids[:, 1:]
    log_probs = torch.log_softmax(shifted_logits, dim=-1)
    gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1)
    return -gathered[0]


def _current_memory_mb(device: torch.device) -> float | None:
    if device.type != "cuda":
        return None
    return float(torch.cuda.memory_allocated(device) / (1024 * 1024))


def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray:
    if not take_log:
        return raw_matrix.copy()
    presentation = np.zeros_like(raw_matrix)
    positive = raw_matrix > 0
    presentation[positive] = np.log(raw_matrix[positive] + 1e-9)
    return presentation


def compute_attribution_matrix(
    input_ids: torch.Tensor,
    token_ranges: list[tuple[int, int]],
    model: Any,
    take_log: bool = True,
    max_trace_tokens: int = 0,
    max_sentences: int = 0,
) -> AttributionMatrixComputation:
    device = input_ids.device
    handles = register_hooks(model)
    model.zero_grad(set_to_none=True)
    forward_start = time.perf_counter()
    memory_before_mb = _current_memory_mb(device)

    try:
        with torch.enable_grad():
            model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids})
            outputs = model(
                **model_inputs,
                output_attentions=True,
                return_dict=True,
            )
            forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0

            logits = outputs.logits
            token_nll = compute_self_token_nll(logits, input_ids)
            loss = token_nll.sum()

            backward_start = time.perf_counter()
            loss.backward()
            backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0

        attentions = get_stored_attentions()
        if not attentions:
            raise RuntimeError("No attention tensors were captured. Check eager attention mode.")

        matrix_start = time.perf_counter()
        sentence_count = len(token_ranges)
        raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32)

        ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)]
        first_attention = ordered_layers[0]
        num_layers = len(ordered_layers)
        num_heads = int(first_attention.shape[1])

        for source_idx, (source_start, source_end) in enumerate(token_ranges):
            for target_idx, (target_start, target_end) in enumerate(token_ranges):
                if target_idx <= source_idx:
                    continue

                total = 0.0
                for attention in ordered_layers:
                    grad = attention.grad
                    if grad is None:
                        raise RuntimeError("Attention gradient was not retained for one or more layers.")
                    total += -(
                        grad[0, :, target_start:target_end, source_start:source_end]
                        * attention[0, :, target_start:target_end, source_start:source_end]
                    ).sum().item()

                denominator = max(1, target_end - target_start)
                raw_matrix[target_idx, source_idx] = total / denominator

        matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0
        total_analysis_ms = (
            forward_pass_ms + backward_pass_ms + matrix_computation_ms
        )
        presentation_matrix = _build_presentation_matrix(raw_matrix, take_log)

        attention_impl = getattr(model.config, "_attn_implementation", "unknown")
        capability = describe_model_support(model)
        runtime_metadata = RuntimeMetadata(
            forward_pass_ms=forward_pass_ms,
            backward_pass_ms=backward_pass_ms,
            matrix_computation_ms=matrix_computation_ms,
            total_analysis_ms=total_analysis_ms,
            num_layers=num_layers,
            num_heads=num_heads,
            sequence_length_tokens=int(input_ids.shape[1]),
            sentence_count=sentence_count,
            device=str(device),
            dtype=str(first_attention.dtype),
            attention_impl=str(attention_impl),
            max_trace_tokens=max_trace_tokens,
            max_sentences=max_sentences,
            capability=ModelCapability.model_validate(asdict(capability)),
        )

        memory_after_mb = _current_memory_mb(device)
        if memory_before_mb is not None and memory_after_mb is not None:
            runtime_metadata = runtime_metadata.model_copy(
                update={
                    "device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)"
                }
            )

        return AttributionMatrixComputation(
            matrix=presentation_matrix,
            raw_matrix=raw_matrix,
            token_nll=token_nll.detach().cpu().numpy(),
            runtime_metadata=runtime_metadata,
        )
    finally:
        for attention in get_stored_attentions().values():
            attention.grad = None
        remove_hooks(handles)
        model.zero_grad(set_to_none=True)
        if device.type == "cuda":
            torch.cuda.empty_cache()