Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import time | |
| from dataclasses import asdict | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks | |
| from app.core.model_support import add_prefix_token_type_ids, describe_model_support | |
| from app.core.schemas import ModelCapability, RuntimeMetadata | |
| class AttributionMatrixComputation: | |
| matrix: np.ndarray | |
| raw_matrix: np.ndarray | |
| token_nll: np.ndarray | |
| runtime_metadata: RuntimeMetadata | |
| def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: | |
| if logits.ndim != 3 or input_ids.ndim != 2: | |
| raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].") | |
| if logits.shape[0] != 1 or input_ids.shape[0] != 1: | |
| raise ValueError("Only batch size 1 is supported for the prototype.") | |
| if input_ids.shape[1] < 2: | |
| raise ValueError("Need at least two tokens to compute next-token loss.") | |
| shifted_logits = logits[:, :-1, :] | |
| shifted_targets = input_ids[:, 1:] | |
| log_probs = torch.log_softmax(shifted_logits, dim=-1) | |
| gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1) | |
| return -gathered[0] | |
| def _current_memory_mb(device: torch.device) -> float | None: | |
| if device.type != "cuda": | |
| return None | |
| return float(torch.cuda.memory_allocated(device) / (1024 * 1024)) | |
| def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray: | |
| if not take_log: | |
| return raw_matrix.copy() | |
| presentation = np.zeros_like(raw_matrix) | |
| positive = raw_matrix > 0 | |
| presentation[positive] = np.log(raw_matrix[positive] + 1e-9) | |
| return presentation | |
| def compute_attribution_matrix( | |
| input_ids: torch.Tensor, | |
| token_ranges: list[tuple[int, int]], | |
| model: Any, | |
| take_log: bool = True, | |
| max_trace_tokens: int = 0, | |
| max_sentences: int = 0, | |
| ) -> AttributionMatrixComputation: | |
| device = input_ids.device | |
| handles = register_hooks(model) | |
| model.zero_grad(set_to_none=True) | |
| forward_start = time.perf_counter() | |
| memory_before_mb = _current_memory_mb(device) | |
| try: | |
| with torch.enable_grad(): | |
| model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids}) | |
| outputs = model( | |
| **model_inputs, | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0 | |
| logits = outputs.logits | |
| token_nll = compute_self_token_nll(logits, input_ids) | |
| loss = token_nll.sum() | |
| backward_start = time.perf_counter() | |
| loss.backward() | |
| backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0 | |
| attentions = get_stored_attentions() | |
| if not attentions: | |
| raise RuntimeError("No attention tensors were captured. Check eager attention mode.") | |
| matrix_start = time.perf_counter() | |
| sentence_count = len(token_ranges) | |
| raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32) | |
| ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)] | |
| first_attention = ordered_layers[0] | |
| num_layers = len(ordered_layers) | |
| num_heads = int(first_attention.shape[1]) | |
| for source_idx, (source_start, source_end) in enumerate(token_ranges): | |
| for target_idx, (target_start, target_end) in enumerate(token_ranges): | |
| if target_idx <= source_idx: | |
| continue | |
| total = 0.0 | |
| for attention in ordered_layers: | |
| grad = attention.grad | |
| if grad is None: | |
| raise RuntimeError("Attention gradient was not retained for one or more layers.") | |
| total += -( | |
| grad[0, :, target_start:target_end, source_start:source_end] | |
| * attention[0, :, target_start:target_end, source_start:source_end] | |
| ).sum().item() | |
| denominator = max(1, target_end - target_start) | |
| raw_matrix[target_idx, source_idx] = total / denominator | |
| matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0 | |
| total_analysis_ms = ( | |
| forward_pass_ms + backward_pass_ms + matrix_computation_ms | |
| ) | |
| presentation_matrix = _build_presentation_matrix(raw_matrix, take_log) | |
| attention_impl = getattr(model.config, "_attn_implementation", "unknown") | |
| capability = describe_model_support(model) | |
| runtime_metadata = RuntimeMetadata( | |
| forward_pass_ms=forward_pass_ms, | |
| backward_pass_ms=backward_pass_ms, | |
| matrix_computation_ms=matrix_computation_ms, | |
| total_analysis_ms=total_analysis_ms, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| sequence_length_tokens=int(input_ids.shape[1]), | |
| sentence_count=sentence_count, | |
| device=str(device), | |
| dtype=str(first_attention.dtype), | |
| attention_impl=str(attention_impl), | |
| max_trace_tokens=max_trace_tokens, | |
| max_sentences=max_sentences, | |
| capability=ModelCapability.model_validate(asdict(capability)), | |
| ) | |
| memory_after_mb = _current_memory_mb(device) | |
| if memory_before_mb is not None and memory_after_mb is not None: | |
| runtime_metadata = runtime_metadata.model_copy( | |
| update={ | |
| "device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)" | |
| } | |
| ) | |
| return AttributionMatrixComputation( | |
| matrix=presentation_matrix, | |
| raw_matrix=raw_matrix, | |
| token_nll=token_nll.detach().cpu().numpy(), | |
| runtime_metadata=runtime_metadata, | |
| ) | |
| finally: | |
| for attention in get_stored_attentions().values(): | |
| attention.grad = None | |
| remove_hooks(handles) | |
| model.zero_grad(set_to_none=True) | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |