abpt / src /model /anchor_monitor.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.anchor_types import AnchorRecord
from src.model.config import ModelConfig
class ContradictionMonitor(nn.Module):
_REGIME_ROOT_ALIAS: dict[int, int] = {
11: 11,
13: 11,
16: 11,
21: 21,
22: 21,
23: 21,
31: 31,
32: 31,
33: 31,
41: 41,
42: 41,
43: 41,
44: 41,
51: 51,
52: 51,
53: 51,
}
_REGIME_COMPATIBILITY: dict[int, set[int]] = {
11: {11, 13, 16},
21: {14, 15, 21, 22, 23},
31: {15, 31, 32, 33},
41: {41, 42, 43, 44},
51: {15, 51, 52, 53},
}
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
@staticmethod
def _cosine01(a: torch.Tensor, b: torch.Tensor) -> float:
cosine = F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=-1).mean()
return float(((cosine + 1.0) * 0.5).item())
def _domain_mode(self) -> str:
if self.cfg.anchor_domain_mode in {"synthetic", "real"}:
return self.cfg.anchor_domain_mode
return "real"
def forward(
self,
hidden: torch.Tensor,
anchors: list[list[AnchorRecord]],
aux: dict | None = None,
) -> dict:
aux = aux or {}
input_ids: torch.Tensor | None = aux.get("input_ids")
pressure_by_anchor: dict[int, float] = {}
pressure_components: dict[int, dict[str, float]] = {}
domain_mode = self._domain_mode()
for batch_idx, batch_anchors in enumerate(anchors):
seq_hidden = hidden[batch_idx]
seq_ids = None if input_ids is None else input_ids[batch_idx]
T = seq_hidden.size(0)
for anchor in batch_anchors:
span_len = max(anchor.end_idx - anchor.start_idx + 1, 1)
horizon = max(int(float(anchor.ttl) * 4), span_len * 4)
start = min(anchor.end_idx + 1, T)
stop = min(start + horizon, T)
if start >= stop:
hidden_contradiction = 0.0
token_contradiction = 0.0
pattern_contradiction = 0.0
future_shift = 0.0
similarity = 1.0
descendant_mass = 0.0
descendant_coherence = 0.0
else:
future = seq_hidden[start:stop]
mean_future = future.mean(dim=0, keepdim=True)
similarity = float(F.cosine_similarity(anchor.repr.unsqueeze(0), mean_future, dim=-1).mean().item())
future_shift = float((future - anchor.repr.unsqueeze(0)).norm(dim=-1).mean().item())
hidden_contradiction = max(0.0, (1.0 - similarity) / 2.0)
if seq_ids is None:
token_contradiction = hidden_contradiction
pattern_contradiction = hidden_contradiction
descendant_mass = max(0.0, 1.0 - hidden_contradiction)
descendant_coherence = max(0.0, similarity)
elif domain_mode == "synthetic":
anchor_token = int(seq_ids[anchor.end_idx].item())
future_tokens = seq_ids[start:stop]
match_ratio = float((future_tokens == anchor_token).float().mean().item())
token_contradiction = 1.0 - match_ratio
pattern_contradiction, descendant_mass, descendant_coherence = self._pattern_stats_synthetic(
seq_ids,
anchor,
start,
stop,
)
else:
token_contradiction, pattern_contradiction, descendant_mass, descendant_coherence = (
self._pattern_stats_real(
seq_hidden=seq_hidden,
seq_ids=seq_ids,
anchor=anchor,
start=start,
stop=stop,
)
)
if seq_ids is None:
contradiction = hidden_contradiction
elif domain_mode == "synthetic":
contradiction = (
0.20 * hidden_contradiction
+ 0.20 * token_contradiction
+ 0.60 * pattern_contradiction
)
else:
contradiction = (
0.55 * hidden_contradiction
+ 0.15 * token_contradiction
+ 0.30 * pattern_contradiction
)
contradiction = float(max(0.0, min(1.0, contradiction)))
anchor.contradiction_pressure = contradiction
anchor.descendant_mass = descendant_mass
anchor.descendant_coherence = descendant_coherence
pressure_by_anchor[anchor.id] = contradiction
pressure_components[anchor.id] = {
"future_shift": future_shift,
"local_similarity": similarity,
"hidden_contradiction": hidden_contradiction,
"token_contradiction": token_contradiction,
"pattern_contradiction": pattern_contradiction,
"descendant_mass": descendant_mass,
"descendant_coherence": descendant_coherence,
"self_contradiction": contradiction,
}
return {
"contradiction_pressure": pressure_by_anchor,
"pressure_components": pressure_components,
}
@staticmethod
def _pattern_stats_synthetic(
seq_ids: torch.Tensor,
anchor: AnchorRecord,
start: int,
stop: int,
) -> tuple[float, float, float]:
anchor_span = seq_ids[anchor.start_idx: anchor.end_idx + 1]
span_len = anchor_span.numel()
future_tokens = seq_ids[start:stop]
if future_tokens.numel() < span_len:
return 1.0, 0.0, 0.0
sims: list[float] = []
root_token = ContradictionMonitor.infer_reference_root(anchor_span)
root_hits = []
regime_hits = []
pos_weights = torch.linspace(1.0, 0.4, steps=span_len, device=anchor_span.device)
pos_weights = pos_weights / pos_weights.sum()
for offset in range(0, future_tokens.numel() - span_len + 1):
window = future_tokens[offset: offset + span_len]
exact_match = float((window == anchor_span).float().mean().item())
overlap = len(set(window.tolist()) & set(anchor_span.tolist())) / max(len(set(anchor_span.tolist())), 1)
if root_token is None:
root_persistence = overlap
else:
root_persistence = float((window == root_token).float().mean().item())
positional_match = float(((window == anchor_span).float() * pos_weights).sum().item())
regime_compatibility = ContradictionMonitor.regime_compatibility_score(
window_tokens=window,
anchor_span=anchor_span,
root_token=root_token,
)
similarity = (
0.25 * exact_match
+ 0.15 * overlap
+ 0.35 * positional_match
+ 0.25 * regime_compatibility
)
sims.append(similarity)
root_hits.append(root_persistence)
regime_hits.append(regime_compatibility)
best_similarity = max(sims) if sims else 0.0
mean_root_persistence = sum(root_hits) / max(len(root_hits), 1)
mean_regime_compatibility = sum(regime_hits) / max(len(regime_hits), 1)
descendant_mass = sum(sim >= 0.6 for sim in sims) / max(len(sims), 1)
descendant_coherence = (
0.60 * (sum(sims) / max(len(sims), 1))
+ 0.25 * mean_root_persistence
+ 0.15 * mean_regime_compatibility
)
pattern_contradiction = 1.0 - (
0.60 * best_similarity
+ 0.20 * mean_root_persistence
+ 0.20 * mean_regime_compatibility
)
return pattern_contradiction, float(descendant_mass), float(descendant_coherence)
def _pattern_stats_real(
self,
seq_hidden: torch.Tensor,
seq_ids: torch.Tensor,
anchor: AnchorRecord,
start: int,
stop: int,
) -> tuple[float, float, float, float]:
anchor_hidden_span = seq_hidden[anchor.start_idx: anchor.end_idx + 1]
anchor_token_span = seq_ids[anchor.start_idx: anchor.end_idx + 1]
span_len = anchor_hidden_span.size(0)
future_hidden = seq_hidden[start:stop]
future_tokens = seq_ids[start:stop]
if future_hidden.size(0) < span_len:
return 1.0, 1.0, 0.0, 0.0
anchor_mean = anchor_hidden_span.mean(dim=0)
anchor_delta = anchor_hidden_span[1:] - anchor_hidden_span[:-1]
sims: list[float] = []
token_overlaps: list[float] = []
transition_sims: list[float] = []
for offset in range(0, future_hidden.size(0) - span_len + 1):
window_hidden = future_hidden[offset: offset + span_len]
window_tokens = future_tokens[offset: offset + span_len]
mean_sim = self._cosine01(anchor_mean, window_hidden.mean(dim=0))
if anchor_delta.numel() > 0:
window_delta = window_hidden[1:] - window_hidden[:-1]
transition_sim = self._cosine01(anchor_delta.flatten(), window_delta.flatten())
else:
transition_sim = mean_sim
token_overlap = len(set(window_tokens.tolist()) & set(anchor_token_span.tolist())) / max(
len(set(anchor_token_span.tolist())),
1,
)
similarity = 0.55 * mean_sim + 0.25 * transition_sim + 0.20 * token_overlap
sims.append(similarity)
token_overlaps.append(token_overlap)
transition_sims.append(transition_sim)
best_similarity = max(sims) if sims else 0.0
mean_similarity = sum(sims) / max(len(sims), 1)
mean_overlap = sum(token_overlaps) / max(len(token_overlaps), 1)
mean_transition = sum(transition_sims) / max(len(transition_sims), 1)
descendant_mass = sum(sim >= 0.68 for sim in sims) / max(len(sims), 1)
descendant_coherence = 0.65 * mean_similarity + 0.20 * mean_transition + 0.15 * mean_overlap
pattern_contradiction = 1.0 - (0.75 * best_similarity + 0.25 * mean_similarity)
token_contradiction = 1.0 - max(mean_overlap, best_similarity)
return (
float(token_contradiction),
float(pattern_contradiction),
float(descendant_mass),
float(descendant_coherence),
)
@classmethod
def resolve_regime_root_from_ids(
cls,
seq_ids: torch.Tensor,
span_start: int,
span_end: int,
) -> int | None:
span_tokens = seq_ids[span_start: span_end + 1]
return cls.resolve_regime_root_from_span(span_tokens)
@classmethod
def resolve_regime_root_from_span(
cls,
span_tokens: torch.Tensor,
) -> int | None:
for token in span_tokens.tolist():
root = cls._REGIME_ROOT_ALIAS.get(int(token))
if root is not None:
return root
return None
@classmethod
def infer_reference_root(
cls,
span_tokens: torch.Tensor,
) -> int | None:
root = cls.resolve_regime_root_from_span(span_tokens)
if root is not None:
return root
if span_tokens.numel() == 0:
return None
unique_tokens, counts = torch.unique(span_tokens, return_counts=True)
return int(unique_tokens[torch.argmax(counts)].item())
@classmethod
def regime_compatibility_score(
cls,
window_tokens: torch.Tensor,
anchor_span: torch.Tensor,
root_token: int | None,
) -> float:
if window_tokens.numel() == 0:
return 0.0
window_list = [int(token) for token in window_tokens.tolist()]
anchor_token_set = {int(token) for token in anchor_span.tolist()}
overlap = len(set(window_list) & anchor_token_set) / max(len(anchor_token_set), 1)
exact_match = float((window_tokens == anchor_span).float().mean().item())
if root_token is None:
root_persistence = overlap
alias_compatibility = 0.0
else:
root_persistence = sum(token == root_token for token in window_list) / max(len(window_list), 1)
alias_compatibility = sum(
cls._REGIME_ROOT_ALIAS.get(token) == root_token for token in window_list
) / max(len(window_list), 1)
allowed_tokens = cls._REGIME_COMPATIBILITY.get(root_token) if root_token is not None else None
if allowed_tokens is None:
return float(
0.45 * exact_match
+ 0.30 * overlap
+ 0.10 * alias_compatibility
+ 0.15 * root_persistence
)
hard_compatibility = sum(token in allowed_tokens for token in window_list) / max(len(window_list), 1)
return float(
0.55 * hard_compatibility
+ 0.20 * overlap
+ 0.15 * alias_compatibility
+ 0.10 * root_persistence
)