cot-anc / app /analysis /suppression.py
BART-ender's picture
Switch default model to HRM-Text-1B
2620860 verified
from __future__ import annotations
import time
from dataclasses import asdict
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from app.analysis.hooks import get_stored_attentions, register_hooks, remove_hooks
from app.core.model_support import add_prefix_token_type_ids, describe_model_support
from app.core.schemas import ModelCapability, RuntimeMetadata
@dataclass(slots=True)
class AttributionMatrixComputation:
matrix: np.ndarray
raw_matrix: np.ndarray
token_nll: np.ndarray
runtime_metadata: RuntimeMetadata
def compute_self_token_nll(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
if logits.ndim != 3 or input_ids.ndim != 2:
raise ValueError("Expected logits [batch, seq, vocab] and input_ids [batch, seq].")
if logits.shape[0] != 1 or input_ids.shape[0] != 1:
raise ValueError("Only batch size 1 is supported for the prototype.")
if input_ids.shape[1] < 2:
raise ValueError("Need at least two tokens to compute next-token loss.")
shifted_logits = logits[:, :-1, :]
shifted_targets = input_ids[:, 1:]
log_probs = torch.log_softmax(shifted_logits, dim=-1)
gathered = log_probs.gather(-1, shifted_targets.unsqueeze(-1)).squeeze(-1)
return -gathered[0]
def _current_memory_mb(device: torch.device) -> float | None:
if device.type != "cuda":
return None
return float(torch.cuda.memory_allocated(device) / (1024 * 1024))
def _build_presentation_matrix(raw_matrix: np.ndarray, take_log: bool) -> np.ndarray:
if not take_log:
return raw_matrix.copy()
presentation = np.zeros_like(raw_matrix)
positive = raw_matrix > 0
presentation[positive] = np.log(raw_matrix[positive] + 1e-9)
return presentation
def compute_attribution_matrix(
input_ids: torch.Tensor,
token_ranges: list[tuple[int, int]],
model: Any,
take_log: bool = True,
max_trace_tokens: int = 0,
max_sentences: int = 0,
) -> AttributionMatrixComputation:
device = input_ids.device
handles = register_hooks(model)
model.zero_grad(set_to_none=True)
forward_start = time.perf_counter()
memory_before_mb = _current_memory_mb(device)
try:
with torch.enable_grad():
model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids})
outputs = model(
**model_inputs,
output_attentions=True,
return_dict=True,
)
forward_pass_ms = (time.perf_counter() - forward_start) * 1000.0
logits = outputs.logits
token_nll = compute_self_token_nll(logits, input_ids)
loss = token_nll.sum()
backward_start = time.perf_counter()
loss.backward()
backward_pass_ms = (time.perf_counter() - backward_start) * 1000.0
attentions = get_stored_attentions()
if not attentions:
raise RuntimeError("No attention tensors were captured. Check eager attention mode.")
matrix_start = time.perf_counter()
sentence_count = len(token_ranges)
raw_matrix = np.zeros((sentence_count, sentence_count), dtype=np.float32)
ordered_layers = [attentions[layer_idx] for layer_idx in sorted(attentions)]
first_attention = ordered_layers[0]
num_layers = len(ordered_layers)
num_heads = int(first_attention.shape[1])
for source_idx, (source_start, source_end) in enumerate(token_ranges):
for target_idx, (target_start, target_end) in enumerate(token_ranges):
if target_idx <= source_idx:
continue
total = 0.0
for attention in ordered_layers:
grad = attention.grad
if grad is None:
raise RuntimeError("Attention gradient was not retained for one or more layers.")
total += -(
grad[0, :, target_start:target_end, source_start:source_end]
* attention[0, :, target_start:target_end, source_start:source_end]
).sum().item()
denominator = max(1, target_end - target_start)
raw_matrix[target_idx, source_idx] = total / denominator
matrix_computation_ms = (time.perf_counter() - matrix_start) * 1000.0
total_analysis_ms = (
forward_pass_ms + backward_pass_ms + matrix_computation_ms
)
presentation_matrix = _build_presentation_matrix(raw_matrix, take_log)
attention_impl = getattr(model.config, "_attn_implementation", "unknown")
capability = describe_model_support(model)
runtime_metadata = RuntimeMetadata(
forward_pass_ms=forward_pass_ms,
backward_pass_ms=backward_pass_ms,
matrix_computation_ms=matrix_computation_ms,
total_analysis_ms=total_analysis_ms,
num_layers=num_layers,
num_heads=num_heads,
sequence_length_tokens=int(input_ids.shape[1]),
sentence_count=sentence_count,
device=str(device),
dtype=str(first_attention.dtype),
attention_impl=str(attention_impl),
max_trace_tokens=max_trace_tokens,
max_sentences=max_sentences,
capability=ModelCapability.model_validate(asdict(capability)),
)
memory_after_mb = _current_memory_mb(device)
if memory_before_mb is not None and memory_after_mb is not None:
runtime_metadata = runtime_metadata.model_copy(
update={
"device": f"{runtime_metadata.device} (mem {memory_before_mb:.1f}->{memory_after_mb:.1f} MB)"
}
)
return AttributionMatrixComputation(
matrix=presentation_matrix,
raw_matrix=raw_matrix,
token_nll=token_nll.detach().cpu().numpy(),
runtime_metadata=runtime_metadata,
)
finally:
for attention in get_stored_attentions().values():
attention.grad = None
remove_hooks(handles)
model.zero_grad(set_to_none=True)
if device.type == "cuda":
torch.cuda.empty_cache()