ultron / tests /test_ultron.py
trojan0x's picture
Add tests/test_ultron.py
a1e7f4a verified
"""Test suite for Ultron model β€” verifies forward pass, generation, stability, and all variants."""
import sys
import torch
import torch.nn.functional as F
import time
sys.path.insert(0, "/app")
from ultron.model import Ultron, UltronConfig
from ultron.variants import ultron_small, ultron_base, ultron_medium, ultron_medium_moe
def test_basic_forward():
"""Test forward pass with minimal config."""
print("=" * 60)
print("TEST: Basic forward pass")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
attn_type="gqa", use_moe=False,
)
model = Ultron(cfg)
total_params = model.get_num_params(non_embedding=False)
print(f" Config: dim={cfg.dim}, heads={cfg.n_heads}, recurrent_layers={cfg.recurrent_layers}, loops={cfg.max_loop_iters}")
print(f" Parameters: {total_params:,}")
ids = torch.randint(0, cfg.vocab_size, (2, 32))
logits = model(ids)
assert logits.shape == (2, 32, cfg.vocab_size), f"Wrong shape: {logits.shape}"
print(f" Logits shape: {logits.shape} βœ“")
# Check stability
rho = model.get_spectral_radius()
assert rho < 1.0, f"Spectral radius {rho} >= 1!"
print(f" Spectral radius ρ(A) = {rho:.6f} (< 1 βœ“)")
print(" PASSED βœ“\n")
def test_mla_attention():
"""Test with Multi-Latent Attention."""
print("=" * 60)
print("TEST: MLA attention")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=4,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
attn_type="mla",
kv_lora_rank=32, q_lora_rank=64,
qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16,
)
model = Ultron(cfg)
ids = torch.randint(0, cfg.vocab_size, (2, 32))
logits = model(ids)
assert logits.shape == (2, 32, cfg.vocab_size)
print(f" Logits shape: {logits.shape} βœ“")
print(f" Parameters: {model.get_num_params():,}")
print(" PASSED βœ“\n")
def test_moe():
"""Test with MoE FFN in recurrent block."""
print("=" * 60)
print("TEST: MoE FFN")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
attn_type="gqa",
use_moe=True, n_experts=4, n_shared_experts=1,
n_experts_per_tok=2, expert_dim=64,
)
model = Ultron(cfg)
ids = torch.randint(0, cfg.vocab_size, (2, 16))
logits = model(ids)
assert logits.shape == (2, 16, cfg.vocab_size)
print(f" Logits shape: {logits.shape} βœ“")
print(f" Parameters: {model.get_num_params():,}")
print(" PASSED βœ“\n")
def test_generation():
"""Test autoregressive generation with KV caching."""
print("=" * 60)
print("TEST: Autoregressive generation")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=256, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
)
model = Ultron(cfg).eval()
prompt = torch.randint(0, cfg.vocab_size, (1, 8))
output = model.generate(prompt, max_new_tokens=16, n_loops=4, temperature=1.0, top_k=10)
assert output.shape == (1, 24), f"Expected (1, 24), got {output.shape}"
print(f" Generated shape: {output.shape} βœ“")
print(f" Prompt: {prompt[0].tolist()[:8]}")
print(f" Generated: {output[0, 8:].tolist()}")
print(" PASSED βœ“\n")
def test_depth_extrapolation():
"""Test that model works with more loops at inference than training default."""
print("=" * 60)
print("TEST: Depth extrapolation")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
)
model = Ultron(cfg).eval()
ids = torch.randint(0, cfg.vocab_size, (1, 16))
# Normal depth
logits_4 = model(ids, n_loops=4)
# Extrapolated depth (2x training loops)
logits_8 = model(ids, n_loops=8)
# Even deeper
logits_16 = model(ids, n_loops=16)
assert logits_4.shape == logits_8.shape == logits_16.shape
# Results should differ (different loop counts = different outputs)
assert not torch.allclose(logits_4, logits_8, atol=1e-4)
print(f" 4 loops β†’ logit mean: {logits_4.mean():.4f}")
print(f" 8 loops β†’ logit mean: {logits_8.mean():.4f}")
print(f" 16 loops β†’ logit mean: {logits_16.mean():.4f}")
print(" PASSED βœ“\n")
def test_act_halting():
"""Verify ACT halting stops early when positions converge."""
print("=" * 60)
print("TEST: ACT halting behavior")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=16,
lora_rank=4, act_threshold=0.99,
)
model = Ultron(cfg).eval()
ids = torch.randint(0, cfg.vocab_size, (1, 16))
logits = model(ids, n_loops=16)
print(f" Logits shape: {logits.shape} βœ“")
print(f" ACT threshold: {cfg.act_threshold}")
print(" PASSED βœ“\n")
def test_backward():
"""Test that gradients flow correctly through the looped model."""
print("=" * 60)
print("TEST: Backward pass / gradient flow")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
)
model = Ultron(cfg)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
ids = torch.randint(0, cfg.vocab_size, (2, 32))
labels = torch.randint(0, cfg.vocab_size, (2, 32))
logits = model(ids)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, cfg.vocab_size), labels.view(-1)
)
loss.backward()
optimizer.step()
print(f" Loss: {loss.item():.4f}")
# Check key gradients exist
assert model.recurrent.injection.log_A.grad is not None, "No gradient on log_A!"
assert model.recurrent.injection.B.grad is not None, "No gradient on B!"
assert model.recurrent.injection.C.grad is not None, "No gradient on C!"
if model.recurrent.lora is not None:
assert model.recurrent.lora.B.grad is not None, "No gradient on LoRA B!"
print(f" log_A grad norm: {model.recurrent.injection.log_A.grad.norm():.6f}")
print(f" B grad norm: {model.recurrent.injection.B.grad.norm():.6f}")
print(f" C grad norm: {model.recurrent.injection.C.grad.norm():.6f}")
print(" PASSED βœ“\n")
def test_gradient_checkpointing():
"""Test that gradient checkpointing works without errors."""
print("=" * 60)
print("TEST: Gradient checkpointing")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
gradient_checkpointing=True,
)
model = Ultron(cfg)
ids = torch.randint(0, cfg.vocab_size, (2, 16))
labels = torch.randint(0, cfg.vocab_size, (2, 16))
logits = model(ids)
loss = torch.nn.functional.cross_entropy(logits.view(-1, cfg.vocab_size), labels.view(-1))
loss.backward()
print(f" Loss: {loss.item():.4f}")
print(f" Grad checkpointing enabled: {cfg.gradient_checkpointing}")
print(" PASSED βœ“\n")
def test_variant_param_counts():
"""Verify parameter counts for all variants."""
print("=" * 60)
print("TEST: Variant parameter counts")
variants = {
"ultron_small": ultron_small(),
"ultron_base": ultron_base(),
}
for name, cfg in variants.items():
model = Ultron(cfg)
total = model.get_num_params(non_embedding=False)
non_emb = model.get_num_params(non_embedding=True)
rho = model.get_spectral_radius()
print(f" {name}: {total:>12,} total | {non_emb:>12,} non-emb | ρ(A)={rho:.6f}")
assert rho < 1.0
print(" PASSED βœ“\n")
def test_training_loop():
"""Run a mini training loop to verify end-to-end training works."""
print("=" * 60)
print("TEST: Mini training loop (10 steps)")
cfg = UltronConfig(
vocab_size=1000, dim=128, n_heads=4, n_kv_heads=2,
max_seq_len=128, prelude_layers=1, coda_layers=1,
recurrent_layers=2, max_loop_iters=4, lora_rank=4,
)
model = Ultron(cfg)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
losses = []
for step in range(10):
ids = torch.randint(0, cfg.vocab_size, (4, 64))
targets = ids[:, 1:]
logits = model(ids[:, :-1])
loss = F.cross_entropy(logits.reshape(-1, cfg.vocab_size), targets.reshape(-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.item())
rho = model.get_spectral_radius()
print(f" Step 0 loss: {losses[0]:.4f}")
print(f" Step 9 loss: {losses[-1]:.4f}")
print(f" ρ(A) after training: {rho:.6f} (< 1 βœ“)")
assert rho < 1.0, f"Spectral radius exploded: {rho}"
assert losses[-1] < losses[0], "Loss didn't decrease!"
print(" PASSED βœ“\n")
if __name__ == "__main__":
print("\nπŸ€– ULTRON TEST SUITE\n")
start = time.time()
test_basic_forward()
test_mla_attention()
test_moe()
test_generation()
test_depth_extrapolation()
test_act_halting()
test_backward()
test_gradient_checkpointing()
test_variant_param_counts()
test_training_loop()
elapsed = time.time() - start
print("=" * 60)
print(f"πŸŽ‰ ALL TESTS PASSED in {elapsed:.1f}s")
print("=" * 60)