open-mythos-hf / test_model.py
maidacundo's picture
Add test suite
d4614fe verified
"""Test the OpenMythos model: forward pass, backward pass, all variants."""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from open_mythos_hf import (
OpenMythosConfig,
OpenMythosForCausalLM,
mythos_tiny,
mythos_140m,
)
def count_params(model):
return sum(p.numel() for p in model.parameters())
def test_forward_backward(name, config):
print(f"\n{'='*60}")
print(f"Testing: {name}")
print(f" n_embd={config.n_embd}, heads={config.n_heads}, "
f"prelude={config.n_layers_in_prelude}, rec={config.n_layers_in_recurrent_block}, "
f"coda={config.n_layers_in_coda}")
print(f" injection={config.injection_type}, MLA={config.use_mla}, MoE={config.use_moe}")
model = OpenMythosForCausalLM(config)
total_params = count_params(model)
print(f" Parameters: {total_params:,} ({total_params / 1e6:.1f}M)")
# Forward pass (training mode — uses Poisson sampling)
model.train()
B, S = 2, min(128, config.block_size)
input_ids = torch.randint(0, config.vocab_size, (B, S))
labels = input_ids.clone()
print(f" Input shape: [{B}, {S}]")
output = model(input_ids=input_ids, labels=labels)
print(f" Output logits shape: {output.logits.shape}")
print(f" Loss: {output.loss.item():.4f}")
print(f" Num recurrence steps: {output.num_steps}")
# Backward pass
output.loss.backward()
grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
print(f" Gradient norm: {grad_norm:.4f}")
# Check gradients exist for key components
has_prelude_grad = any(p.grad is not None and p.grad.norm() > 0
for p in model.prelude.parameters())
has_rec_grad = any(p.grad is not None and p.grad.norm() > 0
for p in model.recurrent_block.parameters())
has_coda_grad = any(p.grad is not None and p.grad.norm() > 0
for p in model.coda.parameters())
has_lm_head_grad = model.lm_head.weight.grad is not None and model.lm_head.weight.grad.norm() > 0
print(f" Gradients: prelude={has_prelude_grad}, recurrent={has_rec_grad}, "
f"coda={has_coda_grad}, lm_head={has_lm_head_grad}")
# Inference mode (fixed steps)
model.eval()
with torch.no_grad():
output_eval = model(input_ids=input_ids, num_steps=config.mean_recurrence)
print(f" Eval logits shape: {output_eval.logits.shape}")
print(f" Eval num_steps: {output_eval.num_steps}")
print(f" ✅ {name} PASSED")
return True
def test_lti_stability():
"""Test that LTI injection maintains spectral radius < 1."""
from open_mythos_hf.modeling import LTIInjection
print(f"\n{'='*60}")
print("Testing LTI Injection stability")
lti = LTIInjection(256, 256)
A_bar, B_bar = lti.get_AB_bar()
max_A = A_bar.abs().max().item()
min_A = A_bar.abs().min().item()
print(f" A_bar range: [{min_A:.6f}, {max_A:.6f}]")
assert max_A < 1.0, f"Spectral radius >= 1: {max_A}"
print(f" ✅ Spectral radius < 1 guaranteed: ρ(Ā) = {max_A:.6f}")
# Test convergence: repeated application should not diverge
h = torch.randn(1, 16, 256)
e = torch.randn(1, 16, 256)
norms = []
for i in range(100):
h = lti(h, e)
norms.append(h.norm().item())
print(f" State norm after 100 steps: {norms[-1]:.2f} "
f"(converges: {norms[-1] < norms[0] * 100})")
print(f" ✅ LTI stability PASSED")
def test_mla():
"""Test Multi-Latent Attention."""
from open_mythos_hf.modeling import MultiLatentAttention, precompute_freqs_cis
print(f"\n{'='*60}")
print("Testing Multi-Latent Attention")
config = OpenMythosConfig(
n_embd=512, n_heads=8, head_dim=64,
use_mla=True, kv_lora_rank=128, q_lora_rank=256, rope_head_dim=32,
)
mla = MultiLatentAttention(config)
x = torch.randn(2, 32, 512)
freqs_cis = precompute_freqs_cis(64, 64)
out = mla(x, freqs_cis)
print(f" Input: {x.shape}, Output: {out.shape}")
assert out.shape == x.shape, f"Shape mismatch: {out.shape} != {x.shape}"
# Check compression ratio
full_kv = 2 * config.n_heads * config.head_dim * config.n_embd # standard K + V
compressed_kv = config.kv_lora_rank * config.n_embd # compressed
print(f" KV cache reduction: {full_kv}{compressed_kv} "
f"({compressed_kv / full_kv * 100:.1f}%)")
print(f" ✅ MLA PASSED")
def test_moe():
"""Test Sparse MoE."""
from open_mythos_hf.modeling import SparseMoE
print(f"\n{'='*60}")
print("Testing Sparse MoE")
config = OpenMythosConfig(
n_embd=256, n_heads=4, head_dim=64,
use_moe=True, n_routed_experts=8, n_shared_experts=2,
moe_top_k=2, moe_intermediate_size=512,
intermediate_size=1024,
)
moe = SparseMoE(config)
x = torch.randn(2, 16, 256)
out = moe(x)
print(f" Input: {x.shape}, Output: {out.shape}")
assert out.shape == x.shape
# Backward
out.sum().backward()
print(f" Backward passed")
print(f" ✅ MoE PASSED")
if __name__ == "__main__":
print("=" * 60)
print("OpenMythos HF Reimplementation — Test Suite")
print("=" * 60)
# Component tests
test_lti_stability()
test_mla()
test_moe()
# Full model tests
# 1. Tiny with LTI injection (no MLA, no MoE)
test_forward_backward("Tiny (LTI)", mythos_tiny())
# 2. Tiny with linear injection (Huginn-style)
config_linear = mythos_tiny()
config_linear.injection_type = "linear"
test_forward_backward("Tiny (Linear injection)", config_linear)
# 3. With MLA
config_mla = mythos_tiny()
config_mla.use_mla = True
config_mla.kv_lora_rank = 64
config_mla.q_lora_rank = 128
config_mla.rope_head_dim = 32
config_mla.head_dim = 64 # need head_dim > rope_head_dim
test_forward_backward("Tiny (MLA)", config_mla)
# 4. With MoE
config_moe = mythos_tiny()
config_moe.use_moe = True
config_moe.n_routed_experts = 4
config_moe.n_shared_experts = 1
config_moe.moe_top_k = 2
config_moe.moe_intermediate_size = 256
test_forward_backward("Tiny (MoE)", config_moe)
# 5. Full combo
config_full = mythos_tiny()
config_full.injection_type = "lti"
config_full.use_mla = True
config_full.kv_lora_rank = 64
config_full.q_lora_rank = 128
config_full.rope_head_dim = 32
config_full.head_dim = 64
config_full.use_moe = True
config_full.n_routed_experts = 4
config_full.n_shared_experts = 1
config_full.moe_top_k = 2
config_full.moe_intermediate_size = 256
test_forward_backward("Tiny (Full: LTI + MLA + MoE)", config_full)
print(f"\n{'='*60}")
print("ALL TESTS PASSED ✅")
print(f"{'='*60}")