File size: 4,178 Bytes
fda8fb3
 
 
 
 
 
 
 
 
2620860
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2620860
fda8fb3
2620860
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations

from typing import Any

import numpy as np
import torch
from scipy.stats import pearsonr, spearmanr

from app.analysis.suppression import compute_self_token_nll
from app.core.model_support import add_prefix_token_type_ids
from app.core.schemas import TopEdge, ValidationMetadata


def _nll_slice_for_token_range(token_range: tuple[int, int]) -> slice:
    start, end = token_range
    return slice(max(0, start - 1), max(0, end - 1))


def build_exact_suppression_mask(
    *,
    sequence_length: int,
    source_range: tuple[int, int],
    target_range: tuple[int, int],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    fill_value = torch.finfo(dtype).min
    mask = torch.zeros((sequence_length, sequence_length), device=device, dtype=dtype)
    future_positions = torch.triu(
        torch.ones((sequence_length, sequence_length), device=device, dtype=torch.bool),
        diagonal=1,
    )
    mask = mask.masked_fill(future_positions, fill_value)
    source_start, source_end = source_range
    target_start, target_end = target_range
    mask[target_start:target_end, source_start:source_end] = fill_value
    return mask.unsqueeze(0).unsqueeze(0)


def compute_exact_edge_score(
    *,
    model: Any,
    input_ids: torch.Tensor,
    source_range: tuple[int, int],
    target_range: tuple[int, int],
    baseline_token_nll: np.ndarray,
) -> float:
    model_dtype = next(model.parameters()).dtype
    attention_mask = build_exact_suppression_mask(
        sequence_length=int(input_ids.shape[1]),
        source_range=source_range,
        target_range=target_range,
        device=input_ids.device,
        dtype=model_dtype,
    )
    with torch.no_grad():
        model_inputs = add_prefix_token_type_ids(model, {"input_ids": input_ids})
        outputs = model(
            **model_inputs,
            attention_mask=attention_mask,
            output_attentions=False,
            return_dict=True,
        )
    suppressed_nll = compute_self_token_nll(outputs.logits, input_ids).detach().cpu().numpy()
    nll_slice = _nll_slice_for_token_range(target_range)
    return float(suppressed_nll[nll_slice].sum() - baseline_token_nll[nll_slice].sum())


def validate_top_edges(
    *,
    model: Any,
    input_ids: torch.Tensor,
    token_ranges: list[tuple[int, int]],
    top_edges: list[TopEdge],
    baseline_token_nll: np.ndarray,
    top_k: int,
) -> ValidationMetadata:
    if top_k <= 0 or not top_edges:
        return ValidationMetadata(enabled=False, top_k=0)

    selected_edges = top_edges[:top_k]
    exact_scores: list[float] = []
    attributed_scores: list[float] = []
    compared_edges: list[TopEdge] = []

    try:
        for edge in selected_edges:
            exact_score = compute_exact_edge_score(
                model=model,
                input_ids=input_ids,
                source_range=token_ranges[edge.source_sentence_idx],
                target_range=token_ranges[edge.target_sentence_idx],
                baseline_token_nll=baseline_token_nll,
            )
            exact_scores.append(exact_score)
            attributed_scores.append(edge.score)
            compared_edges.append(
                TopEdge(
                    source_sentence_idx=edge.source_sentence_idx,
                    target_sentence_idx=edge.target_sentence_idx,
                    score=exact_score,
                )
            )
    except Exception as exc:  # pragma: no cover - environment/model dependent
        return ValidationMetadata(
            enabled=True,
            top_k=top_k,
            compared_edges=[],
            notes=f"Exact suppression validation failed: {exc}",
        )

    pearson = float(pearsonr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None
    spearman = float(spearmanr(attributed_scores, exact_scores).statistic) if len(exact_scores) >= 2 else None

    return ValidationMetadata(
        enabled=True,
        top_k=top_k,
        pearson=pearson,
        spearman=spearman,
        compared_edges=compared_edges,
        notes="Exact suppression compares sentence-level NLL deltas for selected edges.",
    )