Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from scipy.stats import pearsonr, spearmanr | |
| from app.analysis.suppression import compute_self_token_nll | |
| from app.core.model_support import add_prefix_token_type_ids | |
| from app.core.schemas import TopEdge, ValidationMetadata | |
| def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice: | |
| start, end = token_range | |
| return slice(max(0, start - 1), max(0, end - 1)) | |
| def build_exact_suppression_mask( | |
| *, | |
| sequence_length: int, | |
| source_range: tuple[int, int], | |
| target_range: tuple[int, int], | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor: | |
| fill_value = torch.finfo(dtype).min | |
| mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype) | |
| future_positions = torch.triu( | |
| torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool), | |
| diagonal=1, | |
| ) | |
| mask = mask.masked_fill(future_positions, fill_value) | |
| source_start, source_end = source_range | |
| target_start, target_end = target_range | |
| mask[target_start:target_end, source_start:source_end] = fill_value | |
| return mask.unsqueeze(0).unsqueeze(0) | |
| def compute_exact_edge_score( | |
| *, | |
| model: Any, | |
| input_ids: torch.Tensor, | |
| source_range: tuple[int, int], | |
| target_range: tuple[int, int], | |
| baseline_token_nll: np.ndarray, | |
| ) -> float: | |
| model_dtype = next(model.parameters()).dtype | |
| attention_mask = build_exact_suppression_mask( | |
| sequence_length=int(input_ids.shape[1]), | |
| source_range=source_range, | |
| target_range=target_range, | |
| device=input_ids.device, | |
| dtype=model_dtype, | |
| ) | |
| with torch.no_grad(): | |
| model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids}) | |
| outputs = model( | |
| **model_inputs, | |
| attention_mask=attention_mask, | |
| output_attentions=False, | |
| return_dict=True, | |
| ) | |
| suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy() | |
| nll_slice = _nll_slice_for_token_range(target_range) | |
| return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum()) | |
| def validate_top_edges( | |
| *, | |
| model: Any, | |
| input_ids: torch.Tensor, | |
| token_ranges: list[tuple[int, int]], | |
| top_edges: list[TopEdge], | |
| baseline_token_nll: np.ndarray, | |
| top_k: int, | |
| ) -> ValidationMetadata: | |
| if top_k <= 0 or not top_edges: | |
| return ValidationMetadata(enabled=False, top_k=0) | |
| selected_edges = top_edges[:top_k] | |
| exact_scores: list[float] = [] | |
| attributed_scores: list[float] = [] | |
| compared_edges: list[TopEdge] = [] | |
| try: | |
| for edge in selected_edges: | |
| exact_score = compute_exact_edge_score( | |
| model=model, | |
| input_ids=input_ids, | |
| source_range=token_ranges[edge.source_sentence_idx], | |
| target_range=token_ranges[edge.target_sentence_idx], | |
| baseline_token_nll=baseline_token_nll, | |
| ) | |
| exact_scores.append(exact_score) | |
| attributed_scores.append(edge.score) | |
| compared_edges.append( | |
| TopEdge( | |
| source_sentence_idx=edge.source_sentence_idx, | |
| target_sentence_idx=edge.target_sentence_idx, | |
| score=exact_score, | |
| ) | |
| ) | |
| except Exception as exc: # pragma: no cover - environment/model dependent | |
| return ValidationMetadata( | |
| enabled=True, | |
| top_k=top_k, | |
| compared_edges=[], | |
| notes=f"Exact suppression validation failed: {exc}", | |
| ) | |
| pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None | |
| spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None | |
| return ValidationMetadata( | |
| enabled=True, | |
| top_k=top_k, | |
| pearson=pearson, | |
| spearman=spearman, | |
| compared_edges=compared_edges, | |
| notes="Exact suppression compares sentence-level NLL deltas for selected edges.", | |
| ) | |