File size: 5,449 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

import torch
import torch.nn as nn

from src.model.anchor_types import AnchorCandidate, AnchorRecord, AnchorState, RevisionDecision
from src.model.config import ModelConfig


class AnchorMemory(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self._next_anchor_id = 0

    def add_candidates(
        self,
        candidates: list[list[AnchorCandidate]],
        anchors: list[list[AnchorRecord]] | None = None,
    ) -> list[list[AnchorRecord]]:
        if anchors is None:
            anchors = [[] for _ in candidates]

        for batch_anchors, batch_candidates in zip(anchors, candidates):
            for candidate in batch_candidates:
                batch_anchors.append(
                    AnchorRecord(
                        id=self._next_anchor_id,
                        start_idx=candidate.start_idx,
                        end_idx=candidate.end_idx,
                        repr=candidate.repr,
                        score=candidate.score,
                        state=AnchorState.CANDIDATE,
                        support=self._to_float(candidate.score),
                        contradiction_pressure=0.0,
                        viability=self._to_float(candidate.score),
                        ttl=float(self.cfg.anchor_ttl_init),
                        descendant_mass=0.0,
                        descendant_coherence=0.0,
                    )
                )
                self._next_anchor_id += 1
        return anchors

    def update_support(
        self,
        anchors: list[list[AnchorRecord]],
        detector_scores: torch.Tensor | None = None,
    ) -> list[list[AnchorRecord]]:
        for batch_idx, batch_anchors in enumerate(anchors):
            for anchor in batch_anchors:
                if detector_scores is not None and anchor.end_idx < detector_scores.size(1):
                    current = float(detector_scores[batch_idx, anchor.end_idx].item())
                else:
                    current = self._to_float(anchor.score)
                anchor.support = self.cfg.anchor_support_decay * self._to_float(anchor.support) + (1.0 - self.cfg.anchor_support_decay) * current
        return anchors

    def update_ttl(self, anchors: list[list[AnchorRecord]]) -> list[list[AnchorRecord]]:
        for batch_anchors in anchors:
            for anchor in batch_anchors:
                next_ttl = self._to_float(anchor.ttl) - 1.0
                anchor.ttl = max(next_ttl, 0.0)
        return anchors

    def apply_revision(
        self,
        anchors: list[list[AnchorRecord]],
        decisions: list[RevisionDecision],
    ) -> list[list[AnchorRecord]]:
        by_id = {decision.anchor_id: decision for decision in decisions}
        for batch_anchors in anchors:
            for anchor in batch_anchors:
                decision = by_id.get(anchor.id)
                if decision is None:
                    continue
                anchor.state = decision.new_state
                if decision.action == "retire":
                    anchor.viability = 0.0
                elif decision.action == "downgrade":
                    anchor.viability = min(self._to_float(anchor.viability), 0.5)
        return anchors

    def get_active_anchors(
        self,
        anchors: list[list[AnchorRecord]],
    ) -> list[list[AnchorRecord]]:
        active_states = {
            AnchorState.CANDIDATE,
            AnchorState.PROVISIONAL,
            AnchorState.CONFIRMED,
            AnchorState.DECAYING,
        }
        return [
            [anchor for anchor in batch_anchors if anchor.state in active_states]
            for batch_anchors in anchors
        ]

    def export_diagnostics(self, anchors: list[list[AnchorRecord]]) -> dict:
        flat = [anchor for batch in anchors for anchor in batch]
        if not flat:
            return {
                "num_active": 0,
                "state_counts": {state.value: 0 for state in AnchorState},
                "mean_anchor_score": 0.0,
                "mean_contradiction_pressure": 0.0,
                "mean_viability": 0.0,
                "mean_descendant_mass": 0.0,
                "mean_descendant_coherence": 0.0,
                "dead_end_count": 0,
            }

        state_counts = {state.value: 0 for state in AnchorState}
        for anchor in flat:
            state_counts[anchor.state.value] += 1

        return {
            "num_active": sum(anchor.state != AnchorState.DEAD_END for anchor in flat),
            "state_counts": state_counts,
            "mean_anchor_score": sum(self._to_float(anchor.score) for anchor in flat) / len(flat),
            "mean_contradiction_pressure": sum(self._to_float(anchor.contradiction_pressure) for anchor in flat) / len(flat),
            "mean_viability": sum(self._to_float(anchor.viability) for anchor in flat) / len(flat),
            "mean_descendant_mass": sum(self._to_float(anchor.descendant_mass or 0.0) for anchor in flat) / len(flat),
            "mean_descendant_coherence": sum(self._to_float(anchor.descendant_coherence or 0.0) for anchor in flat) / len(flat),
            "dead_end_count": state_counts[AnchorState.DEAD_END.value],
        }

    @staticmethod
    def _to_float(value: torch.Tensor | float) -> float:
        if isinstance(value, torch.Tensor):
            return float(value.detach().item())
        return float(value)