MnemoCore / src /mnemocore /core /bayesian_ltp.py
Granis87's picture
Initial upload of MnemoCore
dbb04e4 verified
"""
Bayesian Long-Term Potentiation (LTP) Feedback Loop (Phase 4.0)
================================================================
Replaces the simple Hebbian LTP update with a Bayesian reliability model.
Core idea:
Each synaptic connection and each memory node maintains a Beta distribution
over its "true reliability" p ~ Beta(α, β).
- α = accumulated success evidence (hits, correct retrievals)
- β = accumulated failure evidence (misses, wrong retrievals, decay)
The posterior mean E[p] = α / (α + β) is used as the reliability estimate.
The posterior variance Var[p] = αβ / ((α+β)²(α+β+1)) reflects uncertainty.
On each firing:
success → α += 1 (evidence for reliability)
failure → β += 1 (evidence for unreliability)
The LTP update on MemoryNode follows the same Beta model, where:
- "success" events: retrieval that helped produce a good answer
- "failure" events: retrieval miss, low-EIG storage, or forced decay
Benefits over plain Hebbian:
- Uncertainty-aware: new synapses have wide credible intervals → exploration bonus
- Natural regularization: α and β act as pseudo-counts preventing overconfidence
- Compatible with existing strength/ltp_strength fields (posterior mean replaces raw strength)
Public API:
updater = BayesianLTPUpdater()
updater.observe_synapse(synapse, success=True)
strength = updater.posterior_mean(synapse)
uncertainty = updater.posterior_uncertainty(synapse)
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from loguru import logger
# ------------------------------------------------------------------ #
# Beta distribution helpers #
# ------------------------------------------------------------------ #
def _beta_mean(alpha: float, beta: float) -> float:
"""E[p] = α / (α + β)."""
total = alpha + beta
if total <= 0:
return 0.5
return alpha / total
def _beta_variance(alpha: float, beta: float) -> float:
"""Var[p] = αβ / ((α+β)²(α+β+1))."""
total = alpha + beta
if total <= 0:
return 0.25 # Maximum variance of Beta(1,1)
return (alpha * beta) / (total * total * (total + 1.0))
def _beta_std(alpha: float, beta: float) -> float:
return math.sqrt(_beta_variance(alpha, beta))
def _beta_upper_credible(alpha: float, beta: float, z: float = 1.65) -> float:
"""
Approximate upper credible bound using normal approximation.
z=1.65 ≈ 90th percentile. Used for UCB-style exploration bonus.
"""
return min(1.0, _beta_mean(alpha, beta) + z * _beta_std(alpha, beta))
# ------------------------------------------------------------------ #
# Mixin state stored alongside SynapticConnection / MemoryNode #
# ------------------------------------------------------------------ #
@dataclass
class BayesianState:
"""
Lightweight Beta distribution state for Bayesian LTP.
Stored as extra fields; zero overhead when not used.
alpha_prior / beta_prior: informative priors (default: uninformative Beta(1,1))
"""
alpha: float = 1.0 # success pseudo-count
beta_count: float = 1.0 # failure pseudo-count (renamed to avoid clash with scipy.beta)
def observe(self, success: bool, strength: float = 1.0) -> None:
"""
Update posterior given an observation.
Args:
success: True → α += strength, False → β += strength
strength: Fractional evidence weight (default 1.0).
"""
if success:
self.alpha += strength
else:
self.beta_count += strength
@property
def mean(self) -> float:
return _beta_mean(self.alpha, self.beta_count)
@property
def uncertainty(self) -> float:
"""Standard deviation of the posterior."""
return _beta_std(self.alpha, self.beta_count)
@property
def upper_credible(self) -> float:
"""90th percentile upper bound (UCB exploration bonus)."""
return _beta_upper_credible(self.alpha, self.beta_count)
@property
def total_observations(self) -> float:
# Subtract initial priors so total_observations = 0 when untouched
return (self.alpha - 1.0) + (self.beta_count - 1.0)
def to_dict(self) -> dict:
return {"alpha": self.alpha, "beta": self.beta_count}
@classmethod
def from_dict(cls, d: dict) -> "BayesianState":
return cls(alpha=d.get("alpha", 1.0), beta_count=d.get("beta", 1.0))
# ------------------------------------------------------------------ #
# Core updater #
# ------------------------------------------------------------------ #
class BayesianLTPUpdater:
"""
Manages Bayesian LTP state for synapses and memory nodes.
Attach BayesianState to objects lazily to avoid changing data-class
signatures across the codebase.
"""
_ATTR = "_bayes" # attribute name injected onto target objects
# ---- Synapse helpers ------------------------------------------ #
def get_synapse_state(self, synapse) -> BayesianState:
"""Get (or create) BayesianState for a SynapticConnection."""
if not hasattr(synapse, self._ATTR):
# Bootstrap from existing strength as evidence ratio
s = synapse.strength
# Seed: alpha ∝ successes, beta ∝ failures, total = fire_count
fc = max(synapse.fire_count, 1)
sc = max(synapse.success_count, 0)
alpha = 1.0 + sc
beta_count = 1.0 + (fc - sc)
object.__setattr__(synapse, self._ATTR, BayesianState(alpha=alpha, beta_count=beta_count))
return getattr(synapse, self._ATTR)
def observe_synapse(self, synapse, success: bool, weight: float = 1.0) -> None:
"""
Update Bayesian posterior for a synapse and synchronize back to
the SynapticConnection.strength field (as posterior mean).
"""
state = self.get_synapse_state(synapse)
state.observe(success=success, strength=weight)
# Write posterior mean back to the canonical `.strength` field
synapse.strength = state.mean
logger.debug(
f"Synapse ({synapse.neuron_a_id[:8]}{synapse.neuron_b_id[:8]}) "
f"Bayesian update — success={success} "
f"α={state.alpha:.2f} β={state.beta_count:.2f} "
f"→ p_mean={state.mean:.4f} ± {state.uncertainty:.4f}"
)
def synapse_strength_ucb(self, synapse) -> float:
"""
Return the UCB (Upper Credible Bound) strength for exploration.
Prefer under-explored synapses during associative spreading.
"""
state = self.get_synapse_state(synapse)
return state.upper_credible
# ---- MemoryNode helpers --------------------------------------- #
def get_node_state(self, node) -> BayesianState:
"""Get (or create) BayesianState for a MemoryNode."""
if not hasattr(node, self._ATTR):
# Bootstrap from epistemic + pragmatic values
ev = getattr(node, "epistemic_value", 0.5)
pv = getattr(node, "pragmatic_value", 0.0)
combined = (ev + pv) / 2.0
ac = max(getattr(node, "access_count", 1), 1)
alpha = 1.0 + combined * ac
beta_count = 1.0 + (1.0 - combined) * ac
object.__setattr__(node, self._ATTR, BayesianState(alpha=alpha, beta_count=beta_count))
return getattr(node, self._ATTR)
def observe_node_retrieval(
self, node, helpful: bool, eig_signal: float = 1.0
) -> float:
"""
Record a retrieval outcome for a MemoryNode.
Args:
node: MemoryNode instance.
helpful: Was this retrieval actually useful?
eig_signal: Epistemic Information Gain from context (0–1).
Used as evidence weight: higher EIG → stronger update.
Returns:
Updated posterior mean LTP strength.
"""
state = self.get_node_state(node)
state.observe(success=helpful, strength=eig_signal)
# Synchronize back to node.ltp_strength
node.ltp_strength = state.mean
logger.debug(
f"Node {node.id[:8]} Bayesian retrieval update — helpful={helpful} "
f"eig={eig_signal:.3f} → ltp={node.ltp_strength:.4f}"
)
return node.ltp_strength
def node_ltp_ucb(self, node) -> float:
"""UCB estimate for node retrieval priority (exploration bonus)."""
state = self.get_node_state(node)
return state.upper_credible
# ---- Serialization helpers ----------------------------------- #
def synapse_to_dict(self, synapse) -> dict:
"""Serialize Bayesian state for persistence."""
state = self.get_synapse_state(synapse)
return state.to_dict()
def synapse_from_dict(self, synapse, d: dict) -> None:
"""Restore Bayesian state from persisted dict."""
state = BayesianState.from_dict(d)
object.__setattr__(synapse, self._ATTR, state)
synapse.strength = state.mean
# Module-level singleton
_UPDATER: BayesianLTPUpdater | None = None
def get_bayesian_updater() -> BayesianLTPUpdater:
"""Get the global Bayesian LTP updater singleton."""
global _UPDATER
if _UPDATER is None:
_UPDATER = BayesianLTPUpdater()
return _UPDATER