abpt / src /model /anchor_viability.py
Search
feat: add src/ module for script imports
8125804
from __future__ import annotations
import math
import torch.nn as nn
from src.model.anchor_types import AnchorRecord, AnchorState
from src.model.config import ModelConfig
class ViabilityTracker(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
def forward(
self,
anchors: list[list[AnchorRecord]],
contradiction: dict,
) -> dict:
pressure_map = contradiction["contradiction_pressure"]
viability: dict[int, float] = {}
state_updates: dict[int, AnchorState] = {}
for batch_anchors in anchors:
for anchor in batch_anchors:
support = self._to_float(anchor.support)
pressure = float(pressure_map.get(anchor.id, self._to_float(anchor.contradiction_pressure)))
age_penalty = 1.0 / max(self._to_float(anchor.ttl) + 1.0, 1.0)
descendant_mass = self._to_float(anchor.descendant_mass or 0.0)
descendant_coherence = self._to_float(anchor.descendant_coherence or 0.0)
raw = (
self.cfg.anchor_viability_alpha * support
- self.cfg.anchor_viability_beta * pressure
- self.cfg.anchor_age_gamma * age_penalty
+ self.cfg.anchor_descendant_mass_delta * descendant_mass
+ self.cfg.anchor_descendant_coherence_eta * descendant_coherence
)
current_viability = 1.0 / (1.0 + math.exp(-raw))
anchor.viability = current_viability
viability[anchor.id] = current_viability
if anchor.state == AnchorState.CANDIDATE:
if current_viability >= self.cfg.anchor_confirm_threshold and pressure <= self.cfg.anchor_contradiction_threshold:
next_state = AnchorState.CONFIRMED
elif current_viability >= self.cfg.anchor_candidate_promote_threshold:
next_state = AnchorState.PROVISIONAL
elif pressure >= self.cfg.anchor_dead_end_threshold and self._to_float(anchor.ttl) <= 1.0:
next_state = AnchorState.DEAD_END
else:
next_state = AnchorState.CANDIDATE
elif anchor.state == AnchorState.PROVISIONAL:
if current_viability >= self.cfg.anchor_confirm_threshold and pressure <= self.cfg.anchor_contradiction_threshold:
next_state = AnchorState.CONFIRMED
elif current_viability <= self.cfg.anchor_revision_threshold and pressure >= self.cfg.anchor_contradiction_threshold:
next_state = AnchorState.DEAD_END
else:
next_state = AnchorState.PROVISIONAL
elif anchor.state == AnchorState.CONFIRMED:
if self._to_float(anchor.ttl) <= 1.0:
next_state = AnchorState.DECAYING
elif current_viability <= self.cfg.anchor_revision_threshold and pressure >= self.cfg.anchor_contradiction_threshold:
next_state = AnchorState.DEAD_END
else:
next_state = AnchorState.CONFIRMED
elif anchor.state == AnchorState.DECAYING:
if current_viability <= self.cfg.anchor_revision_threshold:
next_state = AnchorState.DEAD_END
else:
next_state = AnchorState.DECAYING
else:
next_state = AnchorState.DEAD_END
state_updates[anchor.id] = next_state
return {
"viability": viability,
"state_updates": state_updates,
}
@staticmethod
def _to_float(value: float) -> float:
return float(value)