| """
|
| Quick test script to verify Titans+MIRAS components
|
| """
|
|
|
| import torch
|
| from miras_memory import MIRASMemory
|
| from projections import KeyProjection, ValueProjection
|
| from memory_store import MemoryStore
|
|
|
| print("=" * 50)
|
| print("Testing Titans + MIRAS Components")
|
| print("=" * 50)
|
|
|
|
|
| print("\n✓ Test 1: Memory Module")
|
| memory = MIRASMemory(memory_dim=256, init_scale=0.01)
|
| key_test = torch.randn(1, 256)
|
| value_test = torch.randn(1, 256)
|
|
|
| pred = memory(key_test)
|
| print(f" - Forward pass: {pred.shape}")
|
|
|
| loss = memory.compute_loss(key_test, value_test)
|
| print(f" - Loss computation: {loss.item():.4f}")
|
|
|
| retention = memory.retention_gate(loss)
|
| print(f" - Retention gate: {retention:.2f}x")
|
|
|
| stats = memory.get_stats()
|
| print(f" - Stats: {stats}")
|
|
|
|
|
| print("\n✓ Test 2: Projection Layers")
|
| key_proj = KeyProjection(768, 256)
|
| value_proj = ValueProjection(768, 256)
|
|
|
| hidden = torch.randn(1, 768)
|
| k = key_proj(hidden)
|
| v = value_proj(hidden)
|
| print(f" - Key projection: {k.shape}")
|
| print(f" - Value projection: {v.shape}")
|
|
|
|
|
| print("\n✓ Test 3: Memory Persistence")
|
| store = MemoryStore(save_dir="memory_test")
|
|
|
|
|
| store.save(memory)
|
| print(f" - Memory saved")
|
|
|
|
|
| memory2 = MIRASMemory(memory_dim=256, init_scale=0.01)
|
| loaded = store.load(memory2)
|
| print(f" - Memory loaded: {loaded}")
|
|
|
|
|
| print("\n✓ Test 4: Full Test-Time Learning Pipeline")
|
| memory3 = MIRASMemory(memory_dim=256, init_scale=0.01)
|
|
|
| for i in range(5):
|
|
|
| k = torch.randn(1, 256)
|
| v = torch.randn(1, 256)
|
|
|
| loss = memory3.compute_loss(k, v)
|
| retention = memory3.retention_gate(loss)
|
| lr = 1e-3 * retention
|
|
|
| loss.backward()
|
| with torch.no_grad():
|
| memory3.W -= lr * memory3.W.grad
|
| memory3.W.grad.zero_()
|
| memory3.update_stats(loss)
|
|
|
| stats = memory3.get_stats()
|
| print(f" - Step {i+1}: Loss={loss.item():.4f}, Retention={retention:.2f}x, Avg={stats['avg_loss']:.4f}")
|
|
|
| print("\n" + "=" * 50)
|
| print("✅ ALL TESTS PASSED!")
|
| print("=" * 50)
|
|
|