""" Quick test script to verify Titans+MIRAS components """ import torch from miras_memory import MIRASMemory from projections import KeyProjection, ValueProjection from memory_store import MemoryStore print("=" * 50) print("Testing Titans + MIRAS Components") print("=" * 50) # Test 1: Memory Module print("\n✓ Test 1: Memory Module") memory = MIRASMemory(memory_dim=256, init_scale=0.01) key_test = torch.randn(1, 256) value_test = torch.randn(1, 256) pred = memory(key_test) print(f" - Forward pass: {pred.shape}") loss = memory.compute_loss(key_test, value_test) print(f" - Loss computation: {loss.item():.4f}") retention = memory.retention_gate(loss) print(f" - Retention gate: {retention:.2f}x") stats = memory.get_stats() print(f" - Stats: {stats}") # Test 2: Projections print("\n✓ Test 2: Projection Layers") key_proj = KeyProjection(768, 256) value_proj = ValueProjection(768, 256) hidden = torch.randn(1, 768) k = key_proj(hidden) v = value_proj(hidden) print(f" - Key projection: {k.shape}") print(f" - Value projection: {v.shape}") # Test 3: Memory Store print("\n✓ Test 3: Memory Persistence") store = MemoryStore(save_dir="memory_test") # Save store.save(memory) print(f" - Memory saved") # Create new memory and load memory2 = MIRASMemory(memory_dim=256, init_scale=0.01) loaded = store.load(memory2) print(f" - Memory loaded: {loaded}") # Test 4: Full Pipeline print("\n✓ Test 4: Full Test-Time Learning Pipeline") memory3 = MIRASMemory(memory_dim=256, init_scale=0.01) for i in range(5): # Simulate learning k = torch.randn(1, 256) v = torch.randn(1, 256) loss = memory3.compute_loss(k, v) retention = memory3.retention_gate(loss) lr = 1e-3 * retention loss.backward() with torch.no_grad(): memory3.W -= lr * memory3.W.grad memory3.W.grad.zero_() memory3.update_stats(loss) stats = memory3.get_stats() print(f" - Step {i+1}: Loss={loss.item():.4f}, Retention={retention:.2f}x, Avg={stats['avg_loss']:.4f}") print("\n" + "=" * 50) print("✅ ALL TESTS PASSED!") print("=" * 50)