mlx-llada2-uni / test_generate.py
treadon's picture
Upload test_generate.py with huggingface_hub
74b832e verified
"""Smoke test: run generate_text on a tiny randomly-initialized model.
Won't produce coherent text (weights are random) but verifies the loop runs.
"""
import mlx.core as mx
from llada2.model import LLaDA2Config, LLaDA2Model
from llada2.generate import generate_text
def main():
cfg = LLaDA2Config(
vocab_size=200,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=32,
max_position_embeddings=128,
rope_theta=10000.0,
partial_rotary_factor=0.5,
num_experts=16,
num_shared_experts=1,
num_experts_per_tok=2,
n_group=4,
topk_group=2,
routed_scaling_factor=1.0,
moe_intermediate_size=64,
first_k_dense_replace=1,
pad_token_id=50,
mask_token_id=51,
eos_token_id=52,
)
model = LLaDA2Model(cfg)
mx.eval(model.parameters())
prompt_ids = mx.array([[10, 20, 30, 40]], dtype=mx.int32)
out = generate_text(
model, prompt_ids,
gen_length=16, block_length=8, steps_per_block=4,
temperature=0.0, threshold=0.5,
mask_token_id=cfg.mask_token_id, eos_token_id=cfg.eos_token_id,
verbose=True,
)
mx.eval(out)
print(f"output shape: {out.shape}")
print(f"output ids: {out[0].tolist()}")
print("OK: generation loop completed")
if __name__ == "__main__":
main()