abpt / src /model /future_proposal.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig
@dataclass
class FutureProposalCandidate:
start: int
end: int
repr: torch.Tensor
score: torch.Tensor
root_token: int | None
class FutureProposalHead(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
hidden_dim = max(32, int(cfg.anchor_future_proposal_hidden))
self.score_mlp = nn.Sequential(
nn.Linear(10, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1),
)
self.repr_delta = nn.Sequential(
nn.Linear(cfg.d_model * 4, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, cfg.d_model),
)
@staticmethod
def _cosine01_tensor(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
cosine = F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=-1).mean()
return (cosine + 1.0) * 0.5
def _candidate_lengths(
self,
span_len: int,
available: int,
) -> list[int]:
if available <= 0:
return []
lengths = {
max(1, span_len // 2),
max(1, span_len),
max(1, min(available, span_len + max(1, span_len // 2))),
max(1, min(available, span_len * 2)),
}
return sorted(length for length in lengths if 1 <= length <= available)
def _search_bounds(
self,
anchor,
seq_len: int,
) -> tuple[int, int, int]:
span_len = max(int(anchor.end_idx) - int(anchor.start_idx) + 1, 1)
start = min(int(anchor.end_idx) + 1, seq_len)
if start >= seq_len:
return start, start, span_len
base_horizon = max(
int(float(anchor.ttl) * float(self.cfg.anchor_future_proposal_horizon_scale)),
int(span_len * float(self.cfg.anchor_future_proposal_span_scale)),
)
horizon = min(max(base_horizon, span_len), int(self.cfg.anchor_future_proposal_max_horizon))
stop = min(seq_len, start + max(horizon, 1))
return start, stop, span_len
def _subsample_candidates(
self,
candidates: list[FutureProposalCandidate],
) -> list[FutureProposalCandidate]:
max_windows = max(1, int(self.cfg.anchor_future_proposal_max_windows))
if len(candidates) <= max_windows:
return candidates
idx = torch.linspace(0, len(candidates) - 1, steps=max_windows).round().long().tolist()
return [candidates[i] for i in idx]
def _build_candidates(
self,
seq_hidden: torch.Tensor,
seq_ids: torch.Tensor | None,
anchor,
) -> list[FutureProposalCandidate]:
seq_len = seq_hidden.size(0)
start, stop, span_len = self._search_bounds(anchor, seq_len)
if stop <= start:
return []
available = stop - start
lengths = self._candidate_lengths(span_len, available)
if not lengths:
return []
anchor_hidden_span = seq_hidden[int(anchor.start_idx): int(anchor.end_idx) + 1]
anchor_delta = (
anchor_hidden_span[1:] - anchor_hidden_span[:-1]
if anchor_hidden_span.size(0) > 1
else None
)
candidates: list[FutureProposalCandidate] = []
for length in lengths:
max_offset = stop - length + 1
for offset in range(start, max_offset):
window_hidden = seq_hidden[offset: offset + length]
window_mean = window_hidden.mean(dim=0)
mean_sim = self._cosine01_tensor(anchor.repr, window_mean)
contrast = 1.0 - mean_sim
if anchor_delta is not None and anchor_delta.numel() > 0 and window_hidden.size(0) > 1:
window_delta = window_hidden[1:] - window_hidden[:-1]
transition_sim = self._cosine01_tensor(anchor_delta.mean(dim=0), window_delta.mean(dim=0))
else:
transition_sim = mean_sim
coherence = ((F.cosine_similarity(window_hidden, window_mean.unsqueeze(0), dim=-1) + 1.0) * 0.5).mean()
tail_hidden = seq_hidden[offset + length: stop]
if tail_hidden.numel() > 0:
tail_support = self._cosine01_tensor(window_mean, tail_hidden.mean(dim=0))
else:
tail_support = coherence
if seq_ids is None:
token_overlap = seq_hidden.new_tensor(0.0)
root_token = None
else:
anchor_ids = seq_ids[int(anchor.start_idx): int(anchor.end_idx) + 1]
window_ids = seq_ids[offset: offset + length]
anchor_token_set = {int(token) for token in anchor_ids.tolist()}
window_token_set = {int(token) for token in window_ids.tolist()}
token_overlap = seq_hidden.new_tensor(
len(anchor_token_set & window_token_set) / max(len(anchor_token_set), 1)
)
root_token = int(window_ids[-1].item())
distance = max(0, offset - int(anchor.end_idx))
distance_decay = seq_hidden.new_tensor(1.0 / (1.0 + distance / max(float(span_len), 1.0)))
pressure = seq_hidden.new_tensor(float(anchor.contradiction_pressure))
viability_gap = seq_hidden.new_tensor(1.0 - float(anchor.viability))
descendant_gap = seq_hidden.new_tensor(1.0 - float(anchor.descendant_coherence or 0.0))
conflict_signal = 0.55 * contrast + 0.25 * (1.0 - transition_sim) + 0.20 * (1.0 - token_overlap)
plausibility = 0.45 * coherence + 0.35 * tail_support + 0.20 * distance_decay
repair_readiness = 0.60 * pressure + 0.40 * viability_gap
if float(conflict_signal.item()) < 0.18 or float(repair_readiness.item()) < 0.35:
continue
feature_vec = torch.stack(
[
contrast,
mean_sim,
transition_sim,
coherence,
tail_support,
token_overlap,
distance_decay,
pressure,
viability_gap,
descendant_gap,
],
dim=0,
).to(device=seq_hidden.device, dtype=seq_hidden.dtype)
learned_logit = 0.25 * self.score_mlp(feature_vec.unsqueeze(0)).squeeze(0).squeeze(-1)
heuristic_logit = (
2.4 * (conflict_signal - 0.35)
+ 2.0 * (plausibility - 0.55)
+ 1.4 * (repair_readiness - 0.50)
+ 0.5 * (descendant_gap - 0.35)
)
score = torch.sigmoid(
(heuristic_logit + learned_logit) / max(float(self.cfg.anchor_future_proposal_temperature), 1e-6)
)
candidates.append(
FutureProposalCandidate(
start=offset,
end=offset + length - 1,
repr=window_mean,
score=score,
root_token=root_token,
)
)
return self._subsample_candidates(candidates)
def propose(
self,
seq_hidden: torch.Tensor,
seq_ids: torch.Tensor | None,
anchor,
) -> dict | None:
candidates = self._build_candidates(seq_hidden=seq_hidden, seq_ids=seq_ids, anchor=anchor)
if not candidates:
return None
scores = torch.stack([candidate.score for candidate in candidates], dim=0)
best_score, best_idx = scores.max(dim=0)
if float(best_score.item()) < float(self.cfg.anchor_future_proposal_threshold):
return None
topk = min(int(self.cfg.anchor_future_proposal_topk), len(candidates))
top_scores, top_idx = torch.topk(scores, k=topk)
top_weights = torch.softmax(
top_scores / max(float(self.cfg.anchor_future_proposal_temperature), 1e-6),
dim=0,
)
top_repr = torch.stack([candidates[int(idx.item())].repr for idx in top_idx], dim=0)
anchor_repr = anchor.repr.unsqueeze(0).expand_as(top_repr)
fusion_in = torch.cat(
[anchor_repr, top_repr, top_repr - anchor_repr, top_repr * anchor_repr],
dim=-1,
)
fused_repr = top_repr + float(self.cfg.anchor_future_proposal_residual_scale) * self.repr_delta(fusion_in)
proposal_repr = (top_weights.unsqueeze(-1) * fused_repr).sum(dim=0)
best_candidate = candidates[int(best_idx.item())]
return {
"repr": proposal_repr,
"proposal_type": "future_window_head",
"proposal_score": float(best_score.item()),
"proposal_score_tensor": best_score,
"proposal_span": (best_candidate.start, best_candidate.end),
"proposal_root_token": best_candidate.root_token,
"proposal_candidate_count": len(candidates),
}