Any-to-Any
MLX
diffusion-lm
mixture-of-experts
multimodal
text-to-image
image-understanding
apple-silicon
llada
Instructions to use treadon/mlx-llada2-uni with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-llada2-uni with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-llada2-uni treadon/mlx-llada2-uni
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
File size: 3,033 Bytes
025033b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """Text generation CLI for MLX LLaDA2.0-Uni backbone.
Usage:
python generate.py --prompt "Hello world" [--gen-length 64] [--steps-per-block 16]
"""
import argparse
import json
import time
from pathlib import Path
import mlx.core as mx
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from llada2.generate import generate_text
from llada2.model import LLaDA2Config, LLaDA2Model
from llada2.weights import load_weights_into_model
def apply_chat_template(tokenizer, prompt: str) -> str:
"""Tokenize using LLaDA2 chat template."""
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--gen-length", type=int, default=128)
parser.add_argument("--block-length", type=int, default=32)
parser.add_argument("--steps-per-block", type=int, default=16)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--threshold", type=float, default=0.95)
parser.add_argument("--repo-id", default="inclusionAI/LLaDA2.0-Uni")
args = parser.parse_args()
print(f"[gen] fetching model files…")
snap = snapshot_download(
args.repo_id,
allow_patterns=[
"model-*.safetensors", "model.safetensors.index.json",
"config.json", "tokenizer*", "special_tokens_map.json",
],
)
snap = Path(snap)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(str(snap), trust_remote_code=True)
# Build model
cfg_data = json.loads((snap / "config.json").read_text())
config = LLaDA2Config.from_hf(cfg_data)
model = LLaDA2Model(config)
# Load weights
print(f"[gen] loading weights from {snap}")
t0 = time.time()
load_weights_into_model(model, snap, dtype=mx.bfloat16)
print(f"[gen] weights loaded in {time.time()-t0:.1f}s")
# Tokenize prompt
prompt_text = apply_chat_template(tokenizer, args.prompt)
print(f"\n[gen] prompt (chat-templated):\n{prompt_text!r}\n")
prompt_ids = tokenizer(prompt_text, return_tensors="np").input_ids
prompt_ids = mx.array(prompt_ids, dtype=mx.int32)
print(f"[gen] prompt token count: {prompt_ids.shape[1]}")
# Generate
t0 = time.time()
out_ids = generate_text(
model, prompt_ids,
gen_length=args.gen_length,
block_length=args.block_length,
steps_per_block=args.steps_per_block,
temperature=args.temperature,
threshold=args.threshold,
mask_token_id=config.mask_token_id,
eos_token_id=config.eos_token_id,
)
mx.eval(out_ids)
dt = time.time() - t0
gen_ids = out_ids[0, prompt_ids.shape[1]:].tolist()
text = tokenizer.decode(gen_ids, skip_special_tokens=False)
print(f"\n[gen] ==== GENERATED ({dt:.1f}s) ====\n{text}")
if __name__ == "__main__":
main()
|