"""Smoke test: build the MLX model, run a random-token forward pass, check shapes. No weight loading — just verifies the graph runs end-to-end. """ import mlx.core as mx from llada2.model import LLaDA2Config, LLaDA2Model def main(): # Tiny-config version (fewer experts, fewer layers) for quick graph sanity cfg = LLaDA2Config( vocab_size=1024, hidden_size=128, intermediate_size=256, num_hidden_layers=3, num_attention_heads=4, num_key_value_heads=2, head_dim=32, max_position_embeddings=64, rope_theta=10000.0, partial_rotary_factor=0.5, num_experts=16, num_shared_experts=1, num_experts_per_tok=2, n_group=4, topk_group=2, routed_scaling_factor=1.0, moe_intermediate_size=64, first_k_dense_replace=1, ) model = LLaDA2Model(cfg) mx.eval(model.parameters()) # Forward pass input_ids = mx.random.randint(0, cfg.vocab_size, shape=(1, 16)) logits = model(input_ids) mx.eval(logits) print(f"input_ids shape: {input_ids.shape}") print(f"logits shape: {logits.shape}") assert logits.shape == (1, 16, cfg.vocab_size), f"unexpected logits shape: {logits.shape}" print("OK: forward pass returns correct shape.") if __name__ == "__main__": main()