Spaces:
Paused
Paused
| from __future__ import annotations | |
| import math | |
| import torch.nn as nn | |
| from src.model.anchor_types import AnchorRecord, AnchorState | |
| from src.model.config import ModelConfig | |
| class ViabilityTracker(nn.Module): | |
| def __init__(self, cfg: ModelConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| def forward( | |
| self, | |
| anchors: list[list[AnchorRecord]], | |
| contradiction: dict, | |
| ) -> dict: | |
| pressure_map = contradiction["contradiction_pressure"] | |
| viability: dict[int, float] = {} | |
| state_updates: dict[int, AnchorState] = {} | |
| for batch_anchors in anchors: | |
| for anchor in batch_anchors: | |
| support = self._to_float(anchor.support) | |
| pressure = float(pressure_map.get(anchor.id, self._to_float(anchor.contradiction_pressure))) | |
| age_penalty = 1.0 / max(self._to_float(anchor.ttl) + 1.0, 1.0) | |
| descendant_mass = self._to_float(anchor.descendant_mass or 0.0) | |
| descendant_coherence = self._to_float(anchor.descendant_coherence or 0.0) | |
| raw = ( | |
| self.cfg.anchor_viability_alpha * support | |
| - self.cfg.anchor_viability_beta * pressure | |
| - self.cfg.anchor_age_gamma * age_penalty | |
| + self.cfg.anchor_descendant_mass_delta * descendant_mass | |
| + self.cfg.anchor_descendant_coherence_eta * descendant_coherence | |
| ) | |
| current_viability = 1.0 / (1.0 + math.exp(-raw)) | |
| anchor.viability = current_viability | |
| viability[anchor.id] = current_viability | |
| if anchor.state == AnchorState.CANDIDATE: | |
| if current_viability >= self.cfg.anchor_confirm_threshold and pressure <= self.cfg.anchor_contradiction_threshold: | |
| next_state = AnchorState.CONFIRMED | |
| elif current_viability >= self.cfg.anchor_candidate_promote_threshold: | |
| next_state = AnchorState.PROVISIONAL | |
| elif pressure >= self.cfg.anchor_dead_end_threshold and self._to_float(anchor.ttl) <= 1.0: | |
| next_state = AnchorState.DEAD_END | |
| else: | |
| next_state = AnchorState.CANDIDATE | |
| elif anchor.state == AnchorState.PROVISIONAL: | |
| if current_viability >= self.cfg.anchor_confirm_threshold and pressure <= self.cfg.anchor_contradiction_threshold: | |
| next_state = AnchorState.CONFIRMED | |
| elif current_viability <= self.cfg.anchor_revision_threshold and pressure >= self.cfg.anchor_contradiction_threshold: | |
| next_state = AnchorState.DEAD_END | |
| else: | |
| next_state = AnchorState.PROVISIONAL | |
| elif anchor.state == AnchorState.CONFIRMED: | |
| if self._to_float(anchor.ttl) <= 1.0: | |
| next_state = AnchorState.DECAYING | |
| elif current_viability <= self.cfg.anchor_revision_threshold and pressure >= self.cfg.anchor_contradiction_threshold: | |
| next_state = AnchorState.DEAD_END | |
| else: | |
| next_state = AnchorState.CONFIRMED | |
| elif anchor.state == AnchorState.DECAYING: | |
| if current_viability <= self.cfg.anchor_revision_threshold: | |
| next_state = AnchorState.DEAD_END | |
| else: | |
| next_state = AnchorState.DECAYING | |
| else: | |
| next_state = AnchorState.DEAD_END | |
| state_updates[anchor.id] = next_state | |
| return { | |
| "viability": viability, | |
| "state_updates": state_updates, | |
| } | |
| def _to_float(value: float) -> float: | |
| return float(value) | |