| """
|
| Memory Persistence
|
|
|
| Handles saving and loading memory state to/from disk so the brain
|
| remembers across sessions.
|
| """
|
|
|
| import torch
|
| import json
|
| import os
|
| from pathlib import Path
|
| from datetime import datetime
|
|
|
|
|
| class MemoryStore:
|
| """Manages persistent storage of memory state."""
|
|
|
| def __init__(self, save_dir="memory"):
|
| self.save_dir = Path(save_dir)
|
| self.save_dir.mkdir(exist_ok=True)
|
| self.memory_path = self.save_dir / "memory.pt"
|
| self.metadata_path = self.save_dir / "metadata.json"
|
|
|
| def save(self, memory_module):
|
| """
|
| Save memory state to disk.
|
|
|
| Args:
|
| memory_module: MIRASMemory instance
|
| """
|
|
|
| torch.save({
|
| 'W': memory_module.W.data,
|
| 'update_count': memory_module.update_count,
|
| 'total_loss': memory_module.total_loss,
|
| }, self.memory_path)
|
|
|
|
|
| metadata = {
|
| 'last_updated': datetime.now().isoformat(),
|
| 'memory_dim': memory_module.memory_dim,
|
| 'updates': memory_module.update_count.item(),
|
| 'avg_loss': (memory_module.total_loss / max(memory_module.update_count, 1)).item(),
|
| }
|
|
|
| with open(self.metadata_path, 'w') as f:
|
| json.dump(metadata, f, indent=2)
|
|
|
| print(f"💾 Memory saved: {memory_module.update_count.item()} updates")
|
|
|
| def load(self, memory_module):
|
| """
|
| Load memory state from disk.
|
|
|
| Args:
|
| memory_module: MIRASMemory instance to load into
|
|
|
| Returns:
|
| bool: True if loaded successfully, False otherwise
|
| """
|
| if not self.memory_path.exists():
|
| print("🆕 No saved memory found. Starting fresh!")
|
| return False
|
|
|
| try:
|
| checkpoint = torch.load(self.memory_path)
|
| memory_module.W.data = checkpoint['W']
|
| memory_module.update_count = checkpoint['update_count']
|
| memory_module.total_loss = checkpoint['total_loss']
|
|
|
| print(f"✅ Memory loaded: {memory_module.update_count.item()} updates")
|
| return True
|
| except Exception as e:
|
| print(f"⚠️ Error loading memory: {e}. Starting fresh!")
|
| return False
|
|
|
| def get_metadata(self):
|
| """Get metadata about saved memory."""
|
| if not self.metadata_path.exists():
|
| return None
|
|
|
| with open(self.metadata_path, 'r') as f:
|
| return json.load(f)
|
|
|