File size: 6,254 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from __future__ import annotations

from typing import Any

import torch

_STOPWORD_HINTS = {
    "a",
    "an",
    "the",
    "and",
    "or",
    "to",
    "of",
    "in",
    "on",
    "by",
    "for",
    "that",
    "same",
}


def is_informative_hint_text(text: str) -> bool:
    cleaned = text.strip().lower()
    if not cleaned:
        return False
    if not any(char.isalnum() for char in cleaned):
        return False
    words = [word for word in cleaned.replace("-", " ").split() if word]
    if not words:
        return False
    if len(words) <= 2 and all(word in _STOPWORD_HINTS for word in words):
        return False
    return True


def decode_span_text(tokenizer: Any, token_ids: list[int]) -> str:
    if tokenizer is None:
        return " ".join(str(token_id) for token_id in token_ids)
    try:
        text = tokenizer.decode(token_ids, skip_special_tokens=False)
    except TypeError:
        text = tokenizer.decode(token_ids)
    return text.replace("\n", "\\n")


def safe_decode_token(tokenizer: Any, token_id: int) -> str:
    if tokenizer is None:
        return str(token_id)
    try:
        text = tokenizer.decode([token_id], skip_special_tokens=False)
    except TypeError:
        text = tokenizer.decode([token_id])
    return text.replace("\n", "\\n")


def spans_overlap(span_a: dict[str, Any], span_b: dict[str, Any]) -> bool:
    return not (int(span_a["end"]) < int(span_b["start"]) or int(span_b["end"]) < int(span_a["start"]))


def extract_high_influence_spans(
    scores: torch.Tensor,
    input_ids: torch.Tensor,
    tokenizer: Any,
    min_score: float,
    top_spans: int,
) -> list[dict[str, Any]]:
    selected = [
        idx
        for idx, value in enumerate(scores.tolist())
        if float(value) >= float(min_score)
    ]
    if not selected:
        return []

    spans: list[tuple[int, int]] = []
    start = selected[0]
    prev = selected[0]
    for idx in selected[1:]:
        if idx == prev + 1:
            prev = idx
            continue
        spans.append((start, prev))
        start = idx
        prev = idx
    spans.append((start, prev))

    ranked: list[dict[str, Any]] = []
    for start_idx, end_idx in spans:
        span_scores = scores[start_idx : end_idx + 1]
        token_ids = [int(token.item()) for token in input_ids[start_idx : end_idx + 1]]
        ranked.append(
            {
                "start": int(start_idx),
                "end": int(end_idx),
                "length": int(end_idx - start_idx + 1),
                "mean_score": float(span_scores.mean().item()),
                "max_score": float(span_scores.max().item()),
                "token_ids": token_ids,
                "text": decode_span_text(tokenizer, token_ids),
            }
        )
    ranked.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True)
    return ranked[:top_spans]


def compute_span_anchor_overlap(
    future_spans: list[dict[str, Any]],
    active_anchor_spans: list[dict[str, int]],
) -> dict[str, float]:
    if not future_spans:
        return {
            "future_span_overlap_ratio": 0.0,
            "anchor_span_overlap_ratio": 0.0,
        }

    future_overlap = sum(
        1 for span in future_spans if any(spans_overlap(span, anchor) for anchor in active_anchor_spans)
    )
    anchor_overlap = sum(
        1 for anchor in active_anchor_spans if any(spans_overlap(anchor, span) for span in future_spans)
    )
    return {
        "future_span_overlap_ratio": future_overlap / max(len(future_spans), 1),
        "anchor_span_overlap_ratio": anchor_overlap / max(len(active_anchor_spans), 1) if active_anchor_spans else 0.0,
    }


def build_future_hint_candidates(
    future_spans: list[dict[str, Any]],
    active_anchor_spans: list[dict[str, int]],
) -> list[dict[str, Any]]:
    hints: list[dict[str, Any]] = []
    for span in future_spans:
        if any(spans_overlap(span, anchor_span) for anchor_span in active_anchor_spans):
            continue
        if not is_informative_hint_text(str(span["text"])):
            continue
        hints.append(
            {
                "start": int(span["start"]),
                "end": int(span["end"]),
                "text": span["text"],
                "mean_score": float(span["mean_score"]),
                "max_score": float(span["max_score"]),
                "length": int(span["length"]),
            }
        )
    hints.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True)
    return hints


def build_auxiliary_future_proposals(
    hidden: torch.Tensor,
    input_ids: torch.Tensor,
    future_hint_candidates: list[dict[str, Any]],
    tokenizer: Any,
    max_candidates: int = 3,
) -> list[dict[str, Any]]:
    proposals: list[dict[str, Any]] = []
    for hint in future_hint_candidates[:max_candidates]:
        start = max(0, min(int(hint["start"]), hidden.size(0) - 1))
        end = max(start, min(int(hint["end"]), hidden.size(0) - 1))
        span_hidden = hidden[start : end + 1]
        span_ids = [int(token.item()) for token in input_ids[start : end + 1]]
        proposals.append(
            {
                "proposal_type": "future_hint_span",
                "proposal_score": float(hint["mean_score"]),
                "proposal_span": (start, end),
                "proposal_root_token": span_ids[-1] if span_ids else None,
                "proposal_text": decode_span_text(tokenizer, span_ids),
                "repr": span_hidden.mean(dim=0).detach(),
            }
        )
    return proposals


def summarize_auxiliary_proposals(
    proposal_batches: list[list[dict[str, Any]]],
) -> dict[str, float]:
    counts = [len(batch) for batch in proposal_batches]
    all_scores = [float(item["proposal_score"]) for batch in proposal_batches for item in batch]
    return {
        "proposal_count": int(sum(counts)),
        "batch_with_proposals_count": int(sum(1 for count in counts if count > 0)),
        "mean_proposal_count_per_batch": float(sum(counts) / max(len(counts), 1)),
        "mean_proposal_score": float(sum(all_scores) / max(len(all_scores), 1)) if all_scores else 0.0,
        "max_proposal_score": max(all_scores) if all_scores else 0.0,
    }