| """Test the OpenMythos model: forward pass, backward pass, all variants.""" |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| import torch |
| from open_mythos_hf import ( |
| OpenMythosConfig, |
| OpenMythosForCausalLM, |
| mythos_tiny, |
| mythos_140m, |
| ) |
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters()) |
|
|
| def test_forward_backward(name, config): |
| print(f"\n{'='*60}") |
| print(f"Testing: {name}") |
| print(f" n_embd={config.n_embd}, heads={config.n_heads}, " |
| f"prelude={config.n_layers_in_prelude}, rec={config.n_layers_in_recurrent_block}, " |
| f"coda={config.n_layers_in_coda}") |
| print(f" injection={config.injection_type}, MLA={config.use_mla}, MoE={config.use_moe}") |
| |
| model = OpenMythosForCausalLM(config) |
| total_params = count_params(model) |
| print(f" Parameters: {total_params:,} ({total_params / 1e6:.1f}M)") |
| |
| |
| model.train() |
| B, S = 2, min(128, config.block_size) |
| input_ids = torch.randint(0, config.vocab_size, (B, S)) |
| labels = input_ids.clone() |
| |
| print(f" Input shape: [{B}, {S}]") |
| |
| output = model(input_ids=input_ids, labels=labels) |
| print(f" Output logits shape: {output.logits.shape}") |
| print(f" Loss: {output.loss.item():.4f}") |
| print(f" Num recurrence steps: {output.num_steps}") |
| |
| |
| output.loss.backward() |
| grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 |
| print(f" Gradient norm: {grad_norm:.4f}") |
| |
| |
| has_prelude_grad = any(p.grad is not None and p.grad.norm() > 0 |
| for p in model.prelude.parameters()) |
| has_rec_grad = any(p.grad is not None and p.grad.norm() > 0 |
| for p in model.recurrent_block.parameters()) |
| has_coda_grad = any(p.grad is not None and p.grad.norm() > 0 |
| for p in model.coda.parameters()) |
| has_lm_head_grad = model.lm_head.weight.grad is not None and model.lm_head.weight.grad.norm() > 0 |
| |
| print(f" Gradients: prelude={has_prelude_grad}, recurrent={has_rec_grad}, " |
| f"coda={has_coda_grad}, lm_head={has_lm_head_grad}") |
| |
| |
| model.eval() |
| with torch.no_grad(): |
| output_eval = model(input_ids=input_ids, num_steps=config.mean_recurrence) |
| print(f" Eval logits shape: {output_eval.logits.shape}") |
| print(f" Eval num_steps: {output_eval.num_steps}") |
| |
| print(f" ✅ {name} PASSED") |
| return True |
|
|
|
|
| def test_lti_stability(): |
| """Test that LTI injection maintains spectral radius < 1.""" |
| from open_mythos_hf.modeling import LTIInjection |
| |
| print(f"\n{'='*60}") |
| print("Testing LTI Injection stability") |
| |
| lti = LTIInjection(256, 256) |
| A_bar, B_bar = lti.get_AB_bar() |
| |
| max_A = A_bar.abs().max().item() |
| min_A = A_bar.abs().min().item() |
| print(f" A_bar range: [{min_A:.6f}, {max_A:.6f}]") |
| assert max_A < 1.0, f"Spectral radius >= 1: {max_A}" |
| print(f" ✅ Spectral radius < 1 guaranteed: ρ(Ā) = {max_A:.6f}") |
| |
| |
| h = torch.randn(1, 16, 256) |
| e = torch.randn(1, 16, 256) |
| norms = [] |
| for i in range(100): |
| h = lti(h, e) |
| norms.append(h.norm().item()) |
| |
| print(f" State norm after 100 steps: {norms[-1]:.2f} " |
| f"(converges: {norms[-1] < norms[0] * 100})") |
| print(f" ✅ LTI stability PASSED") |
|
|
|
|
| def test_mla(): |
| """Test Multi-Latent Attention.""" |
| from open_mythos_hf.modeling import MultiLatentAttention, precompute_freqs_cis |
| |
| print(f"\n{'='*60}") |
| print("Testing Multi-Latent Attention") |
| |
| config = OpenMythosConfig( |
| n_embd=512, n_heads=8, head_dim=64, |
| use_mla=True, kv_lora_rank=128, q_lora_rank=256, rope_head_dim=32, |
| ) |
| mla = MultiLatentAttention(config) |
| |
| x = torch.randn(2, 32, 512) |
| freqs_cis = precompute_freqs_cis(64, 64) |
| out = mla(x, freqs_cis) |
| |
| print(f" Input: {x.shape}, Output: {out.shape}") |
| assert out.shape == x.shape, f"Shape mismatch: {out.shape} != {x.shape}" |
| |
| |
| full_kv = 2 * config.n_heads * config.head_dim * config.n_embd |
| compressed_kv = config.kv_lora_rank * config.n_embd |
| print(f" KV cache reduction: {full_kv} → {compressed_kv} " |
| f"({compressed_kv / full_kv * 100:.1f}%)") |
| print(f" ✅ MLA PASSED") |
|
|
|
|
| def test_moe(): |
| """Test Sparse MoE.""" |
| from open_mythos_hf.modeling import SparseMoE |
| |
| print(f"\n{'='*60}") |
| print("Testing Sparse MoE") |
| |
| config = OpenMythosConfig( |
| n_embd=256, n_heads=4, head_dim=64, |
| use_moe=True, n_routed_experts=8, n_shared_experts=2, |
| moe_top_k=2, moe_intermediate_size=512, |
| intermediate_size=1024, |
| ) |
| moe = SparseMoE(config) |
| |
| x = torch.randn(2, 16, 256) |
| out = moe(x) |
| |
| print(f" Input: {x.shape}, Output: {out.shape}") |
| assert out.shape == x.shape |
| |
| |
| out.sum().backward() |
| print(f" Backward passed") |
| print(f" ✅ MoE PASSED") |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print("OpenMythos HF Reimplementation — Test Suite") |
| print("=" * 60) |
| |
| |
| test_lti_stability() |
| test_mla() |
| test_moe() |
| |
| |
| |
| test_forward_backward("Tiny (LTI)", mythos_tiny()) |
| |
| |
| config_linear = mythos_tiny() |
| config_linear.injection_type = "linear" |
| test_forward_backward("Tiny (Linear injection)", config_linear) |
| |
| |
| config_mla = mythos_tiny() |
| config_mla.use_mla = True |
| config_mla.kv_lora_rank = 64 |
| config_mla.q_lora_rank = 128 |
| config_mla.rope_head_dim = 32 |
| config_mla.head_dim = 64 |
| test_forward_backward("Tiny (MLA)", config_mla) |
| |
| |
| config_moe = mythos_tiny() |
| config_moe.use_moe = True |
| config_moe.n_routed_experts = 4 |
| config_moe.n_shared_experts = 1 |
| config_moe.moe_top_k = 2 |
| config_moe.moe_intermediate_size = 256 |
| test_forward_backward("Tiny (MoE)", config_moe) |
| |
| |
| config_full = mythos_tiny() |
| config_full.injection_type = "lti" |
| config_full.use_mla = True |
| config_full.kv_lora_rank = 64 |
| config_full.q_lora_rank = 128 |
| config_full.rope_head_dim = 32 |
| config_full.head_dim = 64 |
| config_full.use_moe = True |
| config_full.n_routed_experts = 4 |
| config_full.n_shared_experts = 1 |
| config_full.moe_top_k = 2 |
| config_full.moe_intermediate_size = 256 |
| test_forward_backward("Tiny (Full: LTI + MLA + MoE)", config_full) |
| |
| print(f"\n{'='*60}") |
| print("ALL TESTS PASSED ✅") |
| print(f"{'='*60}") |
|
|