mlx-llada2-uni / README.md
treadon's picture
Upload README.md with huggingface_hub
60122a4 verified
---
library_name: mlx
tags:
- mlx
- diffusion-lm
- mixture-of-experts
- multimodal
- text-to-image
- image-understanding
- apple-silicon
- llada
base_model: inclusionAI/LLaDA2.0-Uni
license: apache-2.0
pipeline_tag: any-to-any
---
# MLX LLaDA2.0-Uni
> Follow [**@treadon on X**](https://x.com/treadon) and [**treadon on Hugging Face**](https://huggingface.co/treadon) for more AI experiments, evals, and projects.
An [MLX](https://github.com/ml-explore/mlx) port of [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) β€” a **16B MoE unified multimodal *diffusion* LLM**, running natively on Apple Silicon.
LLaDA2 doesn't generate left-to-right. It's a *diffusion* language model: fill a template with `<mask>` tokens, then iteratively un-mask them over multiple denoising steps using **bidirectional** attention β€” like image diffusion models but over discrete tokens. Images are represented as **in-vocabulary VQ tokens** (offset at index 157184), so the backbone is just a sequence-in / logits-out transformer that does text chat, image understanding (VQA), and text-to-image in one model.
<table>
<tr>
<td><img src="samples/strawberries.png" width="320"><br><sub>A bowl of fresh strawberries on a wooden table, morning light, photorealistic (512Γ—512, cfg 4.0, 8 MLX Γ— 9 blocks β†’ 25-step decoder)</sub></td>
<td><img src="samples/castle.png" width="320"><br><sub>A medieval castle on a hill at sunset, oil painting style (same settings)</sub></td>
</tr>
</table>
*(Generated at 512Γ—512 for speed. LLaDA2 was designed for 1024Γ—1024 β€” running at the native 32Γ—32 token grid produces significantly sharper results but each generation takes several hours on M4 Pro. The 16Γ—16 grid at 512Γ—512 works well for still-life subjects; complex subjects like cat faces can show mild distortion at this reduced resolution.)*
## Why this is interesting
`llama.cpp` doesn't support diffusion LLMs β€” every architecture there is a hand-coded graph assuming causal attention + next-token sampling. LLaDA2 needs bidirectional attention and a multi-step denoising loop, so GGUF was not an option. MLX is a general-purpose tensor framework: you write the forward pass and the denoising loop in Python, and it runs.
## Quick Start
```bash
git clone https://huggingface.co/treadon/mlx-llada2-uni
cd mlx-llada2-uni
git clone https://github.com/inclusionAI/LLaDA2.0-Uni llada2-uni-repo
python -m venv .venv && source .venv/bin/activate
pip install mlx "transformers==4.51" diffusers torch torchvision huggingface_hub safetensors Pillow accelerate torchdiffeq einops
# Text-only chat
python generate.py --prompt "Name three primary colors."
# Image understanding (VQA)
python image_understand.py --image some_photo.jpg --question "What's in this image?"
# Text-to-image (β‰ˆ18 min on M4 Pro for a 512Γ—512)
python t2i.py --prompt "A photorealistic cat sitting on a wooden table"
```
First run pulls ~45 GB from [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni): the 32 GB MoE backbone, 12 GB diffusion decoder, 2.4 GB image tokenizer (ViT + VQVAE), and the 170 MB FLUX VAE. Requirements:
- Apple Silicon (tested on M4 Pro 64 GB).
- Python 3.10+.
- `transformers==4.51` specifically β€” newer versions have hard flash_attn imports that fail on Apple Silicon.
## What's in this repo
Just the MLX code + sample outputs. Weights are loaded directly from
[inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) β€” we don't redistribute them.
```
mlx-llada2-uni/
β”œβ”€β”€ llada2/
β”‚ β”œβ”€β”€ model.py # 16B MoE backbone (GQA, partial-RoPE, DeepSeek-V2 MoE, bidirectional attention)
β”‚ β”œβ”€β”€ weights.py # HF safetensors β†’ MLX loader (packs per-expert ModuleList into stacked arrays)
β”‚ β”œβ”€β”€ generate.py # Block-diffusion text generation
β”‚ └── generate_image.py # Block-diffusion VQ-token generation with CFG (for t2i)
β”œβ”€β”€ generate.py # Text CLI
β”œβ”€β”€ image_understand.py # VQA CLI (hybrid: PyTorch image tokenizer β†’ MLX backbone)
β”œβ”€β”€ t2i.py # Text-to-image CLI (hybrid: MLX VQ gen β†’ PyTorch decoder)
└── samples/ # Example outputs
```
## Architecture
### Backbone (runs in MLX)
| Field | Value |
|---|---|
| Parameters | 16B total, ~1B active per token |
| hidden_size | 2048 |
| layers | 20 (1 dense + 19 MoE) |
| attention | 16Q / 4KV (GQA), head_dim 128, QK RMSNorm, **bidirectional** |
| RoPE | partial (rotates first 64 of 128 dims), ΞΈ=600000 |
| MoE | 256 routed + 1 shared, top-8 per token, DeepSeek-V2 group-limited (n_group=8, topk_group=4) |
| moe_intermediate_size | 512 |
| vocab | 173568 (text + 16384-entry VQ codebook at offset 157184) |
### Pipeline (hybrid)
| Component | Size | Where |
|---|---|---|
| LLaDA2 MoE backbone | 32 GB bf16 | **MLX** (GPU) |
| Image tokenizer (ViT + VQVAE) | 2.4 GB | **PyTorch CPU** |
| SigVQ prior | 400 MB fp32 | **PyTorch MPS** |
| ZImage decoder (6.2B, 30 layers) | 12 GB bf16 | **PyTorch MPS** |
| FLUX VAE | 170 MB bf16 | **PyTorch MPS** |
Porting the 6.2B ZImageTransformer2DModel decoder + ~4B image tokenizer to MLX would take weeks.
They run in PyTorch on MPS instead. The heavy lifting (the 16B MoE LLM that does all three tasks) is the MLX-native part.
## Conversion notes
| Area | Issue | Fix |
|---|---|---|
| Packed QKV | `query_key_value` linear outputs `(nh + 2Β·nkv)Β·d` | Split along channel, per-head RMSNorm |
| Partial RoPE | HF rotates first `head_dim * partial_rotary_factor` dims | Skip concat if non-rotated half is empty |
| MoE gate | Sigmoid + `expert_bias` + group-limited top-k | `argpartition(-scores, k-1)[:, :k]` for k-largest indices |
| MoE dispatch | Need gather `[N_slots, H_moe, H]` weights | Chunked at 512 slots to cap transient memory |
| Packed experts | HF stores per-expert ModuleList | Stack into `[E, out, in]` at load time |
| Bidirectional mask | HF uses bool tril-block-diag mask | Additive fp32 mask with βˆ’inf off-block |
| CFG uncond padding | **Bug β€” fixed** | Separate forward passes; uncond gets its own `attn_mask` (pads masked) and `position_ids` (zeros on pad, start-at-0 for real tokens) |
| Decoder fp64 | MPS doesn't support fp64 | Patched `decoder/transport/transport.py` to use fp32 |
| `flash_attn` import | Not available on Apple Silicon | Stub `sys.modules['flash_attn']` + downgrade transformers to 4.51 |
## The interesting bug
The cond/uncond CFG paths had different prompt lengths (cond=36, uncond=30 for a typical t2i prompt). My first version shared the same `attn_mask` and `position_ids` between them, meaning the uncond path:
1. Attended to the left-pad `<mask>` tokens as if they were content.
2. Saw its real uncond tokens at RoPE positions shifted by `pad_len` (so the model sees `<uncondition>` at position 6 rather than position 0, where it was trained).
CFG computes `logits_uncond + 4Β·(logits_cond βˆ’ logits_uncond)`, so small errors in the uncond direction got amplified 4Γ— into visible vertical **chromatic stripes** in the decoded image.
Fix: separate forward passes for cond and uncond so each path gets its own attention mask (pads masked to βˆ’inf) and position_ids (real tokens start at RoPE position 0, matching training).
## Performance (M4 Pro, 64 GB)
### Text
- Short Q&A (e.g. "what is 2+2?") β€” 15s for 32 tokens.
- Load time: 5.9s warm cache / 8s cold.
### Image understanding
- 512Γ—512 input β†’ 1024 VQ tokens + 10 question tokens β†’ MLX backbone β†’ 48 gen tokens β‰ˆ **90s end-to-end**.
### Text-to-image (512Γ—512)
| Decoder | Decode steps | Decode time | Quality |
|---|---|---|---|
| `decoder-turbo` (distilled) | 8 | ~45s | Stripes on our tokens (brittle to small VQ noise) |
| `decoder/` (full) | 50 | ~8 min | βœ… Clean, production-quality |
| `decoder/` (full) | 25 | ~4 min | βœ… Clean (tested) |
MLX VQ generation itself: ~10 min (8 steps Γ— 9 blocks, CFG scale 4.0).
Default CLI: 50 steps for quality. Use `--decoder-steps 25` to halve the decode time with no visible loss.
## Links
- Original model: [inclusionAI/LLaDA2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni)
- Paper: [arXiv:2604.20796](https://arxiv.org/abs/2604.20796)
- Apple MLX: [github.com/ml-explore/mlx](https://github.com/ml-explore/mlx)
- Built by [@treadon](https://x.com/treadon)
## More from me
For other projects and writeups, see [**riteshkhanna.com**](https://riteshkhanna.com), follow [**@treadon on X**](https://x.com/treadon), or [**treadon on Hugging Face**](https://huggingface.co/treadon).