--- library_name: mlx tags: - mlx - diffusion-lm - mixture-of-experts - multimodal - text-to-image - image-understanding - apple-silicon - llada base_model: inclusionAI/LLaDA2.0-Uni license: apache-2.0 pipeline_tag: any-to-any --- # MLX LLaDA2.0-Uni > Follow [**@treadon on X**](https://x.com/treadon) and [**treadon on Hugging Face**](https://huggingface.co/treadon) for more AI experiments, evals, and projects. An [MLX](https://github.com/ml-explore/mlx) port of [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) — a **16B MoE unified multimodal *diffusion* LLM**, running natively on Apple Silicon. LLaDA2 doesn't generate left-to-right. It's a *diffusion* language model: fill a template with `` tokens, then iteratively un-mask them over multiple denoising steps using **bidirectional** attention — like image diffusion models but over discrete tokens. Images are represented as **in-vocabulary VQ tokens** (offset at index 157184), so the backbone is just a sequence-in / logits-out transformer that does text chat, image understanding (VQA), and text-to-image in one model.

A bowl of fresh strawberries on a wooden table, morning light, photorealistic (512×512, cfg 4.0, 8 MLX × 9 blocks → 25-step decoder)

A medieval castle on a hill at sunset, oil painting style (same settings)
*(Generated at 512×512 for speed. LLaDA2 was designed for 1024×1024 — running at the native 32×32 token grid produces significantly sharper results but each generation takes several hours on M4 Pro. The 16×16 grid at 512×512 works well for still-life subjects; complex subjects like cat faces can show mild distortion at this reduced resolution.)* ## Why this is interesting `llama.cpp` doesn't support diffusion LLMs — every architecture there is a hand-coded graph assuming causal attention + next-token sampling. LLaDA2 needs bidirectional attention and a multi-step denoising loop, so GGUF was not an option. MLX is a general-purpose tensor framework: you write the forward pass and the denoising loop in Python, and it runs. ## Quick Start ```bash git clone https://huggingface.co/treadon/mlx-llada2-uni cd mlx-llada2-uni git clone https://github.com/inclusionAI/LLaDA2.0-Uni llada2-uni-repo python -m venv .venv && source .venv/bin/activate pip install mlx "transformers==4.51" diffusers torch torchvision huggingface_hub safetensors Pillow accelerate torchdiffeq einops # Text-only chat python generate.py --prompt "Name three primary colors." # Image understanding (VQA) python image_understand.py --image some_photo.jpg --question "What's in this image?" # Text-to-image (≈18 min on M4 Pro for a 512×512) python t2i.py --prompt "A photorealistic cat sitting on a wooden table" ``` First run pulls ~45 GB from [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni): the 32 GB MoE backbone, 12 GB diffusion decoder, 2.4 GB image tokenizer (ViT + VQVAE), and the 170 MB FLUX VAE. Requirements: - Apple Silicon (tested on M4 Pro 64 GB). - Python 3.10+. - `transformers==4.51` specifically — newer versions have hard flash_attn imports that fail on Apple Silicon. ## What's in this repo Just the MLX code + sample outputs. Weights are loaded directly from [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) — we don't redistribute them. ``` mlx-llada2-uni/ ├── llada2/ │ ├── model.py # 16B MoE backbone (GQA, partial-RoPE, DeepSeek-V2 MoE, bidirectional attention) │ ├── weights.py # HF safetensors → MLX loader (packs per-expert ModuleList into stacked arrays) │ ├── generate.py # Block-diffusion text generation │ └── generate_image.py # Block-diffusion VQ-token generation with CFG (for t2i) ├── generate.py # Text CLI ├── image_understand.py # VQA CLI (hybrid: PyTorch image tokenizer → MLX backbone) ├── t2i.py # Text-to-image CLI (hybrid: MLX VQ gen → PyTorch decoder) └── samples/ # Example outputs ``` ## Architecture ### Backbone (runs in MLX) | Field | Value | |---|---| | Parameters | 16B total, ~1B active per token | | hidden_size | 2048 | | layers | 20 (1 dense + 19 MoE) | | attention | 16Q / 4KV (GQA), head_dim 128, QK RMSNorm, **bidirectional** | | RoPE | partial (rotates first 64 of 128 dims), θ=600000 | | MoE | 256 routed + 1 shared, top-8 per token, DeepSeek-V2 group-limited (n_group=8, topk_group=4) | | moe_intermediate_size | 512 | | vocab | 173568 (text + 16384-entry VQ codebook at offset 157184) | ### Pipeline (hybrid) | Component | Size | Where | |---|---|---| | LLaDA2 MoE backbone | 32 GB bf16 | **MLX** (GPU) | | Image tokenizer (ViT + VQVAE) | 2.4 GB | **PyTorch CPU** | | SigVQ prior | 400 MB fp32 | **PyTorch MPS** | | ZImage decoder (6.2B, 30 layers) | 12 GB bf16 | **PyTorch MPS** | | FLUX VAE | 170 MB bf16 | **PyTorch MPS** | Porting the 6.2B ZImageTransformer2DModel decoder + ~4B image tokenizer to MLX would take weeks. They run in PyTorch on MPS instead. The heavy lifting (the 16B MoE LLM that does all three tasks) is the MLX-native part. ## Conversion notes | Area | Issue | Fix | |---|---|---| | Packed QKV | `query_key_value` linear outputs `(nh + 2·nkv)·d` | Split along channel, per-head RMSNorm | | Partial RoPE | HF rotates first `head_dim * partial_rotary_factor` dims | Skip concat if non-rotated half is empty | | MoE gate | Sigmoid + `expert_bias` + group-limited top-k | `argpartition(-scores, k-1)[:, :k]` for k-largest indices | | MoE dispatch | Need gather `[N_slots, H_moe, H]` weights | Chunked at 512 slots to cap transient memory | | Packed experts | HF stores per-expert ModuleList | Stack into `[E, out, in]` at load time | | Bidirectional mask | HF uses bool tril-block-diag mask | Additive fp32 mask with −inf off-block | | CFG uncond padding | **Bug — fixed** | Separate forward passes; uncond gets its own `attn_mask` (pads masked) and `position_ids` (zeros on pad, start-at-0 for real tokens) | | Decoder fp64 | MPS doesn't support fp64 | Patched `decoder/transport/transport.py` to use fp32 | | `flash_attn` import | Not available on Apple Silicon | Stub `sys.modules['flash_attn']` + downgrade transformers to 4.51 | ## The interesting bug The cond/uncond CFG paths had different prompt lengths (cond=36, uncond=30 for a typical t2i prompt). My first version shared the same `attn_mask` and `position_ids` between them, meaning the uncond path: 1. Attended to the left-pad `` tokens as if they were content. 2. Saw its real uncond tokens at RoPE positions shifted by `pad_len` (so the model sees `` at position 6 rather than position 0, where it was trained). CFG computes `logits_uncond + 4·(logits_cond − logits_uncond)`, so small errors in the uncond direction got amplified 4× into visible vertical **chromatic stripes** in the decoded image. Fix: separate forward passes for cond and uncond so each path gets its own attention mask (pads masked to −inf) and position_ids (real tokens start at RoPE position 0, matching training). ## Performance (M4 Pro, 64 GB) ### Text - Short Q&A (e.g. "what is 2+2?") — 15s for 32 tokens. - Load time: 5.9s warm cache / 8s cold. ### Image understanding - 512×512 input → 1024 VQ tokens + 10 question tokens → MLX backbone → 48 gen tokens ≈ **90s end-to-end**. ### Text-to-image (512×512) | Decoder | Decode steps | Decode time | Quality | |---|---|---|---| | `decoder-turbo` (distilled) | 8 | ~45s | Stripes on our tokens (brittle to small VQ noise) | | `decoder/` (full) | 50 | ~8 min | ✅ Clean, production-quality | | `decoder/` (full) | 25 | ~4 min | ✅ Clean (tested) | MLX VQ generation itself: ~10 min (8 steps × 9 blocks, CFG scale 4.0). Default CLI: 50 steps for quality. Use `--decoder-steps 25` to halve the decode time with no visible loss. ## Links - Original model: [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) - Paper: [arXiv:2604.20796](https://arxiv.org/abs/2604.20796) - Apple MLX: [github.com/ml-explore/mlx](https://github.com/ml-explore/mlx) - Built by [@treadon](https://x.com/treadon) ## More from me For other projects and writeups, see [**riteshkhanna.com**](https://riteshkhanna.com), follow [**@treadon on X**](https://x.com/treadon), or [**treadon on Hugging Face**](https://huggingface.co/treadon).