File size: 2,186 Bytes
afa8aff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | """
Quick test script to verify Titans+MIRAS components
"""
import torch
from miras_memory import MIRASMemory
from projections import KeyProjection, ValueProjection
from memory_store import MemoryStore
print("=" * 50)
print("Testing Titans + MIRAS Components")
print("=" * 50)
# Test 1: Memory Module
print("\n✓ Test 1: Memory Module")
memory = MIRASMemory(memory_dim=256, init_scale=0.01)
key_test = torch.randn(1, 256)
value_test = torch.randn(1, 256)
pred = memory(key_test)
print(f" - Forward pass: {pred.shape}")
loss = memory.compute_loss(key_test, value_test)
print(f" - Loss computation: {loss.item():.4f}")
retention = memory.retention_gate(loss)
print(f" - Retention gate: {retention:.2f}x")
stats = memory.get_stats()
print(f" - Stats: {stats}")
# Test 2: Projections
print("\n✓ Test 2: Projection Layers")
key_proj = KeyProjection(768, 256)
value_proj = ValueProjection(768, 256)
hidden = torch.randn(1, 768)
k = key_proj(hidden)
v = value_proj(hidden)
print(f" - Key projection: {k.shape}")
print(f" - Value projection: {v.shape}")
# Test 3: Memory Store
print("\n✓ Test 3: Memory Persistence")
store = MemoryStore(save_dir="memory_test")
# Save
store.save(memory)
print(f" - Memory saved")
# Create new memory and load
memory2 = MIRASMemory(memory_dim=256, init_scale=0.01)
loaded = store.load(memory2)
print(f" - Memory loaded: {loaded}")
# Test 4: Full Pipeline
print("\n✓ Test 4: Full Test-Time Learning Pipeline")
memory3 = MIRASMemory(memory_dim=256, init_scale=0.01)
for i in range(5):
# Simulate learning
k = torch.randn(1, 256)
v = torch.randn(1, 256)
loss = memory3.compute_loss(k, v)
retention = memory3.retention_gate(loss)
lr = 1e-3 * retention
loss.backward()
with torch.no_grad():
memory3.W -= lr * memory3.W.grad
memory3.W.grad.zero_()
memory3.update_stats(loss)
stats = memory3.get_stats()
print(f" - Step {i+1}: Loss={loss.item():.4f}, Retention={retention:.2f}x, Avg={stats['avg_loss']:.4f}")
print("\n" + "=" * 50)
print("✅ ALL TESTS PASSED!")
print("=" * 50)
|