mlx-llada2-uni / test_forward.py
treadon's picture
Upload test_forward.py with huggingface_hub
4c93fbd verified
"""Smoke test: build the MLX model, run a random-token forward pass, check shapes.
No weight loading — just verifies the graph runs end-to-end.
"""
import mlx.core as mx
from llada2.model import LLaDA2Config, LLaDA2Model
def main():
# Tiny-config version (fewer experts, fewer layers) for quick graph sanity
cfg = LLaDA2Config(
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=32,
max_position_embeddings=64,
rope_theta=10000.0,
partial_rotary_factor=0.5,
num_experts=16,
num_shared_experts=1,
num_experts_per_tok=2,
n_group=4,
topk_group=2,
routed_scaling_factor=1.0,
moe_intermediate_size=64,
first_k_dense_replace=1,
)
model = LLaDA2Model(cfg)
mx.eval(model.parameters())
# Forward pass
input_ids = mx.random.randint(0, cfg.vocab_size, shape=(1, 16))
logits = model(input_ids)
mx.eval(logits)
print(f"input_ids shape: {input_ids.shape}")
print(f"logits shape: {logits.shape}")
assert logits.shape == (1, 16, cfg.vocab_size), f"unexpected logits shape: {logits.shape}"
print("OK: forward pass returns correct shape.")
if __name__ == "__main__":
main()