abpt / src /model /future_span_hints.py
Search
feat: add src/ module for script imports
8125804
from __future__ import annotations
from typing import Any
import torch
_STOPWORD_HINTS = {
"a",
"an",
"the",
"and",
"or",
"to",
"of",
"in",
"on",
"by",
"for",
"that",
"same",
}
def is_informative_hint_text(text: str) -> bool:
cleaned = text.strip().lower()
if not cleaned:
return False
if not any(char.isalnum() for char in cleaned):
return False
words = [word for word in cleaned.replace("-", " ").split() if word]
if not words:
return False
if len(words) <= 2 and all(word in _STOPWORD_HINTS for word in words):
return False
return True
def decode_span_text(tokenizer: Any, token_ids: list[int]) -> str:
if tokenizer is None:
return " ".join(str(token_id) for token_id in token_ids)
try:
text = tokenizer.decode(token_ids, skip_special_tokens=False)
except TypeError:
text = tokenizer.decode(token_ids)
return text.replace("\n", "\\n")
def safe_decode_token(tokenizer: Any, token_id: int) -> str:
if tokenizer is None:
return str(token_id)
try:
text = tokenizer.decode([token_id], skip_special_tokens=False)
except TypeError:
text = tokenizer.decode([token_id])
return text.replace("\n", "\\n")
def spans_overlap(span_a: dict[str, Any], span_b: dict[str, Any]) -> bool:
return not (int(span_a["end"]) < int(span_b["start"]) or int(span_b["end"]) < int(span_a["start"]))
def extract_high_influence_spans(
scores: torch.Tensor,
input_ids: torch.Tensor,
tokenizer: Any,
min_score: float,
top_spans: int,
) -> list[dict[str, Any]]:
selected = [
idx
for idx, value in enumerate(scores.tolist())
if float(value) >= float(min_score)
]
if not selected:
return []
spans: list[tuple[int, int]] = []
start = selected[0]
prev = selected[0]
for idx in selected[1:]:
if idx == prev + 1:
prev = idx
continue
spans.append((start, prev))
start = idx
prev = idx
spans.append((start, prev))
ranked: list[dict[str, Any]] = []
for start_idx, end_idx in spans:
span_scores = scores[start_idx : end_idx + 1]
token_ids = [int(token.item()) for token in input_ids[start_idx : end_idx + 1]]
ranked.append(
{
"start": int(start_idx),
"end": int(end_idx),
"length": int(end_idx - start_idx + 1),
"mean_score": float(span_scores.mean().item()),
"max_score": float(span_scores.max().item()),
"token_ids": token_ids,
"text": decode_span_text(tokenizer, token_ids),
}
)
ranked.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True)
return ranked[:top_spans]
def compute_span_anchor_overlap(
future_spans: list[dict[str, Any]],
active_anchor_spans: list[dict[str, int]],
) -> dict[str, float]:
if not future_spans:
return {
"future_span_overlap_ratio": 0.0,
"anchor_span_overlap_ratio": 0.0,
}
future_overlap = sum(
1 for span in future_spans if any(spans_overlap(span, anchor) for anchor in active_anchor_spans)
)
anchor_overlap = sum(
1 for anchor in active_anchor_spans if any(spans_overlap(anchor, span) for span in future_spans)
)
return {
"future_span_overlap_ratio": future_overlap / max(len(future_spans), 1),
"anchor_span_overlap_ratio": anchor_overlap / max(len(active_anchor_spans), 1) if active_anchor_spans else 0.0,
}
def build_future_hint_candidates(
future_spans: list[dict[str, Any]],
active_anchor_spans: list[dict[str, int]],
) -> list[dict[str, Any]]:
hints: list[dict[str, Any]] = []
for span in future_spans:
if any(spans_overlap(span, anchor_span) for anchor_span in active_anchor_spans):
continue
if not is_informative_hint_text(str(span["text"])):
continue
hints.append(
{
"start": int(span["start"]),
"end": int(span["end"]),
"text": span["text"],
"mean_score": float(span["mean_score"]),
"max_score": float(span["max_score"]),
"length": int(span["length"]),
}
)
hints.sort(key=lambda item: (item["mean_score"], item["length"], item["max_score"]), reverse=True)
return hints
def build_auxiliary_future_proposals(
hidden: torch.Tensor,
input_ids: torch.Tensor,
future_hint_candidates: list[dict[str, Any]],
tokenizer: Any,
max_candidates: int = 3,
) -> list[dict[str, Any]]:
proposals: list[dict[str, Any]] = []
for hint in future_hint_candidates[:max_candidates]:
start = max(0, min(int(hint["start"]), hidden.size(0) - 1))
end = max(start, min(int(hint["end"]), hidden.size(0) - 1))
span_hidden = hidden[start : end + 1]
span_ids = [int(token.item()) for token in input_ids[start : end + 1]]
proposals.append(
{
"proposal_type": "future_hint_span",
"proposal_score": float(hint["mean_score"]),
"proposal_span": (start, end),
"proposal_root_token": span_ids[-1] if span_ids else None,
"proposal_text": decode_span_text(tokenizer, span_ids),
"repr": span_hidden.mean(dim=0).detach(),
}
)
return proposals
def summarize_auxiliary_proposals(
proposal_batches: list[list[dict[str, Any]]],
) -> dict[str, float]:
counts = [len(batch) for batch in proposal_batches]
all_scores = [float(item["proposal_score"]) for batch in proposal_batches for item in batch]
return {
"proposal_count": int(sum(counts)),
"batch_with_proposals_count": int(sum(1 for count in counts if count > 0)),
"mean_proposal_count_per_batch": float(sum(counts) / max(len(counts), 1)),
"mean_proposal_score": float(sum(all_scores) / max(len(all_scores), 1)) if all_scores else 0.0,
"max_proposal_score": max(all_scores) if all_scores else 0.0,
}