"""Test the OpenMythos model: forward pass, backward pass, all variants.""" import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import torch from open_mythos_hf import ( OpenMythosConfig, OpenMythosForCausalLM, mythos_tiny, mythos_140m, ) def count_params(model): return sum(p.numel() for p in model.parameters()) def test_forward_backward(name, config): print(f"\n{'='*60}") print(f"Testing: {name}") print(f" n_embd={config.n_embd}, heads={config.n_heads}, " f"prelude={config.n_layers_in_prelude}, rec={config.n_layers_in_recurrent_block}, " f"coda={config.n_layers_in_coda}") print(f" injection={config.injection_type}, MLA={config.use_mla}, MoE={config.use_moe}") model = OpenMythosForCausalLM(config) total_params = count_params(model) print(f" Parameters: {total_params:,} ({total_params / 1e6:.1f}M)") # Forward pass (training mode — uses Poisson sampling) model.train() B, S = 2, min(128, config.block_size) input_ids = torch.randint(0, config.vocab_size, (B, S)) labels = input_ids.clone() print(f" Input shape: [{B}, {S}]") output = model(input_ids=input_ids, labels=labels) print(f" Output logits shape: {output.logits.shape}") print(f" Loss: {output.loss.item():.4f}") print(f" Num recurrence steps: {output.num_steps}") # Backward pass output.loss.backward() grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 print(f" Gradient norm: {grad_norm:.4f}") # Check gradients exist for key components has_prelude_grad = any(p.grad is not None and p.grad.norm() > 0 for p in model.prelude.parameters()) has_rec_grad = any(p.grad is not None and p.grad.norm() > 0 for p in model.recurrent_block.parameters()) has_coda_grad = any(p.grad is not None and p.grad.norm() > 0 for p in model.coda.parameters()) has_lm_head_grad = model.lm_head.weight.grad is not None and model.lm_head.weight.grad.norm() > 0 print(f" Gradients: prelude={has_prelude_grad}, recurrent={has_rec_grad}, " f"coda={has_coda_grad}, lm_head={has_lm_head_grad}") # Inference mode (fixed steps) model.eval() with torch.no_grad(): output_eval = model(input_ids=input_ids, num_steps=config.mean_recurrence) print(f" Eval logits shape: {output_eval.logits.shape}") print(f" Eval num_steps: {output_eval.num_steps}") print(f" ✅ {name} PASSED") return True def test_lti_stability(): """Test that LTI injection maintains spectral radius < 1.""" from open_mythos_hf.modeling import LTIInjection print(f"\n{'='*60}") print("Testing LTI Injection stability") lti = LTIInjection(256, 256) A_bar, B_bar = lti.get_AB_bar() max_A = A_bar.abs().max().item() min_A = A_bar.abs().min().item() print(f" A_bar range: [{min_A:.6f}, {max_A:.6f}]") assert max_A < 1.0, f"Spectral radius >= 1: {max_A}" print(f" ✅ Spectral radius < 1 guaranteed: ρ(Ā) = {max_A:.6f}") # Test convergence: repeated application should not diverge h = torch.randn(1, 16, 256) e = torch.randn(1, 16, 256) norms = [] for i in range(100): h = lti(h, e) norms.append(h.norm().item()) print(f" State norm after 100 steps: {norms[-1]:.2f} " f"(converges: {norms[-1] < norms[0] * 100})") print(f" ✅ LTI stability PASSED") def test_mla(): """Test Multi-Latent Attention.""" from open_mythos_hf.modeling import MultiLatentAttention, precompute_freqs_cis print(f"\n{'='*60}") print("Testing Multi-Latent Attention") config = OpenMythosConfig( n_embd=512, n_heads=8, head_dim=64, use_mla=True, kv_lora_rank=128, q_lora_rank=256, rope_head_dim=32, ) mla = MultiLatentAttention(config) x = torch.randn(2, 32, 512) freqs_cis = precompute_freqs_cis(64, 64) out = mla(x, freqs_cis) print(f" Input: {x.shape}, Output: {out.shape}") assert out.shape == x.shape, f"Shape mismatch: {out.shape} != {x.shape}" # Check compression ratio full_kv = 2 * config.n_heads * config.head_dim * config.n_embd # standard K + V compressed_kv = config.kv_lora_rank * config.n_embd # compressed print(f" KV cache reduction: {full_kv} → {compressed_kv} " f"({compressed_kv / full_kv * 100:.1f}%)") print(f" ✅ MLA PASSED") def test_moe(): """Test Sparse MoE.""" from open_mythos_hf.modeling import SparseMoE print(f"\n{'='*60}") print("Testing Sparse MoE") config = OpenMythosConfig( n_embd=256, n_heads=4, head_dim=64, use_moe=True, n_routed_experts=8, n_shared_experts=2, moe_top_k=2, moe_intermediate_size=512, intermediate_size=1024, ) moe = SparseMoE(config) x = torch.randn(2, 16, 256) out = moe(x) print(f" Input: {x.shape}, Output: {out.shape}") assert out.shape == x.shape # Backward out.sum().backward() print(f" Backward passed") print(f" ✅ MoE PASSED") if __name__ == "__main__": print("=" * 60) print("OpenMythos HF Reimplementation — Test Suite") print("=" * 60) # Component tests test_lti_stability() test_mla() test_moe() # Full model tests # 1. Tiny with LTI injection (no MLA, no MoE) test_forward_backward("Tiny (LTI)", mythos_tiny()) # 2. Tiny with linear injection (Huginn-style) config_linear = mythos_tiny() config_linear.injection_type = "linear" test_forward_backward("Tiny (Linear injection)", config_linear) # 3. With MLA config_mla = mythos_tiny() config_mla.use_mla = True config_mla.kv_lora_rank = 64 config_mla.q_lora_rank = 128 config_mla.rope_head_dim = 32 config_mla.head_dim = 64 # need head_dim > rope_head_dim test_forward_backward("Tiny (MLA)", config_mla) # 4. With MoE config_moe = mythos_tiny() config_moe.use_moe = True config_moe.n_routed_experts = 4 config_moe.n_shared_experts = 1 config_moe.moe_top_k = 2 config_moe.moe_intermediate_size = 256 test_forward_backward("Tiny (MoE)", config_moe) # 5. Full combo config_full = mythos_tiny() config_full.injection_type = "lti" config_full.use_mla = True config_full.kv_lora_rank = 64 config_full.q_lora_rank = 128 config_full.rope_head_dim = 32 config_full.head_dim = 64 config_full.use_moe = True config_full.n_routed_experts = 4 config_full.n_shared_experts = 1 config_full.moe_top_k = 2 config_full.moe_intermediate_size = 256 test_forward_backward("Tiny (Full: LTI + MLA + MoE)", config_full) print(f"\n{'='*60}") print("ALL TESTS PASSED ✅") print(f"{'='*60}")