Spaces:
Running on Zero
Running on Zero
File size: 6,254 Bytes
8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | from __future__ import annotations
from typing import Any
import torch
_STOPWORD_HINTS = {
"a",
"an",
"the",
"and",
"or",
"to",
"of",
"in",
"on",
"by",
"for",
"that",
"same",
}
def is_informative_hint_text(text: str) -> bool:
cleaned = text.strip().lower()
if not cleaned:
return False
if not any(char.isalnum() for char in cleaned):
return False
words = [word for word in cleaned.replace("-", " ").split() if word]
if not words:
return False
if len(words) <= 2 and all(word in _STOPWORD_HINTS for word in words):
return False
return True
def decode_span_text(tokenizer: Any, token_ids: list[int]) -> str:
if tokenizer is None:
return " ".join(str(token_id) for token_id in token_ids)
try:
text = tokenizer.decode(token_ids, skip_special_tokens=False)
except TypeError:
text = tokenizer.decode(token_ids)
return text.replace("\n", "\\n")
def safe_decode_token(tokenizer: Any, token_id: int) -> str:
if tokenizer is None:
return str(token_id)
try:
text = tokenizer.decode([token_id], skip_special_tokens=False)
except TypeError:
text = tokenizer.decode([token_id])
return text.replace("\n", "\\n")
def spans_overlap(span_a: dict[str, Any], span_b: dict[str, Any]) -> bool:
return not (int(span_a["end"]) < int(span_b["start"]) or int(span_b["end"]) < int(span_a["start"]))
def extract_high_influence_spans(
scores: torch.Tensor,
input_ids: torch.Tensor,
tokenizer: Any,
min_score: float,
top_spans: int,
) -> list[dict[str, Any]]:
selected = [
idx
for idx, value in enumerate(scores.tolist())
if float(value) >= float(min_score)
]
if not selected:
return []
spans: list[tuple[int, int]] = []
start = selected[0]
prev = selected[0]
for idx in selected[1:]:
if idx == prev + 1:
prev = idx
continue
spans.append((start, prev))
start = idx
prev = idx
spans.append((start, prev))
ranked: list[dict[str, Any]] = []
for start_idx, end_idx in spans:
span_scores = scores[start_idx : end_idx + 1]
token_ids = [int(token.item()) for token in input_ids[start_idx : end_idx + 1]]
ranked.append(
{
"start": int(start_idx),
"end": int(end_idx),
"length": int(end_idx - start_idx + 1),
"mean_score": float(span_scores.mean().item()),
"max_score": float(span_scores.max().item()),
"token_ids": token_ids,
"text": decode_span_text(tokenizer, token_ids),
}
)
ranked.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True)
return ranked[:top_spans]
def compute_span_anchor_overlap(
future_spans: list[dict[str, Any]],
active_anchor_spans: list[dict[str, int]],
) -> dict[str, float]:
if not future_spans:
return {
"future_span_overlap_ratio": 0.0,
"anchor_span_overlap_ratio": 0.0,
}
future_overlap = sum(
1 for span in future_spans if any(spans_overlap(span, anchor) for anchor in active_anchor_spans)
)
anchor_overlap = sum(
1 for anchor in active_anchor_spans if any(spans_overlap(anchor, span) for span in future_spans)
)
return {
"future_span_overlap_ratio": future_overlap / max(len(future_spans), 1),
"anchor_span_overlap_ratio": anchor_overlap / max(len(active_anchor_spans), 1) if active_anchor_spans else 0.0,
}
def build_future_hint_candidates(
future_spans: list[dict[str, Any]],
active_anchor_spans: list[dict[str, int]],
) -> list[dict[str, Any]]:
hints: list[dict[str, Any]] = []
for span in future_spans:
if any(spans_overlap(span, anchor_span) for anchor_span in active_anchor_spans):
continue
if not is_informative_hint_text(str(span["text"])):
continue
hints.append(
{
"start": int(span["start"]),
"end": int(span["end"]),
"text": span["text"],
"mean_score": float(span["mean_score"]),
"max_score": float(span["max_score"]),
"length": int(span["length"]),
}
)
hints.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True)
return hints
def build_auxiliary_future_proposals(
hidden: torch.Tensor,
input_ids: torch.Tensor,
future_hint_candidates: list[dict[str, Any]],
tokenizer: Any,
max_candidates: int = 3,
) -> list[dict[str, Any]]:
proposals: list[dict[str, Any]] = []
for hint in future_hint_candidates[:max_candidates]:
start = max(0, min(int(hint["start"]), hidden.size(0) - 1))
end = max(start, min(int(hint["end"]), hidden.size(0) - 1))
span_hidden = hidden[start : end + 1]
span_ids = [int(token.item()) for token in input_ids[start : end + 1]]
proposals.append(
{
"proposal_type": "future_hint_span",
"proposal_score": float(hint["mean_score"]),
"proposal_span": (start, end),
"proposal_root_token": span_ids[-1] if span_ids else None,
"proposal_text": decode_span_text(tokenizer, span_ids),
"repr": span_hidden.mean(dim=0).detach(),
}
)
return proposals
def summarize_auxiliary_proposals(
proposal_batches: list[list[dict[str, Any]]],
) -> dict[str, float]:
counts = [len(batch) for batch in proposal_batches]
all_scores = [float(item["proposal_score"]) for batch in proposal_batches for item in batch]
return {
"proposal_count": int(sum(counts)),
"batch_with_proposals_count": int(sum(1 for count in counts if count > 0)),
"mean_proposal_count_per_batch": float(sum(counts) / max(len(counts), 1)),
"mean_proposal_score": float(sum(all_scores) / max(len(all_scores), 1)) if all_scores else 0.0,
"max_proposal_score": max(all_scores) if all_scores else 0.0,
}
|