from __future__ import annotations from typing import Any import numpy as np import torch from scipy.stats import pearsonr, spearmanr from app.analysis.suppression import compute_self_token_nll from app.core.model_support import add_prefix_token_type_ids from app.core.schemas import TopEdge, ValidationMetadata def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice: start, end = token_range return slice(max(0, start - 1), max(0, end - 1)) def build_exact_suppression_mask( *, sequence_length: int, source_range: tuple[int, int], target_range: tuple[int, int], device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: fill_value = torch.finfo(dtype).min mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype) future_positions = torch.triu( torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool), diagonal=1, ) mask = mask.masked_fill(future_positions, fill_value) source_start, source_end = source_range target_start, target_end = target_range mask[target_start:target_end, source_start:source_end] = fill_value return mask.unsqueeze(0).unsqueeze(0) def compute_exact_edge_score( *, model: Any, input_ids: torch.Tensor, source_range: tuple[int, int], target_range: tuple[int, int], baseline_token_nll: np.ndarray, ) -> float: model_dtype = next(model.parameters()).dtype attention_mask = build_exact_suppression_mask( sequence_length=int(input_ids.shape[1]), source_range=source_range, target_range=target_range, device=input_ids.device, dtype=model_dtype, ) with torch.no_grad(): model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids}) outputs = model( **model_inputs, attention_mask=attention_mask, output_attentions=False, return_dict=True, ) suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy() nll_slice = _nll_slice_for_token_range(target_range) return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum()) def validate_top_edges( *, model: Any, input_ids: torch.Tensor, token_ranges: list[tuple[int, int]], top_edges: list[TopEdge], baseline_token_nll: np.ndarray, top_k: int, ) -> ValidationMetadata: if top_k <= 0 or not top_edges: return ValidationMetadata(enabled=False, top_k=0) selected_edges = top_edges[:top_k] exact_scores: list[float] = [] attributed_scores: list[float] = [] compared_edges: list[TopEdge] = [] try: for edge in selected_edges: exact_score = compute_exact_edge_score( model=model, input_ids=input_ids, source_range=token_ranges[edge.source_sentence_idx], target_range=token_ranges[edge.target_sentence_idx], baseline_token_nll=baseline_token_nll, ) exact_scores.append(exact_score) attributed_scores.append(edge.score) compared_edges.append( TopEdge( source_sentence_idx=edge.source_sentence_idx, target_sentence_idx=edge.target_sentence_idx, score=exact_score, ) ) except Exception as exc: # pragma: no cover - environment/model dependent return ValidationMetadata( enabled=True, top_k=top_k, compared_edges=[], notes=f"Exact suppression validation failed: {exc}", ) pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None return ValidationMetadata( enabled=True, top_k=top_k, pearson=pearson, spearman=spearman, compared_edges=compared_edges, notes="Exact suppression compares sentence-level NLL deltas for selected edges.", )