mlx-llada2-uni / generate.py
treadon's picture
Upload generate.py with huggingface_hub
025033b verified
"""Text generation CLI for MLX LLaDA2.0-Uni backbone.
Usage:
python generate.py --prompt "Hello world" [--gen-length 64] [--steps-per-block 16]
"""
import argparse
import json
import time
from pathlib import Path
import mlx.core as mx
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from llada2.generate import generate_text
from llada2.model import LLaDA2Config, LLaDA2Model
from llada2.weights import load_weights_into_model
def apply_chat_template(tokenizer, prompt: str) -> str:
"""Tokenize using LLaDA2 chat template."""
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--gen-length", type=int, default=128)
parser.add_argument("--block-length", type=int, default=32)
parser.add_argument("--steps-per-block", type=int, default=16)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--threshold", type=float, default=0.95)
parser.add_argument("--repo-id", default="inclusionAI/LLaDA2.0-Uni")
args = parser.parse_args()
print(f"[gen] fetching model files…")
snap = snapshot_download(
args.repo_id,
allow_patterns=[
"model-*.safetensors", "model.safetensors.index.json",
"config.json", "tokenizer*", "special_tokens_map.json",
],
)
snap = Path(snap)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(str(snap), trust_remote_code=True)
# Build model
cfg_data = json.loads((snap / "config.json").read_text())
config = LLaDA2Config.from_hf(cfg_data)
model = LLaDA2Model(config)
# Load weights
print(f"[gen] loading weights from {snap}")
t0 = time.time()
load_weights_into_model(model, snap, dtype=mx.bfloat16)
print(f"[gen] weights loaded in {time.time()-t0:.1f}s")
# Tokenize prompt
prompt_text = apply_chat_template(tokenizer, args.prompt)
print(f"\n[gen] prompt (chat-templated):\n{prompt_text!r}\n")
prompt_ids = tokenizer(prompt_text, return_tensors="np").input_ids
prompt_ids = mx.array(prompt_ids, dtype=mx.int32)
print(f"[gen] prompt token count: {prompt_ids.shape[1]}")
# Generate
t0 = time.time()
out_ids = generate_text(
model, prompt_ids,
gen_length=args.gen_length,
block_length=args.block_length,
steps_per_block=args.steps_per_block,
temperature=args.temperature,
threshold=args.threshold,
mask_token_id=config.mask_token_id,
eos_token_id=config.eos_token_id,
)
mx.eval(out_ids)
dt = time.time() - t0
gen_ids = out_ids[0, prompt_ids.shape[1]:].tolist()
text = tokenizer.decode(gen_ids, skip_special_tokens=False)
print(f"\n[gen] ==== GENERATED ({dt:.1f}s) ====\n{text}")
if __name__ == "__main__":
main()