Spaces:
Sleeping
Sleeping
File size: 4,178 Bytes
fda8fb3 2620860 fda8fb3 2620860 fda8fb3 2620860 fda8fb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | from __future__ import annotations
from typing import Any
import numpy as np
import torch
from scipy.stats import pearsonr, spearmanr
from app.analysis.suppression import compute_self_token_nll
from app.core.model_support import add_prefix_token_type_ids
from app.core.schemas import TopEdge, ValidationMetadata
def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice:
start, end = token_range
return slice(max(0, start - 1), max(0, end - 1))
def build_exact_suppression_mask(
*,
sequence_length: int,
source_range: tuple[int, int],
target_range: tuple[int, int],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
fill_value = torch.finfo(dtype).min
mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype)
future_positions = torch.triu(
torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool),
diagonal=1,
)
mask = mask.masked_fill(future_positions, fill_value)
source_start, source_end = source_range
target_start, target_end = target_range
mask[target_start:target_end, source_start:source_end] = fill_value
return mask.unsqueeze(0).unsqueeze(0)
def compute_exact_edge_score(
*,
model: Any,
input_ids: torch.Tensor,
source_range: tuple[int, int],
target_range: tuple[int, int],
baseline_token_nll: np.ndarray,
) -> float:
model_dtype = next(model.parameters()).dtype
attention_mask = build_exact_suppression_mask(
sequence_length=int(input_ids.shape[1]),
source_range=source_range,
target_range=target_range,
device=input_ids.device,
dtype=model_dtype,
)
with torch.no_grad():
model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids})
outputs = model(
**model_inputs,
attention_mask=attention_mask,
output_attentions=False,
return_dict=True,
)
suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy()
nll_slice = _nll_slice_for_token_range(target_range)
return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum())
def validate_top_edges(
*,
model: Any,
input_ids: torch.Tensor,
token_ranges: list[tuple[int, int]],
top_edges: list[TopEdge],
baseline_token_nll: np.ndarray,
top_k: int,
) -> ValidationMetadata:
if top_k <= 0 or not top_edges:
return ValidationMetadata(enabled=False, top_k=0)
selected_edges = top_edges[:top_k]
exact_scores: list[float] = []
attributed_scores: list[float] = []
compared_edges: list[TopEdge] = []
try:
for edge in selected_edges:
exact_score = compute_exact_edge_score(
model=model,
input_ids=input_ids,
source_range=token_ranges[edge.source_sentence_idx],
target_range=token_ranges[edge.target_sentence_idx],
baseline_token_nll=baseline_token_nll,
)
exact_scores.append(exact_score)
attributed_scores.append(edge.score)
compared_edges.append(
TopEdge(
source_sentence_idx=edge.source_sentence_idx,
target_sentence_idx=edge.target_sentence_idx,
score=exact_score,
)
)
except Exception as exc: # pragma: no cover - environment/model dependent
return ValidationMetadata(
enabled=True,
top_k=top_k,
compared_edges=[],
notes=f"Exact suppression validation failed: {exc}",
)
pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
return ValidationMetadata(
enabled=True,
top_k=top_k,
pearson=pearson,
spearman=spearman,
compared_edges=compared_edges,
notes="Exact suppression compares sentence-level NLL deltas for selected edges.",
)
|