File size: 2,689 Bytes
afa8aff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """
Memory Persistence
Handles saving and loading memory state to/from disk so the brain
remembers across sessions.
"""
import torch
import json
import os
from pathlib import Path
from datetime import datetime
class MemoryStore:
"""Manages persistent storage of memory state."""
def __init__(self, save_dir="memory"):
self.save_dir = Path(save_dir)
self.save_dir.mkdir(exist_ok=True)
self.memory_path = self.save_dir / "memory.pt"
self.metadata_path = self.save_dir / "metadata.json"
def save(self, memory_module):
"""
Save memory state to disk.
Args:
memory_module: MIRASMemory instance
"""
# Save memory weights
torch.save({
'W': memory_module.W.data,
'update_count': memory_module.update_count,
'total_loss': memory_module.total_loss,
}, self.memory_path)
# Save metadata
metadata = {
'last_updated': datetime.now().isoformat(),
'memory_dim': memory_module.memory_dim,
'updates': memory_module.update_count.item(),
'avg_loss': (memory_module.total_loss / max(memory_module.update_count, 1)).item(),
}
with open(self.metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"💾 Memory saved: {memory_module.update_count.item()} updates")
def load(self, memory_module):
"""
Load memory state from disk.
Args:
memory_module: MIRASMemory instance to load into
Returns:
bool: True if loaded successfully, False otherwise
"""
if not self.memory_path.exists():
print("🆕 No saved memory found. Starting fresh!")
return False
try:
checkpoint = torch.load(self.memory_path)
memory_module.W.data = checkpoint['W']
memory_module.update_count = checkpoint['update_count']
memory_module.total_loss = checkpoint['total_loss']
print(f"✅ Memory loaded: {memory_module.update_count.item()} updates")
return True
except Exception as e:
print(f"⚠️ Error loading memory: {e}. Starting fresh!")
return False
def get_metadata(self):
"""Get metadata about saved memory."""
if not self.metadata_path.exists():
return None
with open(self.metadata_path, 'r') as f:
return json.load(f)
|