File size: 9,743 Bytes
dbb04e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""

Bayesian Long-Term Potentiation (LTP) Feedback Loop (Phase 4.0)

================================================================

Replaces the simple Hebbian LTP update with a Bayesian reliability model.



Core idea:

  Each synaptic connection and each memory node maintains a Beta distribution

  over its "true reliability" p ~ Beta(α, β).



  - α = accumulated success evidence (hits, correct retrievals)

  - β = accumulated failure evidence (misses, wrong retrievals, decay)



  The posterior mean  E[p] = α / (α + β)  is used as the reliability estimate.

  The posterior variance Var[p] = αβ / ((α+β)²(α+β+1)) reflects uncertainty.



  On each firing:

    success → α += 1  (evidence for reliability)

    failure → β += 1  (evidence for unreliability)



  The LTP update on MemoryNode follows the same Beta model, where:

    - "success" events: retrieval that helped produce a good answer

    - "failure" events: retrieval miss, low-EIG storage, or forced decay



Benefits over plain Hebbian:

  - Uncertainty-aware: new synapses have wide credible intervals → exploration bonus

  - Natural regularization: α and β act as pseudo-counts preventing overconfidence

  - Compatible with existing strength/ltp_strength fields (posterior mean replaces raw strength)



Public API:

    updater = BayesianLTPUpdater()

    updater.observe_synapse(synapse, success=True)

    strength = updater.posterior_mean(synapse)

    uncertainty = updater.posterior_uncertainty(synapse)

"""

from __future__ import annotations

import math
from dataclasses import dataclass, field
from loguru import logger


# ------------------------------------------------------------------ #
#  Beta distribution helpers                                          #
# ------------------------------------------------------------------ #

def _beta_mean(alpha: float, beta: float) -> float:
    """E[p] = α / (α + β)."""
    total = alpha + beta
    if total <= 0:
        return 0.5
    return alpha / total


def _beta_variance(alpha: float, beta: float) -> float:
    """Var[p] = αβ / ((α+β)²(α+β+1))."""
    total = alpha + beta
    if total <= 0:
        return 0.25   # Maximum variance of Beta(1,1)
    return (alpha * beta) / (total * total * (total + 1.0))


def _beta_std(alpha: float, beta: float) -> float:
    return math.sqrt(_beta_variance(alpha, beta))


def _beta_upper_credible(alpha: float, beta: float, z: float = 1.65) -> float:
    """

    Approximate upper credible bound using normal approximation.

    z=1.65 ≈ 90th percentile.  Used for UCB-style exploration bonus.

    """
    return min(1.0, _beta_mean(alpha, beta) + z * _beta_std(alpha, beta))


# ------------------------------------------------------------------ #
#  Mixin state stored alongside SynapticConnection / MemoryNode      #
# ------------------------------------------------------------------ #

@dataclass
class BayesianState:
    """

    Lightweight Beta distribution state for Bayesian LTP.

    Stored as extra fields; zero overhead when not used.



    alpha_prior / beta_prior: informative priors (default: uninformative Beta(1,1))

    """
    alpha: float = 1.0   # success pseudo-count
    beta_count: float = 1.0  # failure pseudo-count  (renamed to avoid clash with scipy.beta)

    def observe(self, success: bool, strength: float = 1.0) -> None:
        """

        Update posterior given an observation.



        Args:

            success: True → α += strength, False → β += strength

            strength: Fractional evidence weight (default 1.0).

        """
        if success:
            self.alpha += strength
        else:
            self.beta_count += strength

    @property
    def mean(self) -> float:
        return _beta_mean(self.alpha, self.beta_count)

    @property
    def uncertainty(self) -> float:
        """Standard deviation of the posterior."""
        return _beta_std(self.alpha, self.beta_count)

    @property
    def upper_credible(self) -> float:
        """90th percentile upper bound (UCB exploration bonus)."""
        return _beta_upper_credible(self.alpha, self.beta_count)

    @property
    def total_observations(self) -> float:
        # Subtract initial priors so total_observations = 0 when untouched
        return (self.alpha - 1.0) + (self.beta_count - 1.0)

    def to_dict(self) -> dict:
        return {"alpha": self.alpha, "beta": self.beta_count}

    @classmethod
    def from_dict(cls, d: dict) -> "BayesianState":
        return cls(alpha=d.get("alpha", 1.0), beta_count=d.get("beta", 1.0))


# ------------------------------------------------------------------ #
#  Core updater                                                       #
# ------------------------------------------------------------------ #

class BayesianLTPUpdater:
    """

    Manages Bayesian LTP state for synapses and memory nodes.



    Attach BayesianState to objects lazily to avoid changing data-class

    signatures across the codebase.

    """

    _ATTR = "_bayes"   # attribute name injected onto target objects

    # ---- Synapse helpers ------------------------------------------ #

    def get_synapse_state(self, synapse) -> BayesianState:
        """Get (or create) BayesianState for a SynapticConnection."""
        if not hasattr(synapse, self._ATTR):
            # Bootstrap from existing strength as evidence ratio
            s = synapse.strength
            # Seed: alpha ∝ successes, beta ∝ failures, total = fire_count
            fc = max(synapse.fire_count, 1)
            sc = max(synapse.success_count, 0)
            alpha = 1.0 + sc
            beta_count = 1.0 + (fc - sc)
            object.__setattr__(synapse, self._ATTR, BayesianState(alpha=alpha, beta_count=beta_count))
        return getattr(synapse, self._ATTR)

    def observe_synapse(self, synapse, success: bool, weight: float = 1.0) -> None:
        """

        Update Bayesian posterior for a synapse and synchronize back to

        the SynapticConnection.strength field (as posterior mean).

        """
        state = self.get_synapse_state(synapse)
        state.observe(success=success, strength=weight)
        # Write posterior mean back to the canonical `.strength` field
        synapse.strength = state.mean
        logger.debug(
            f"Synapse ({synapse.neuron_a_id[:8]}{synapse.neuron_b_id[:8]}) "
            f"Bayesian update — success={success} "
            f"α={state.alpha:.2f} β={state.beta_count:.2f} "
            f"→ p_mean={state.mean:.4f} ± {state.uncertainty:.4f}"
        )

    def synapse_strength_ucb(self, synapse) -> float:
        """

        Return the UCB (Upper Credible Bound) strength for exploration.

        Prefer under-explored synapses during associative spreading.

        """
        state = self.get_synapse_state(synapse)
        return state.upper_credible

    # ---- MemoryNode helpers --------------------------------------- #

    def get_node_state(self, node) -> BayesianState:
        """Get (or create) BayesianState for a MemoryNode."""
        if not hasattr(node, self._ATTR):
            # Bootstrap from epistemic + pragmatic values
            ev = getattr(node, "epistemic_value", 0.5)
            pv = getattr(node, "pragmatic_value", 0.0)
            combined = (ev + pv) / 2.0
            ac = max(getattr(node, "access_count", 1), 1)
            alpha = 1.0 + combined * ac
            beta_count = 1.0 + (1.0 - combined) * ac
            object.__setattr__(node, self._ATTR, BayesianState(alpha=alpha, beta_count=beta_count))
        return getattr(node, self._ATTR)

    def observe_node_retrieval(

        self, node, helpful: bool, eig_signal: float = 1.0

    ) -> float:
        """

        Record a retrieval outcome for a MemoryNode.



        Args:

            node: MemoryNode instance.

            helpful: Was this retrieval actually useful?

            eig_signal: Epistemic Information Gain from context (0–1).

                        Used as evidence weight: higher EIG → stronger update.



        Returns:

            Updated posterior mean LTP strength.

        """
        state = self.get_node_state(node)
        state.observe(success=helpful, strength=eig_signal)
        # Synchronize back to node.ltp_strength
        node.ltp_strength = state.mean
        logger.debug(
            f"Node {node.id[:8]} Bayesian retrieval update — helpful={helpful} "
            f"eig={eig_signal:.3f} → ltp={node.ltp_strength:.4f}"
        )
        return node.ltp_strength

    def node_ltp_ucb(self, node) -> float:
        """UCB estimate for node retrieval priority (exploration bonus)."""
        state = self.get_node_state(node)
        return state.upper_credible

    # ---- Serialization helpers ----------------------------------- #

    def synapse_to_dict(self, synapse) -> dict:
        """Serialize Bayesian state for persistence."""
        state = self.get_synapse_state(synapse)
        return state.to_dict()

    def synapse_from_dict(self, synapse, d: dict) -> None:
        """Restore Bayesian state from persisted dict."""
        state = BayesianState.from_dict(d)
        object.__setattr__(synapse, self._ATTR, state)
        synapse.strength = state.mean


# Module-level singleton
_UPDATER: BayesianLTPUpdater | None = None


def get_bayesian_updater() -> BayesianLTPUpdater:
    """Get the global Bayesian LTP updater singleton."""
    global _UPDATER
    if _UPDATER is None:
        _UPDATER = BayesianLTPUpdater()
    return _UPDATER