from __future__ import annotations import time from dataclasses import asdict from dataclasses import dataclass from typing import Any import numpy as np import torch from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks from app.core.model_support import add_prefix_token_type_ids, describe_model_support from app.core.schemas import ModelCapability, RuntimeMetadata @dataclass(slots=True) class AttributionMatrixComputation: matrix: np.ndarray raw_matrix: np.ndarray token_nll: np.ndarray runtime_metadata: RuntimeMetadata def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: if logits.ndim != 3 or input_ids.ndim != 2: raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].") if logits.shape[0] != 1 or input_ids.shape[0] != 1: raise ValueError("Only batch size 1 is supported for the prototype.") if input_ids.shape[1] < 2: raise ValueError("Need at least two tokens to compute next-token loss.") shifted_logits = logits[:, :-1, :] shifted_targets = input_ids[:, 1:] log_probs = torch.log_softmax(shifted_logits, dim=-1) gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1) return -gathered[0] def _current_memory_mb(device: torch.device) -> float | None: if device.type != "cuda": return None return float(torch.cuda.memory_allocated(device) / (1024 * 1024)) def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray: if not take_log: return raw_matrix.copy() presentation = np.zeros_like(raw_matrix) positive = raw_matrix > 0 presentation[positive] = np.log(raw_matrix[positive] + 1e-9) return presentation def compute_attribution_matrix( input_ids: torch.Tensor, token_ranges: list[tuple[int, int]], model: Any, take_log: bool = True, max_trace_tokens: int = 0, max_sentences: int = 0, ) -> AttributionMatrixComputation: device = input_ids.device handles = register_hooks(model) model.zero_grad(set_to_none=True) forward_start = time.perf_counter() memory_before_mb = _current_memory_mb(device) try: with torch.enable_grad(): model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids}) outputs = model( **model_inputs, output_attentions=True, return_dict=True, ) forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0 logits = outputs.logits token_nll = compute_self_token_nll(logits, input_ids) loss = token_nll.sum() backward_start = time.perf_counter() loss.backward() backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0 attentions = get_stored_attentions() if not attentions: raise RuntimeError("No attention tensors were captured. Check eager attention mode.") matrix_start = time.perf_counter() sentence_count = len(token_ranges) raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32) ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)] first_attention = ordered_layers[0] num_layers = len(ordered_layers) num_heads = int(first_attention.shape[1]) for source_idx, (source_start, source_end) in enumerate(token_ranges): for target_idx, (target_start, target_end) in enumerate(token_ranges): if target_idx <= source_idx: continue total = 0.0 for attention in ordered_layers: grad = attention.grad if grad is None: raise RuntimeError("Attention gradient was not retained for one or more layers.") total += -( grad[0, :, target_start:target_end, source_start:source_end] * attention[0, :, target_start:target_end, source_start:source_end] ).sum().item() denominator = max(1, target_end - target_start) raw_matrix[target_idx, source_idx] = total / denominator matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0 total_analysis_ms = ( forward_pass_ms + backward_pass_ms + matrix_computation_ms ) presentation_matrix = _build_presentation_matrix(raw_matrix, take_log) attention_impl = getattr(model.config, "_attn_implementation", "unknown") capability = describe_model_support(model) runtime_metadata = RuntimeMetadata( forward_pass_ms=forward_pass_ms, backward_pass_ms=backward_pass_ms, matrix_computation_ms=matrix_computation_ms, total_analysis_ms=total_analysis_ms, num_layers=num_layers, num_heads=num_heads, sequence_length_tokens=int(input_ids.shape[1]), sentence_count=sentence_count, device=str(device), dtype=str(first_attention.dtype), attention_impl=str(attention_impl), max_trace_tokens=max_trace_tokens, max_sentences=max_sentences, capability=ModelCapability.model_validate(asdict(capability)), ) memory_after_mb = _current_memory_mb(device) if memory_before_mb is not None and memory_after_mb is not None: runtime_metadata = runtime_metadata.model_copy( update={ "device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)" } ) return AttributionMatrixComputation( matrix=presentation_matrix, raw_matrix=raw_matrix, token_nll=token_nll.detach().cpu().numpy(), runtime_metadata=runtime_metadata, ) finally: for attention in get_stored_attentions().values(): attention.grad = None remove_hooks(handles) model.zero_grad(set_to_none=True) if device.type == "cuda": torch.cuda.empty_cache()