""" Checkpoint management for neural memory state. Like `docker commit` but for neural memory weights. """ import hashlib import json from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import List import torch import torch.nn as nn @dataclass class CheckpointInfo: """Metadata for a saved checkpoint.""" tag: str created_at: str size_mb: float weight_hash: str description: str = "" class CheckpointManager: """ Manages saving and restoring neural memory checkpoints. Provides Docker-like semantics for memory state management. """ def __init__(self, checkpoint_dir: str = "/app/checkpoints"): """ Initialize checkpoint manager. Args: checkpoint_dir: Directory to store checkpoints """ self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.metadata_file = self.checkpoint_dir / "metadata.json" self._load_metadata() def _load_metadata(self) -> None: """Load checkpoint metadata from disk.""" if self.metadata_file.exists(): with self.metadata_file.open() as f: self.metadata = json.load(f) else: self.metadata = {"checkpoints": {}} def _save_metadata(self) -> None: """Save checkpoint metadata to disk.""" with self.metadata_file.open("w") as f: json.dump(self.metadata, f, indent=2) def _compute_hash(self, model: nn.Module) -> str: """Compute hash of model weights for integrity verification.""" hasher = hashlib.sha256() for param in model.parameters(): # Use string representation instead of numpy to avoid numpy dependency data_str = str(param.data.cpu().flatten().tolist()) hasher.update(data_str.encode()) return hasher.hexdigest()[:16] def checkpoint(self, model: nn.Module, tag: str, description: str = "") -> CheckpointInfo: """ Save current learned state as a named checkpoint. Args: model: Neural memory model to checkpoint tag: Name for this checkpoint (e.g., "v1.0", "pre-experiment") description: Optional description Returns: CheckpointInfo with metadata """ checkpoint_path = self.checkpoint_dir / f"{tag}.pt" # Save model state torch.save(model.state_dict(), checkpoint_path) # Compute metadata size_mb = checkpoint_path.stat().st_size / (1024 * 1024) weight_hash = self._compute_hash(model) info = CheckpointInfo( tag=tag, created_at=datetime.now(timezone.utc).isoformat(), size_mb=round(size_mb, 2), weight_hash=weight_hash, description=description, ) # Update metadata self.metadata["checkpoints"][tag] = { "created_at": info.created_at, "size_mb": info.size_mb, "weight_hash": info.weight_hash, "description": info.description, } self._save_metadata() return info def restore(self, model: nn.Module, tag: str) -> CheckpointInfo: """ Restore memory to a previous checkpoint. Args: model: Neural memory model to restore into tag: Checkpoint tag to restore Returns: CheckpointInfo of restored checkpoint """ checkpoint_path = self.checkpoint_dir / f"{tag}.pt" if not checkpoint_path.exists(): raise ValueError(f"Checkpoint '{tag}' not found") # Load state (handle CPU/GPU compatibility) state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) model.load_state_dict(state_dict) # Return metadata meta = self.metadata["checkpoints"].get(tag, {}) return CheckpointInfo( tag=tag, created_at=meta.get("created_at", ""), size_mb=meta.get("size_mb", 0), weight_hash=self._compute_hash(model), description=meta.get("description", ""), ) def list_checkpoints(self) -> List[CheckpointInfo]: """List all available checkpoints.""" checkpoints = [] for tag, meta in self.metadata["checkpoints"].items(): checkpoints.append( CheckpointInfo( tag=tag, created_at=meta.get("created_at", ""), size_mb=meta.get("size_mb", 0), weight_hash=meta.get("weight_hash", ""), description=meta.get("description", ""), ) ) return sorted(checkpoints, key=lambda x: x.created_at, reverse=True) def delete(self, tag: str) -> bool: """Delete a checkpoint.""" checkpoint_path = self.checkpoint_dir / f"{tag}.pt" if checkpoint_path.exists(): checkpoint_path.unlink() if tag in self.metadata["checkpoints"]: del self.metadata["checkpoints"][tag] self._save_metadata() return True return False