"""Text generation CLI for MLX LLaDA2.0-Uni backbone. Usage: python generate.py --prompt "Hello world" [--gen-length 64] [--steps-per-block 16] """ import argparse import json import time from pathlib import Path import mlx.core as mx from huggingface_hub import snapshot_download from transformers import AutoTokenizer from llada2.generate import generate_text from llada2.model import LLaDA2Config, LLaDA2Model from llada2.weights import load_weights_into_model def apply_chat_template(tokenizer, prompt: str) -> str: """Tokenize using LLaDA2 chat template.""" messages = [{"role": "user", "content": prompt}] return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--gen-length", type=int, default=128) parser.add_argument("--block-length", type=int, default=32) parser.add_argument("--steps-per-block", type=int, default=16) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--threshold", type=float, default=0.95) parser.add_argument("--repo-id", default="inclusionAI/LLaDA2.0-Uni") args = parser.parse_args() print(f"[gen] fetching model files…") snap = snapshot_download( args.repo_id, allow_patterns=[ "model-*.safetensors", "model.safetensors.index.json", "config.json", "tokenizer*", "special_tokens_map.json", ], ) snap = Path(snap) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(str(snap), trust_remote_code=True) # Build model cfg_data = json.loads((snap / "config.json").read_text()) config = LLaDA2Config.from_hf(cfg_data) model = LLaDA2Model(config) # Load weights print(f"[gen] loading weights from {snap}") t0 = time.time() load_weights_into_model(model, snap, dtype=mx.bfloat16) print(f"[gen] weights loaded in {time.time()-t0:.1f}s") # Tokenize prompt prompt_text = apply_chat_template(tokenizer, args.prompt) print(f"\n[gen] prompt (chat-templated):\n{prompt_text!r}\n") prompt_ids = tokenizer(prompt_text, return_tensors="np").input_ids prompt_ids = mx.array(prompt_ids, dtype=mx.int32) print(f"[gen] prompt token count: {prompt_ids.shape[1]}") # Generate t0 = time.time() out_ids = generate_text( model, prompt_ids, gen_length=args.gen_length, block_length=args.block_length, steps_per_block=args.steps_per_block, temperature=args.temperature, threshold=args.threshold, mask_token_id=config.mask_token_id, eos_token_id=config.eos_token_id, ) mx.eval(out_ids) dt = time.time() - t0 gen_ids = out_ids[0, prompt_ids.shape[1]:].tolist() text = tokenizer.decode(gen_ids, skip_special_tokens=False) print(f"\n[gen] ==== GENERATED ({dt:.1f}s) ====\n{text}") if __name__ == "__main__": main()