File size: 2,689 Bytes
afa8aff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""

Memory Persistence



Handles saving and loading memory state to/from disk so the brain

remembers across sessions.

"""

import torch
import json
import os
from pathlib import Path
from datetime import datetime


class MemoryStore:
    """Manages persistent storage of memory state."""
    
    def __init__(self, save_dir="memory"):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.memory_path = self.save_dir / "memory.pt"
        self.metadata_path = self.save_dir / "metadata.json"
    
    def save(self, memory_module):
        """

        Save memory state to disk.

        

        Args:

            memory_module: MIRASMemory instance

        """
        # Save memory weights
        torch.save({
            'W': memory_module.W.data,
            'update_count': memory_module.update_count,
            'total_loss': memory_module.total_loss,
        }, self.memory_path)
        
        # Save metadata
        metadata = {
            'last_updated': datetime.now().isoformat(),
            'memory_dim': memory_module.memory_dim,
            'updates': memory_module.update_count.item(),
            'avg_loss': (memory_module.total_loss / max(memory_module.update_count, 1)).item(),
        }
        
        with open(self.metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"💾 Memory saved: {memory_module.update_count.item()} updates")
    
    def load(self, memory_module):
        """

        Load memory state from disk.

        

        Args:

            memory_module: MIRASMemory instance to load into

            

        Returns:

            bool: True if loaded successfully, False otherwise

        """
        if not self.memory_path.exists():
            print("🆕 No saved memory found. Starting fresh!")
            return False
        
        try:
            checkpoint = torch.load(self.memory_path)
            memory_module.W.data = checkpoint['W']
            memory_module.update_count = checkpoint['update_count']
            memory_module.total_loss = checkpoint['total_loss']
            
            print(f"✅ Memory loaded: {memory_module.update_count.item()} updates")
            return True
        except Exception as e:
            print(f"⚠️ Error loading memory: {e}. Starting fresh!")
            return False
    
    def get_metadata(self):
        """Get metadata about saved memory."""
        if not self.metadata_path.exists():
            return None
        
        with open(self.metadata_path, 'r') as f:
            return json.load(f)