NuWave / nuwave /substrate /activation_persistence.py
Executor-Tyrant-Framework's picture
Sync from GitHub: 3263f1f0fcd371ae72208539b2fa4acd56585269
8cac211 verified
"""
CES Activation Persistence β€” Save and restore node activation state across sessions.
Writes a JSON sidecar file alongside the main msgpack checkpoint that
captures each node's voltage, last-fired step, and excitability. On
restore, a temporal decay is applied based on elapsed wall-clock time so
that stale activations fade naturally.
Usage::
from activation_persistence import ActivationPersistence
ap = ActivationPersistence(ces_config)
ap.save(graph, "/path/to/checkpoint.msgpack")
# ... later ...
ap.restore(graph, "/path/to/checkpoint.msgpack")
# ---- Changelog ----
# [2026-02-22] Claude (Opus 4.6) β€” Initial implementation.
# What: ActivationPersistence with capture/save/restore, exponential
# temporal decay, max_entries bounding, and auto-save timer.
# Why: The main msgpack checkpoint resets all voltages to resting
# potential. This sidecar preserves "warm" activations across
# sessions so the SNN resumes where it left off rather than
# starting cold.
# -------------------
"""
from __future__ import annotations
import json
import logging
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional
from ces_config import CESConfig
logger = logging.getLogger("neurograph.ces.persistence")
class ActivationPersistence:
"""Manages activation state sidecar files for cross-session persistence.
Args:
ces_config: ``CESConfig`` with persistence parameters.
"""
def __init__(self, ces_config: CESConfig) -> None:
self._cfg = ces_config.persistence
self._last_save_time: Optional[float] = None
self._last_decay_applied: float = 0.0
self._entries_saved: int = 0
self._sidecar_path: Optional[str] = None
# Auto-save timer state
self._auto_save_timer: Optional[threading.Timer] = None
self._auto_save_graph: Any = None
self._auto_save_checkpoint_path: Optional[str] = None
# ── Public API ─────────────────────────────────────────────────────
def capture(self, graph: Any) -> Dict[str, Dict[str, Any]]:
"""Capture current activation state from the graph.
Returns a dict keyed by node_id with voltage, last_spike_time,
intrinsic_excitability, and a wall-clock timestamp.
Only includes nodes with non-zero voltage to keep the sidecar
compact. Bounded to ``max_entries`` by voltage magnitude.
"""
now = time.time()
entries: Dict[str, Dict[str, Any]] = {}
for node_id, node in graph.nodes.items():
if node.voltage == 0.0 or node.voltage == node.resting_potential:
continue
entries[node_id] = {
"voltage": node.voltage,
"last_spike_time": node.last_spike_time,
"excitability": node.intrinsic_excitability,
"timestamp": now,
}
# Bound to max_entries by descending absolute voltage
max_entries = self._cfg.max_entries
if len(entries) > max_entries:
sorted_ids = sorted(
entries.keys(),
key=lambda nid: abs(entries[nid]["voltage"]),
reverse=True,
)
entries = {nid: entries[nid] for nid in sorted_ids[:max_entries]}
return entries
def save(self, graph: Any, checkpoint_path: str) -> str:
"""Capture and write activation sidecar next to the checkpoint.
Args:
graph: The ``Graph`` instance to capture from.
checkpoint_path: Path to the main checkpoint file.
Returns:
Path to the written sidecar file.
"""
entries = self.capture(graph)
sidecar_path = self._sidecar_path_for(checkpoint_path)
data = {
"version": "1.0",
"saved_at": time.time(),
"timestep": graph.timestep,
"entries": entries,
}
try:
with open(sidecar_path, "w") as f:
json.dump(data, f)
self._last_save_time = data["saved_at"]
self._entries_saved = len(entries)
self._sidecar_path = sidecar_path
logger.info(
"Activation sidecar saved: %d entries to %s",
len(entries),
sidecar_path,
)
except Exception as exc:
logger.warning("Failed to save activation sidecar: %s", exc)
return sidecar_path
def restore(self, graph: Any, checkpoint_path: str) -> int:
"""Restore activation state from sidecar with temporal decay.
Applies exponential decay based on elapsed wall-clock time:
``activation *= (1 - decay_per_hour) ^ elapsed_hours``
Entries below ``min_activation`` after decay are discarded.
Entries for nodes no longer in the graph are skipped.
Args:
graph: The ``Graph`` instance to restore into.
checkpoint_path: Path to the main checkpoint file.
Returns:
Number of nodes restored.
"""
sidecar_path = self._sidecar_path_for(checkpoint_path)
if not Path(sidecar_path).exists():
logger.debug("No activation sidecar at %s", sidecar_path)
return 0
try:
with open(sidecar_path) as f:
data = json.load(f)
except Exception as exc:
logger.warning("Failed to read activation sidecar: %s", exc)
return 0
entries = data.get("entries", {})
saved_at = data.get("saved_at", time.time())
# Apply temporal decay
elapsed_hours = (time.time() - saved_at) / 3600.0
if elapsed_hours < 0:
elapsed_hours = 0.0
self._last_decay_applied = elapsed_hours
entries = self._apply_decay(entries, elapsed_hours)
# Inject into graph
restored = 0
for node_id, state in entries.items():
node = graph.nodes.get(node_id)
if node is None:
continue
node.voltage = state["voltage"]
if "excitability" in state:
node.intrinsic_excitability = state["excitability"]
restored += 1
self._sidecar_path = sidecar_path
logger.info(
"Restored %d/%d activation entries (%.1fh decay applied)",
restored,
len(entries),
elapsed_hours,
)
return restored
def start_auto_save(self, graph: Any, checkpoint_path: str) -> None:
"""Start periodic auto-saving of activation state."""
self._auto_save_graph = graph
self._auto_save_checkpoint_path = checkpoint_path
self._schedule_auto_save()
def stop_auto_save(self) -> None:
"""Stop the auto-save timer."""
if self._auto_save_timer is not None:
self._auto_save_timer.cancel()
self._auto_save_timer = None
self._auto_save_graph = None
self._auto_save_checkpoint_path = None
def get_stats(self) -> Dict[str, Any]:
"""Return persistence statistics."""
return {
"entries_saved": self._entries_saved,
"last_save_time": self._last_save_time,
"last_decay_applied": round(self._last_decay_applied, 2),
"sidecar_path": self._sidecar_path,
}
# ── Internal ───────────────────────────────────────────────────────
def _sidecar_path_for(self, checkpoint_path: str) -> str:
"""Compute the sidecar file path for a given checkpoint path."""
return checkpoint_path + self._cfg.sidecar_suffix
def _apply_decay(
self,
entries: Dict[str, Dict[str, Any]],
elapsed_hours: float,
) -> Dict[str, Dict[str, Any]]:
"""Apply exponential decay and prune sub-threshold entries."""
if elapsed_hours <= 0:
return entries
decay_factor = (1.0 - self._cfg.decay_per_hour) ** elapsed_hours
min_act = self._cfg.min_activation
result: Dict[str, Dict[str, Any]] = {}
for node_id, state in entries.items():
decayed_voltage = state["voltage"] * decay_factor
if abs(decayed_voltage) < min_act:
continue
result[node_id] = {
**state,
"voltage": decayed_voltage,
}
return result
def _schedule_auto_save(self) -> None:
"""Schedule the next auto-save tick."""
if self._auto_save_graph is None:
return
self._auto_save_timer = threading.Timer(
self._cfg.auto_save_interval,
self._auto_save_tick,
)
self._auto_save_timer.daemon = True
self._auto_save_timer.start()
def _auto_save_tick(self) -> None:
"""Auto-save callback."""
# Check if we've been stopped before doing anything β€” prevents orphan
# timer chains that keep rescheduling after stop_auto_save().
if self._auto_save_graph is None or self._auto_save_checkpoint_path is None:
return
try:
self.save(self._auto_save_graph, self._auto_save_checkpoint_path)
except Exception as exc:
logger.warning("Auto-save failed: %s", exc)
self._schedule_auto_save()