memory-augmented-generation / test_components.py
Pavantej's picture
Upload folder using huggingface_hub
afa8aff verified
"""
Quick test script to verify Titans+MIRAS components
"""
import torch
from miras_memory import MIRASMemory
from projections import KeyProjection, ValueProjection
from memory_store import MemoryStore
print("=" * 50)
print("Testing Titans + MIRAS Components")
print("=" * 50)
# Test 1: Memory Module
print("\n✓ Test 1: Memory Module")
memory = MIRASMemory(memory_dim=256, init_scale=0.01)
key_test = torch.randn(1, 256)
value_test = torch.randn(1, 256)
pred = memory(key_test)
print(f" - Forward pass: {pred.shape}")
loss = memory.compute_loss(key_test, value_test)
print(f" - Loss computation: {loss.item():.4f}")
retention = memory.retention_gate(loss)
print(f" - Retention gate: {retention:.2f}x")
stats = memory.get_stats()
print(f" - Stats: {stats}")
# Test 2: Projections
print("\n✓ Test 2: Projection Layers")
key_proj = KeyProjection(768, 256)
value_proj = ValueProjection(768, 256)
hidden = torch.randn(1, 768)
k = key_proj(hidden)
v = value_proj(hidden)
print(f" - Key projection: {k.shape}")
print(f" - Value projection: {v.shape}")
# Test 3: Memory Store
print("\n✓ Test 3: Memory Persistence")
store = MemoryStore(save_dir="memory_test")
# Save
store.save(memory)
print(f" - Memory saved")
# Create new memory and load
memory2 = MIRASMemory(memory_dim=256, init_scale=0.01)
loaded = store.load(memory2)
print(f" - Memory loaded: {loaded}")
# Test 4: Full Pipeline
print("\n✓ Test 4: Full Test-Time Learning Pipeline")
memory3 = MIRASMemory(memory_dim=256, init_scale=0.01)
for i in range(5):
# Simulate learning
k = torch.randn(1, 256)
v = torch.randn(1, 256)
loss = memory3.compute_loss(k, v)
retention = memory3.retention_gate(loss)
lr = 1e-3 * retention
loss.backward()
with torch.no_grad():
memory3.W -= lr * memory3.W.grad
memory3.W.grad.zero_()
memory3.update_stats(loss)
stats = memory3.get_stats()
print(f" - Step {i+1}: Loss={loss.item():.4f}, Retention={retention:.2f}x, Avg={stats['avg_loss']:.4f}")
print("\n" + "=" * 50)
print("✅ ALL TESTS PASSED!")
print("=" * 50)