"""Test suite for Ultron model — verifies forward pass, generation, stability, and all variants.""" import sys import torch import torch.nn.functional as F import time sys.path.insert(0, "/app") from ultron.model import Ultron, UltronConfig from ultron.variants import ultron_small, ultron_base, ultron_medium, ultron_medium_moe def test_basic_forward(): """Test forward pass with minimal config.""" print("=" * 60) print("TEST: Basic forward pass") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, attn_type="gqa", use_moe=False, ) model = Ultron(cfg) total_params = model.get_num_params(non_embedding=False) print(f" Config: dim={cfg.dim}, heads={cfg.n_heads}, recurrent_layers={cfg.recurrent_layers}, loops={cfg.max_loop_iters}") print(f" Parameters: {total_params:,}") ids = torch.randint(0, cfg.vocab_size, (2, 32)) logits = model(ids) assert logits.shape == (2, 32, cfg.vocab_size), f"Wrong shape: {logits.shape}" print(f" Logits shape: {logits.shape} ✓") # Check stability rho = model.get_spectral_radius() assert rho < 1.0, f"Spectral radius {rho} >= 1!" print(f" Spectral radius ρ(A) = {rho:.6f} (< 1 ✓)") print(" PASSED ✓\n") def test_mla_attention(): """Test with Multi-Latent Attention.""" print("=" * 60) print("TEST: MLA attention") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=4, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, attn_type="mla", kv_lora_rank=32, q_lora_rank=64, qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16, ) model = Ultron(cfg) ids = torch.randint(0, cfg.vocab_size, (2, 32)) logits = model(ids) assert logits.shape == (2, 32, cfg.vocab_size) print(f" Logits shape: {logits.shape} ✓") print(f" Parameters: {model.get_num_params():,}") print(" PASSED ✓\n") def test_moe(): """Test with MoE FFN in recurrent block.""" print("=" * 60) print("TEST: MoE FFN") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, attn_type="gqa", use_moe=True, n_experts=4, n_shared_experts=1, n_experts_per_tok=2, expert_dim=64, ) model = Ultron(cfg) ids = torch.randint(0, cfg.vocab_size, (2, 16)) logits = model(ids) assert logits.shape == (2, 16, cfg.vocab_size) print(f" Logits shape: {logits.shape} ✓") print(f" Parameters: {model.get_num_params():,}") print(" PASSED ✓\n") def test_generation(): """Test autoregressive generation with KV caching.""" print("=" * 60) print("TEST: Autoregressive generation") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=256, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, ) model = Ultron(cfg).eval() prompt = torch.randint(0, cfg.vocab_size, (1, 8)) output = model.generate(prompt, max_new_tokens=16, n_loops=4, temperature=1.0, top_k=10) assert output.shape == (1, 24), f"Expected (1, 24), got {output.shape}" print(f" Generated shape: {output.shape} ✓") print(f" Prompt: {prompt[0].tolist()[:8]}") print(f" Generated: {output[0, 8:].tolist()}") print(" PASSED ✓\n") def test_depth_extrapolation(): """Test that model works with more loops at inference than training default.""" print("=" * 60) print("TEST: Depth extrapolation") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, ) model = Ultron(cfg).eval() ids = torch.randint(0, cfg.vocab_size, (1, 16)) # Normal depth logits_4 = model(ids, n_loops=4) # Extrapolated depth (2x training loops) logits_8 = model(ids, n_loops=8) # Even deeper logits_16 = model(ids, n_loops=16) assert logits_4.shape == logits_8.shape == logits_16.shape # Results should differ (different loop counts = different outputs) assert not torch.allclose(logits_4, logits_8, atol=1e-4) print(f" 4 loops → logit mean: {logits_4.mean():.4f}") print(f" 8 loops → logit mean: {logits_8.mean():.4f}") print(f" 16 loops → logit mean: {logits_16.mean():.4f}") print(" PASSED ✓\n") def test_act_halting(): """Verify ACT halting stops early when positions converge.""" print("=" * 60) print("TEST: ACT halting behavior") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=16, lora_rank=4, act_threshold=0.99, ) model = Ultron(cfg).eval() ids = torch.randint(0, cfg.vocab_size, (1, 16)) logits = model(ids, n_loops=16) print(f" Logits shape: {logits.shape} ✓") print(f" ACT threshold: {cfg.act_threshold}") print(" PASSED ✓\n") def test_backward(): """Test that gradients flow correctly through the looped model.""" print("=" * 60) print("TEST: Backward pass / gradient flow") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, ) model = Ultron(cfg) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) ids = torch.randint(0, cfg.vocab_size, (2, 32)) labels = torch.randint(0, cfg.vocab_size, (2, 32)) logits = model(ids) loss = torch.nn.functional.cross_entropy( logits.view(-1, cfg.vocab_size), labels.view(-1) ) loss.backward() optimizer.step() print(f" Loss: {loss.item():.4f}") # Check key gradients exist assert model.recurrent.injection.log_A.grad is not None, "No gradient on log_A!" assert model.recurrent.injection.B.grad is not None, "No gradient on B!" assert model.recurrent.injection.C.grad is not None, "No gradient on C!" if model.recurrent.lora is not None: assert model.recurrent.lora.B.grad is not None, "No gradient on LoRA B!" print(f" log_A grad norm: {model.recurrent.injection.log_A.grad.norm():.6f}") print(f" B grad norm: {model.recurrent.injection.B.grad.norm():.6f}") print(f" C grad norm: {model.recurrent.injection.C.grad.norm():.6f}") print(" PASSED ✓\n") def test_gradient_checkpointing(): """Test that gradient checkpointing works without errors.""" print("=" * 60) print("TEST: Gradient checkpointing") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, gradient_checkpointing=True, ) model = Ultron(cfg) ids = torch.randint(0, cfg.vocab_size, (2, 16)) labels = torch.randint(0, cfg.vocab_size, (2, 16)) logits = model(ids) loss = torch.nn.functional.cross_entropy(logits.view(-1, cfg.vocab_size), labels.view(-1)) loss.backward() print(f" Loss: {loss.item():.4f}") print(f" Grad checkpointing enabled: {cfg.gradient_checkpointing}") print(" PASSED ✓\n") def test_variant_param_counts(): """Verify parameter counts for all variants.""" print("=" * 60) print("TEST: Variant parameter counts") variants = { "ultron_small": ultron_small(), "ultron_base": ultron_base(), } for name, cfg in variants.items(): model = Ultron(cfg) total = model.get_num_params(non_embedding=False) non_emb = model.get_num_params(non_embedding=True) rho = model.get_spectral_radius() print(f" {name}: {total:>12,} total | {non_emb:>12,} non-emb | ρ(A)={rho:.6f}") assert rho < 1.0 print(" PASSED ✓\n") def test_training_loop(): """Run a mini training loop to verify end-to-end training works.""" print("=" * 60) print("TEST: Mini training loop (10 steps)") cfg = UltronConfig( vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2, max_seq_len=128, prelude_layers=1, coda_layers=1, recurrent_layers=2, max_loop_iters=4, lora_rank=4, ) model = Ultron(cfg) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) losses = [] for step in range(10): ids = torch.randint(0, cfg.vocab_size, (4, 64)) targets = ids[:, 1:] logits = model(ids[:, :-1]) loss = F.cross_entropy(logits.reshape(-1, cfg.vocab_size), targets.reshape(-1)) loss.backward() optimizer.step() optimizer.zero_grad() losses.append(loss.item()) rho = model.get_spectral_radius() print(f" Step 0 loss: {losses[0]:.4f}") print(f" Step 9 loss: {losses[-1]:.4f}") print(f" ρ(A) after training: {rho:.6f} (< 1 ✓)") assert rho < 1.0, f"Spectral radius exploded: {rho}" assert losses[-1] < losses[0], "Loss didn't decrease!" print(" PASSED ✓\n") if __name__ == "__main__": print("\n🤖 ULTRON TEST SUITE\n") start = time.time() test_basic_forward() test_mla_attention() test_moe() test_generation() test_depth_extrapolation() test_act_halting() test_backward() test_gradient_checkpointing() test_variant_param_counts() test_training_loop() elapsed = time.time() - start print("=" * 60) print(f"🎉 ALL TESTS PASSED in {elapsed:.1f}s") print("=" * 60)