| """Smoke test: build the MLX model, run a random-token forward pass, check shapes. | |
| No weight loading — just verifies the graph runs end-to-end. | |
| """ | |
| import mlx.core as mx | |
| from llada2.model import LLaDA2Config, LLaDA2Model | |
| def main(): | |
| # Tiny-config version (fewer experts, fewer layers) for quick graph sanity | |
| cfg = LLaDA2Config( | |
| vocab_size=1024, | |
| hidden_size=128, | |
| intermediate_size=256, | |
| num_hidden_layers=3, | |
| num_attention_heads=4, | |
| num_key_value_heads=2, | |
| head_dim=32, | |
| max_position_embeddings=64, | |
| rope_theta=10000.0, | |
| partial_rotary_factor=0.5, | |
| num_experts=16, | |
| num_shared_experts=1, | |
| num_experts_per_tok=2, | |
| n_group=4, | |
| topk_group=2, | |
| routed_scaling_factor=1.0, | |
| moe_intermediate_size=64, | |
| first_k_dense_replace=1, | |
| ) | |
| model = LLaDA2Model(cfg) | |
| mx.eval(model.parameters()) | |
| # Forward pass | |
| input_ids = mx.random.randint(0, cfg.vocab_size, shape=(1, 16)) | |
| logits = model(input_ids) | |
| mx.eval(logits) | |
| print(f"input_ids shape: {input_ids.shape}") | |
| print(f"logits shape: {logits.shape}") | |
| assert logits.shape == (1, 16, cfg.vocab_size), f"unexpected logits shape: {logits.shape}" | |
| print("OK: forward pass returns correct shape.") | |
| if __name__ == "__main__": | |
| main() | |