Spaces:
Sleeping
Sleeping
File size: 9,567 Bytes
8cac211 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | """
CES Activation Persistence β Save and restore node activation state across sessions.
Writes a JSON sidecar file alongside the main msgpack checkpoint that
captures each node's voltage, last-fired step, and excitability. On
restore, a temporal decay is applied based on elapsed wall-clock time so
that stale activations fade naturally.
Usage::
from activation_persistence import ActivationPersistence
ap = ActivationPersistence(ces_config)
ap.save(graph, "/path/to/checkpoint.msgpack")
# ... later ...
ap.restore(graph, "/path/to/checkpoint.msgpack")
# ---- Changelog ----
# [2026-02-22] Claude (Opus 4.6) β Initial implementation.
# What: ActivationPersistence with capture/save/restore, exponential
# temporal decay, max_entries bounding, and auto-save timer.
# Why: The main msgpack checkpoint resets all voltages to resting
# potential. This sidecar preserves "warm" activations across
# sessions so the SNN resumes where it left off rather than
# starting cold.
# -------------------
"""
from __future__ import annotations
import json
import logging
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional
from ces_config import CESConfig
logger = logging.getLogger("neurograph.ces.persistence")
class ActivationPersistence:
"""Manages activation state sidecar files for cross-session persistence.
Args:
ces_config: ``CESConfig`` with persistence parameters.
"""
def __init__(self, ces_config: CESConfig) -> None:
self._cfg = ces_config.persistence
self._last_save_time: Optional[float] = None
self._last_decay_applied: float = 0.0
self._entries_saved: int = 0
self._sidecar_path: Optional[str] = None
# Auto-save timer state
self._auto_save_timer: Optional[threading.Timer] = None
self._auto_save_graph: Any = None
self._auto_save_checkpoint_path: Optional[str] = None
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def capture(self, graph: Any) -> Dict[str, Dict[str, Any]]:
"""Capture current activation state from the graph.
Returns a dict keyed by node_id with voltage, last_spike_time,
intrinsic_excitability, and a wall-clock timestamp.
Only includes nodes with non-zero voltage to keep the sidecar
compact. Bounded to ``max_entries`` by voltage magnitude.
"""
now = time.time()
entries: Dict[str, Dict[str, Any]] = {}
for node_id, node in graph.nodes.items():
if node.voltage == 0.0 or node.voltage == node.resting_potential:
continue
entries[node_id] = {
"voltage": node.voltage,
"last_spike_time": node.last_spike_time,
"excitability": node.intrinsic_excitability,
"timestamp": now,
}
# Bound to max_entries by descending absolute voltage
max_entries = self._cfg.max_entries
if len(entries) > max_entries:
sorted_ids = sorted(
entries.keys(),
key=lambda nid: abs(entries[nid]["voltage"]),
reverse=True,
)
entries = {nid: entries[nid] for nid in sorted_ids[:max_entries]}
return entries
def save(self, graph: Any, checkpoint_path: str) -> str:
"""Capture and write activation sidecar next to the checkpoint.
Args:
graph: The ``Graph`` instance to capture from.
checkpoint_path: Path to the main checkpoint file.
Returns:
Path to the written sidecar file.
"""
entries = self.capture(graph)
sidecar_path = self._sidecar_path_for(checkpoint_path)
data = {
"version": "1.0",
"saved_at": time.time(),
"timestep": graph.timestep,
"entries": entries,
}
try:
with open(sidecar_path, "w") as f:
json.dump(data, f)
self._last_save_time = data["saved_at"]
self._entries_saved = len(entries)
self._sidecar_path = sidecar_path
logger.info(
"Activation sidecar saved: %d entries to %s",
len(entries),
sidecar_path,
)
except Exception as exc:
logger.warning("Failed to save activation sidecar: %s", exc)
return sidecar_path
def restore(self, graph: Any, checkpoint_path: str) -> int:
"""Restore activation state from sidecar with temporal decay.
Applies exponential decay based on elapsed wall-clock time:
``activation *= (1 - decay_per_hour) ^ elapsed_hours``
Entries below ``min_activation`` after decay are discarded.
Entries for nodes no longer in the graph are skipped.
Args:
graph: The ``Graph`` instance to restore into.
checkpoint_path: Path to the main checkpoint file.
Returns:
Number of nodes restored.
"""
sidecar_path = self._sidecar_path_for(checkpoint_path)
if not Path(sidecar_path).exists():
logger.debug("No activation sidecar at %s", sidecar_path)
return 0
try:
with open(sidecar_path) as f:
data = json.load(f)
except Exception as exc:
logger.warning("Failed to read activation sidecar: %s", exc)
return 0
entries = data.get("entries", {})
saved_at = data.get("saved_at", time.time())
# Apply temporal decay
elapsed_hours = (time.time() - saved_at) / 3600.0
if elapsed_hours < 0:
elapsed_hours = 0.0
self._last_decay_applied = elapsed_hours
entries = self._apply_decay(entries, elapsed_hours)
# Inject into graph
restored = 0
for node_id, state in entries.items():
node = graph.nodes.get(node_id)
if node is None:
continue
node.voltage = state["voltage"]
if "excitability" in state:
node.intrinsic_excitability = state["excitability"]
restored += 1
self._sidecar_path = sidecar_path
logger.info(
"Restored %d/%d activation entries (%.1fh decay applied)",
restored,
len(entries),
elapsed_hours,
)
return restored
def start_auto_save(self, graph: Any, checkpoint_path: str) -> None:
"""Start periodic auto-saving of activation state."""
self._auto_save_graph = graph
self._auto_save_checkpoint_path = checkpoint_path
self._schedule_auto_save()
def stop_auto_save(self) -> None:
"""Stop the auto-save timer."""
if self._auto_save_timer is not None:
self._auto_save_timer.cancel()
self._auto_save_timer = None
self._auto_save_graph = None
self._auto_save_checkpoint_path = None
def get_stats(self) -> Dict[str, Any]:
"""Return persistence statistics."""
return {
"entries_saved": self._entries_saved,
"last_save_time": self._last_save_time,
"last_decay_applied": round(self._last_decay_applied, 2),
"sidecar_path": self._sidecar_path,
}
# ββ Internal βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _sidecar_path_for(self, checkpoint_path: str) -> str:
"""Compute the sidecar file path for a given checkpoint path."""
return checkpoint_path + self._cfg.sidecar_suffix
def _apply_decay(
self,
entries: Dict[str, Dict[str, Any]],
elapsed_hours: float,
) -> Dict[str, Dict[str, Any]]:
"""Apply exponential decay and prune sub-threshold entries."""
if elapsed_hours <= 0:
return entries
decay_factor = (1.0 - self._cfg.decay_per_hour) ** elapsed_hours
min_act = self._cfg.min_activation
result: Dict[str, Dict[str, Any]] = {}
for node_id, state in entries.items():
decayed_voltage = state["voltage"] * decay_factor
if abs(decayed_voltage) < min_act:
continue
result[node_id] = {
**state,
"voltage": decayed_voltage,
}
return result
def _schedule_auto_save(self) -> None:
"""Schedule the next auto-save tick."""
if self._auto_save_graph is None:
return
self._auto_save_timer = threading.Timer(
self._cfg.auto_save_interval,
self._auto_save_tick,
)
self._auto_save_timer.daemon = True
self._auto_save_timer.start()
def _auto_save_tick(self) -> None:
"""Auto-save callback."""
# Check if we've been stopped before doing anything β prevents orphan
# timer chains that keep rescheduling after stop_auto_save().
if self._auto_save_graph is None or self._auto_save_checkpoint_path is None:
return
try:
self.save(self._auto_save_graph, self._auto_save_checkpoint_path)
except Exception as exc:
logger.warning("Auto-save failed: %s", exc)
self._schedule_auto_save()
|