Any-to-Any
MLX
diffusion-lm
mixture-of-experts
multimodal
text-to-image
image-understanding
apple-silicon
llada
Instructions to use treadon/mlx-llada2-uni with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-llada2-uni with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-llada2-uni treadon/mlx-llada2-uni
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """Text generation CLI for MLX LLaDA2.0-Uni backbone. | |
| Usage: | |
| python generate.py --prompt "Hello world" [--gen-length 64] [--steps-per-block 16] | |
| """ | |
| import argparse | |
| import json | |
| import time | |
| from pathlib import Path | |
| import mlx.core as mx | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer | |
| from llada2.generate import generate_text | |
| from llada2.model import LLaDA2Config, LLaDA2Model | |
| from llada2.weights import load_weights_into_model | |
| def apply_chat_template(tokenizer, prompt: str) -> str: | |
| """Tokenize using LLaDA2 chat template.""" | |
| messages = [{"role": "user", "content": prompt}] | |
| return tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--prompt", type=str, required=True) | |
| parser.add_argument("--gen-length", type=int, default=128) | |
| parser.add_argument("--block-length", type=int, default=32) | |
| parser.add_argument("--steps-per-block", type=int, default=16) | |
| parser.add_argument("--temperature", type=float, default=0.0) | |
| parser.add_argument("--threshold", type=float, default=0.95) | |
| parser.add_argument("--repo-id", default="inclusionAI/LLaDA2.0-Uni") | |
| args = parser.parse_args() | |
| print(f"[gen] fetching model files…") | |
| snap = snapshot_download( | |
| args.repo_id, | |
| allow_patterns=[ | |
| "model-*.safetensors", "model.safetensors.index.json", | |
| "config.json", "tokenizer*", "special_tokens_map.json", | |
| ], | |
| ) | |
| snap = Path(snap) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(str(snap), trust_remote_code=True) | |
| # Build model | |
| cfg_data = json.loads((snap / "config.json").read_text()) | |
| config = LLaDA2Config.from_hf(cfg_data) | |
| model = LLaDA2Model(config) | |
| # Load weights | |
| print(f"[gen] loading weights from {snap}") | |
| t0 = time.time() | |
| load_weights_into_model(model, snap, dtype=mx.bfloat16) | |
| print(f"[gen] weights loaded in {time.time()-t0:.1f}s") | |
| # Tokenize prompt | |
| prompt_text = apply_chat_template(tokenizer, args.prompt) | |
| print(f"\n[gen] prompt (chat-templated):\n{prompt_text!r}\n") | |
| prompt_ids = tokenizer(prompt_text, return_tensors="np").input_ids | |
| prompt_ids = mx.array(prompt_ids, dtype=mx.int32) | |
| print(f"[gen] prompt token count: {prompt_ids.shape[1]}") | |
| # Generate | |
| t0 = time.time() | |
| out_ids = generate_text( | |
| model, prompt_ids, | |
| gen_length=args.gen_length, | |
| block_length=args.block_length, | |
| steps_per_block=args.steps_per_block, | |
| temperature=args.temperature, | |
| threshold=args.threshold, | |
| mask_token_id=config.mask_token_id, | |
| eos_token_id=config.eos_token_id, | |
| ) | |
| mx.eval(out_ids) | |
| dt = time.time() - t0 | |
| gen_ids = out_ids[0, prompt_ids.shape[1]:].tolist() | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=False) | |
| print(f"\n[gen] ==== GENERATED ({dt:.1f}s) ====\n{text}") | |
| if __name__ == "__main__": | |
| main() | |