File size: 2,186 Bytes
afa8aff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""

Quick test script to verify Titans+MIRAS components

"""

import torch
from miras_memory import MIRASMemory
from projections import KeyProjection, ValueProjection
from memory_store import MemoryStore

print("=" * 50)
print("Testing Titans + MIRAS Components")
print("=" * 50)

# Test 1: Memory Module
print("\n✓ Test 1: Memory Module")
memory = MIRASMemory(memory_dim=256, init_scale=0.01)
key_test = torch.randn(1, 256)
value_test = torch.randn(1, 256)

pred = memory(key_test)
print(f"  - Forward pass: {pred.shape}")

loss = memory.compute_loss(key_test, value_test)
print(f"  - Loss computation: {loss.item():.4f}")

retention = memory.retention_gate(loss)
print(f"  - Retention gate: {retention:.2f}x")

stats = memory.get_stats()
print(f"  - Stats: {stats}")

# Test 2: Projections
print("\n✓ Test 2: Projection Layers")
key_proj = KeyProjection(768, 256)
value_proj = ValueProjection(768, 256)

hidden = torch.randn(1, 768)
k = key_proj(hidden)
v = value_proj(hidden)
print(f"  - Key projection: {k.shape}")
print(f"  - Value projection: {v.shape}")

# Test 3: Memory Store
print("\n✓ Test 3: Memory Persistence")
store = MemoryStore(save_dir="memory_test")

# Save
store.save(memory)
print(f"  - Memory saved")

# Create new memory and load
memory2 = MIRASMemory(memory_dim=256, init_scale=0.01)
loaded = store.load(memory2)
print(f"  - Memory loaded: {loaded}")

# Test 4: Full Pipeline
print("\n✓ Test 4: Full Test-Time Learning Pipeline")
memory3 = MIRASMemory(memory_dim=256, init_scale=0.01)

for i in range(5):
    # Simulate learning
    k = torch.randn(1, 256)
    v = torch.randn(1, 256)
    
    loss = memory3.compute_loss(k, v)
    retention = memory3.retention_gate(loss)
    lr = 1e-3 * retention
    
    loss.backward()
    with torch.no_grad():
        memory3.W -= lr * memory3.W.grad
        memory3.W.grad.zero_()
        memory3.update_stats(loss)
    
    stats = memory3.get_stats()
    print(f"  - Step {i+1}: Loss={loss.item():.4f}, Retention={retention:.2f}x, Avg={stats['avg_loss']:.4f}")

print("\n" + "=" * 50)
print("✅ ALL TESTS PASSED!")
print("=" * 50)