Spaces:
Sleeping
Sleeping
File size: 6,269 Bytes
fda8fb3 2620860 fda8fb3 2620860 fda8fb3 2620860 fda8fb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | from __future__ import annotations
import time
from dataclasses import asdict
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks
from app.core.model_support import add_prefix_token_type_ids, describe_model_support
from app.core.schemas import ModelCapability, RuntimeMetadata
@dataclass(slots=True)
class AttributionMatrixComputation:
matrix: np.ndarray
raw_matrix: np.ndarray
token_nll: np.ndarray
runtime_metadata: RuntimeMetadata
def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
if logits.ndim != 3 or input_ids.ndim != 2:
raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].")
if logits.shape[0] != 1 or input_ids.shape[0] != 1:
raise ValueError("Only batch size 1 is supported for the prototype.")
if input_ids.shape[1] < 2:
raise ValueError("Need at least two tokens to compute next-token loss.")
shifted_logits = logits[:, :-1, :]
shifted_targets = input_ids[:, 1:]
log_probs = torch.log_softmax(shifted_logits, dim=-1)
gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1)
return -gathered[0]
def _current_memory_mb(device: torch.device) -> float | None:
if device.type != "cuda":
return None
return float(torch.cuda.memory_allocated(device) / (1024 * 1024))
def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray:
if not take_log:
return raw_matrix.copy()
presentation = np.zeros_like(raw_matrix)
positive = raw_matrix > 0
presentation[positive] = np.log(raw_matrix[positive] + 1e-9)
return presentation
def compute_attribution_matrix(
input_ids: torch.Tensor,
token_ranges: list[tuple[int, int]],
model: Any,
take_log: bool = True,
max_trace_tokens: int = 0,
max_sentences: int = 0,
) -> AttributionMatrixComputation:
device = input_ids.device
handles = register_hooks(model)
model.zero_grad(set_to_none=True)
forward_start = time.perf_counter()
memory_before_mb = _current_memory_mb(device)
try:
with torch.enable_grad():
model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids})
outputs = model(
**model_inputs,
output_attentions=True,
return_dict=True,
)
forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0
logits = outputs.logits
token_nll = compute_self_token_nll(logits, input_ids)
loss = token_nll.sum()
backward_start = time.perf_counter()
loss.backward()
backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0
attentions = get_stored_attentions()
if not attentions:
raise RuntimeError("No attention tensors were captured. Check eager attention mode.")
matrix_start = time.perf_counter()
sentence_count = len(token_ranges)
raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32)
ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)]
first_attention = ordered_layers[0]
num_layers = len(ordered_layers)
num_heads = int(first_attention.shape[1])
for source_idx, (source_start, source_end) in enumerate(token_ranges):
for target_idx, (target_start, target_end) in enumerate(token_ranges):
if target_idx <= source_idx:
continue
total = 0.0
for attention in ordered_layers:
grad = attention.grad
if grad is None:
raise RuntimeError("Attention gradient was not retained for one or more layers.")
total += -(
grad[0, :, target_start:target_end, source_start:source_end]
* attention[0, :, target_start:target_end, source_start:source_end]
).sum().item()
denominator = max(1, target_end - target_start)
raw_matrix[target_idx, source_idx] = total / denominator
matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0
total_analysis_ms = (
forward_pass_ms + backward_pass_ms + matrix_computation_ms
)
presentation_matrix = _build_presentation_matrix(raw_matrix, take_log)
attention_impl = getattr(model.config, "_attn_implementation", "unknown")
capability = describe_model_support(model)
runtime_metadata = RuntimeMetadata(
forward_pass_ms=forward_pass_ms,
backward_pass_ms=backward_pass_ms,
matrix_computation_ms=matrix_computation_ms,
total_analysis_ms=total_analysis_ms,
num_layers=num_layers,
num_heads=num_heads,
sequence_length_tokens=int(input_ids.shape[1]),
sentence_count=sentence_count,
device=str(device),
dtype=str(first_attention.dtype),
attention_impl=str(attention_impl),
max_trace_tokens=max_trace_tokens,
max_sentences=max_sentences,
capability=ModelCapability.model_validate(asdict(capability)),
)
memory_after_mb = _current_memory_mb(device)
if memory_before_mb is not None and memory_after_mb is not None:
runtime_metadata = runtime_metadata.model_copy(
update={
"device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)"
}
)
return AttributionMatrixComputation(
matrix=presentation_matrix,
raw_matrix=raw_matrix,
token_nll=token_nll.detach().cpu().numpy(),
runtime_metadata=runtime_metadata,
)
finally:
for attention in get_stored_attentions().values():
attention.grad = None
remove_hooks(handles)
model.zero_grad(set_to_none=True)
if device.type == "cuda":
torch.cuda.empty_cache()
|