Any-to-Any
MLX
diffusion-lm
mixture-of-experts
multimodal
text-to-image
image-understanding
apple-silicon
llada
Instructions to use treadon/mlx-llada2-uni with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-llada2-uni with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-llada2-uni treadon/mlx-llada2-uni
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """MLX port of LLaDA2.0-Uni MoE backbone. | |
| Bidirectional-attention diffusion LLM. Images are in-vocabulary VQ tokens | |
| (offset at image_token_offset), so the backbone is purely sequence-in/logits-out. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| class LLaDA2Config: | |
| vocab_size: int = 173568 | |
| hidden_size: int = 2048 | |
| intermediate_size: int = 5120 | |
| num_hidden_layers: int = 20 | |
| num_attention_heads: int = 16 | |
| num_key_value_heads: int = 4 | |
| head_dim: int = 128 | |
| rms_norm_eps: float = 1e-6 | |
| max_position_embeddings: int = 8192 | |
| rope_theta: float = 600000.0 | |
| partial_rotary_factor: float = 0.5 | |
| pad_token_id: int = 156892 | |
| mask_token_id: int = 156895 | |
| eos_token_id: int = 156892 | |
| image_token_offset: int = 157184 | |
| num_experts: int = 256 | |
| num_shared_experts: int = 1 | |
| num_experts_per_tok: int = 8 | |
| n_group: int = 8 | |
| topk_group: int = 4 | |
| routed_scaling_factor: float = 2.5 | |
| moe_intermediate_size: int = 512 | |
| first_k_dense_replace: int = 1 | |
| def from_hf(cls, hf_config: dict) -> "LLaDA2Config": | |
| return cls( | |
| vocab_size=hf_config["vocab_size"], | |
| hidden_size=hf_config["hidden_size"], | |
| intermediate_size=hf_config["intermediate_size"], | |
| num_hidden_layers=hf_config["num_hidden_layers"], | |
| num_attention_heads=hf_config["num_attention_heads"], | |
| num_key_value_heads=hf_config["num_key_value_heads"], | |
| head_dim=hf_config.get("head_dim", 128), | |
| rms_norm_eps=hf_config.get("rms_norm_eps", 1e-6), | |
| max_position_embeddings=hf_config.get("max_position_embeddings", 8192), | |
| rope_theta=hf_config.get("rope_theta", 600000.0), | |
| partial_rotary_factor=hf_config.get("partial_rotary_factor", 0.5), | |
| pad_token_id=hf_config.get("pad_token_id", 156892), | |
| image_token_offset=hf_config.get("image_token_offset", 157184), | |
| num_experts=hf_config["num_experts"], | |
| num_shared_experts=hf_config.get("num_shared_experts", 1), | |
| num_experts_per_tok=hf_config["num_experts_per_tok"], | |
| n_group=hf_config["n_group"], | |
| topk_group=hf_config["topk_group"], | |
| routed_scaling_factor=hf_config.get("routed_scaling_factor", 2.5), | |
| moe_intermediate_size=hf_config["moe_intermediate_size"], | |
| first_k_dense_replace=hf_config.get("first_k_dense_replace", 1), | |
| ) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = mx.ones((dim,)) | |
| self.eps = eps | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return mx.fast.rms_norm(x, self.weight, self.eps) | |
| def _rope_inv_freq(head_dim: int, partial_rotary_factor: float, base: float): | |
| dim = int(head_dim * partial_rotary_factor) | |
| half = dim // 2 | |
| inv = 1.0 / (base ** (mx.arange(0, half, dtype=mx.float32) * 2.0 / dim)) | |
| return inv, dim | |
| def build_rope_cache(max_seq_len: int, head_dim: int, partial_rotary_factor: float, base: float): | |
| """Precompute cos/sin for positions [0, max_seq_len). Shape: [max_seq_len, rope_dim].""" | |
| inv_freq, rope_dim = _rope_inv_freq(head_dim, partial_rotary_factor, base) | |
| positions = mx.arange(max_seq_len, dtype=mx.float32) | |
| freqs = positions[:, None] * inv_freq[None, :] # [S, rope_dim/2] | |
| emb = mx.concatenate([freqs, freqs], axis=-1) # [S, rope_dim] (HF llama style) | |
| return mx.cos(emb), mx.sin(emb), rope_dim | |
| def apply_rope_partial(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, rope_dim: int): | |
| """q, k: [B, H, S, D]. Apply rotary to first rope_dim of D.""" | |
| q_rot, q_pass = q[..., :rope_dim], q[..., rope_dim:] | |
| k_rot, k_pass = k[..., :rope_dim], k[..., rope_dim:] | |
| cos_b = cos[None, None, :, :] | |
| sin_b = sin[None, None, :, :] | |
| def rotate_half(x): | |
| half = x.shape[-1] // 2 | |
| return mx.concatenate([-x[..., half:], x[..., :half]], axis=-1) | |
| q_rot = (q_rot * cos_b) + (rotate_half(q_rot) * sin_b) | |
| k_rot = (k_rot * cos_b) + (rotate_half(k_rot) * sin_b) | |
| q_out = mx.concatenate([q_rot, q_pass], axis=-1) if q_pass.shape[-1] > 0 else q_rot | |
| k_out = mx.concatenate([k_rot, k_pass], axis=-1) if k_pass.shape[-1] > 0 else k_rot | |
| return q_out, k_out | |
| class Attention(nn.Module): | |
| """GQA with packed QKV linear, QK norm, partial RoPE, bidirectional attention.""" | |
| def __init__(self, config: LLaDA2Config): | |
| super().__init__() | |
| self.num_heads = config.num_attention_heads | |
| self.num_kv_heads = config.num_key_value_heads | |
| self.head_dim = config.head_dim | |
| self.hidden_size = config.hidden_size | |
| self.scaling = self.head_dim ** -0.5 | |
| qkv_out = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim | |
| self.query_key_value = nn.Linear(self.hidden_size, qkv_out, bias=False) | |
| self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| def __call__(self, x: mx.array, cos: mx.array, sin: mx.array, rope_dim: int, | |
| attn_mask: Optional[mx.array] = None) -> mx.array: | |
| B, S, _ = x.shape | |
| qkv = self.query_key_value(x) | |
| qkv = qkv.reshape(B, S, self.num_heads + 2 * self.num_kv_heads, self.head_dim) | |
| q = qkv[:, :, : self.num_heads, :] | |
| k = qkv[:, :, self.num_heads : self.num_heads + self.num_kv_heads, :] | |
| v = qkv[:, :, self.num_heads + self.num_kv_heads :, :] | |
| # Per-head RMSNorm then transpose to [B, H, S, D] | |
| q = self.query_layernorm(q).transpose(0, 2, 1, 3) | |
| k = self.key_layernorm(k).transpose(0, 2, 1, 3) | |
| v = v.transpose(0, 2, 1, 3) | |
| q, k = apply_rope_partial(q, k, cos, sin, rope_dim) | |
| out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scaling, mask=attn_mask) | |
| out = out.transpose(0, 2, 1, 3).reshape(B, S, self.num_heads * self.head_dim) | |
| return self.dense(out) | |
| class DenseMLP(nn.Module): | |
| def __init__(self, config: LLaDA2Config): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class MoEGate(nn.Module): | |
| """DeepSeek-V2 style: sigmoid + group-limited topk with expert_bias.""" | |
| def __init__(self, config: LLaDA2Config): | |
| super().__init__() | |
| self.num_experts = config.num_experts | |
| self.top_k = config.num_experts_per_tok | |
| self.n_group = config.n_group | |
| self.topk_group = config.topk_group | |
| self.routed_scaling_factor = config.routed_scaling_factor | |
| self.weight = mx.zeros((self.num_experts, config.hidden_size)) | |
| self.expert_bias = mx.zeros((self.num_experts,)) | |
| def __call__(self, x: mx.array): | |
| """x: [N, H]. Returns (top_idx [N, top_k] int32, top_weight [N, top_k]).""" | |
| x_f = x.astype(mx.float32) | |
| logits = x_f @ self.weight.astype(mx.float32).T # [N, E] | |
| scores = mx.sigmoid(logits) | |
| scores_b = scores + self.expert_bias.astype(scores.dtype) | |
| N = x.shape[0] | |
| group_size = self.num_experts // self.n_group | |
| g = scores_b.reshape(N, self.n_group, group_size) | |
| # Sum of top-2 per group as group score | |
| top2_vals = mx.topk(g, 2, axis=-1) # ascending: [N, n_group, 2] | |
| group_scores = top2_vals.sum(axis=-1) # [N, n_group] | |
| # Pick top `topk_group` groups (largest). argpartition returns unsorted; that's fine. | |
| group_top = mx.argpartition(-group_scores, self.topk_group - 1, axis=-1)[:, : self.topk_group] | |
| group_mask = mx.zeros((N, self.n_group), dtype=mx.bool_) | |
| group_mask = mx.put_along_axis( | |
| group_mask, group_top, mx.ones(group_top.shape, dtype=mx.bool_), axis=-1 | |
| ) | |
| # Expand to expert mask | |
| score_mask = mx.broadcast_to( | |
| group_mask[:, :, None], (N, self.n_group, group_size) | |
| ).reshape(N, self.num_experts) | |
| masked = mx.where(score_mask, scores_b, mx.array(-float("inf"), dtype=scores_b.dtype)) | |
| # Top-k experts | |
| top_idx = mx.argpartition(-masked, self.top_k - 1, axis=-1)[:, : self.top_k] # [N, top_k] | |
| top_scores = mx.take_along_axis(scores, top_idx, axis=-1) # original sigmoid | |
| if self.top_k > 1: | |
| top_weight = top_scores / (top_scores.sum(axis=-1, keepdims=True) + 1e-20) | |
| else: | |
| top_weight = top_scores | |
| top_weight = top_weight * self.routed_scaling_factor | |
| return top_idx, top_weight.astype(x.dtype) | |
| class MoEBlock(nn.Module): | |
| """256 routed experts + 1 shared. Experts packed as [E, out, in] for per-expert matmul loop.""" | |
| def __init__(self, config: LLaDA2Config): | |
| super().__init__() | |
| self.num_experts = config.num_experts | |
| self.moe_hidden = config.moe_intermediate_size | |
| self.hidden = config.hidden_size | |
| self.top_k = config.num_experts_per_tok | |
| self.gate = MoEGate(config) | |
| # Packed expert weights, PyTorch-linear layout [E, out, in]. | |
| # Pre-allocated; weight loader overwrites via model.update(). | |
| self.experts_gate_w = mx.zeros((self.num_experts, self.moe_hidden, self.hidden)) | |
| self.experts_up_w = mx.zeros((self.num_experts, self.moe_hidden, self.hidden)) | |
| self.experts_down_w = mx.zeros((self.num_experts, self.hidden, self.moe_hidden)) | |
| shared_inter = config.moe_intermediate_size * config.num_shared_experts | |
| self.shared_gate_proj = nn.Linear(self.hidden, shared_inter, bias=False) | |
| self.shared_up_proj = nn.Linear(self.hidden, shared_inter, bias=False) | |
| self.shared_down_proj = nn.Linear(shared_inter, self.hidden, bias=False) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| B, S, H = x.shape | |
| identity = x | |
| x_flat = x.reshape(B * S, H) | |
| N = x_flat.shape[0] | |
| top_idx, top_weight = self.gate(x_flat) # [N, top_k] int, float | |
| # Flatten into (N*top_k) slots in row-major order: [tok0_slot0, tok0_slot1, ...] | |
| flat_expert = top_idx.reshape(-1) # [N*top_k] | |
| slot_weight = top_weight.reshape(-1, 1).astype(x.dtype) # [N*top_k, 1] | |
| token_ids = mx.repeat(mx.arange(N), self.top_k) # [N*top_k] | |
| slot_x = x_flat[token_ids] # [N_slots, H] | |
| N_slots = slot_x.shape[0] | |
| # Gathering [N_slots, H_moe, H] is memory-heavy (O(N_slots * H_moe * H)). | |
| # At H_moe=512, H=2048, bf16: 2 MB per slot. N_slots=8700 → ~17 GB per | |
| # gather. Chunk to cap peak memory. Empirically 512-slot chunks are fine. | |
| CHUNK = 512 | |
| parts = [] | |
| for start in range(0, N_slots, CHUNK): | |
| end = min(start + CHUNK, N_slots) | |
| chunk_expert = flat_expert[start:end] | |
| chunk_x = slot_x[start:end] | |
| chunk_w = slot_weight[start:end] | |
| g_w = self.experts_gate_w[chunk_expert] # [k, H_moe, H] | |
| u_w = self.experts_up_w[chunk_expert] # [k, H_moe, H] | |
| d_w = self.experts_down_w[chunk_expert] # [k, H, H_moe] | |
| x3 = chunk_x[:, None, :] # [k, 1, H] | |
| g_out = (x3 @ g_w.transpose(0, 2, 1)).squeeze(1) # [k, H_moe] | |
| u_out = (x3 @ u_w.transpose(0, 2, 1)).squeeze(1) # [k, H_moe] | |
| act = nn.silu(g_out) * u_out # [k, H_moe] | |
| d_out = (act[:, None, :] @ d_w.transpose(0, 2, 1)).squeeze(1) # [k, H] | |
| parts.append(d_out * chunk_w) # [k, H] | |
| # Force materialization to release the [k, H_moe, H] gather buffers | |
| mx.eval(parts[-1]) | |
| weighted = mx.concatenate(parts, axis=0) # [N_slots, H] | |
| # Sum top_k contributions per token | |
| y = weighted.reshape(N, self.top_k, H).sum(axis=1) # [N, H] | |
| y = y.reshape(B, S, H) | |
| shared = self.shared_down_proj( | |
| nn.silu(self.shared_gate_proj(identity)) * self.shared_up_proj(identity) | |
| ) | |
| return y + shared | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, config: LLaDA2Config, layer_idx: int): | |
| super().__init__() | |
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.attention = Attention(config) | |
| if layer_idx >= config.first_k_dense_replace: | |
| self.mlp = MoEBlock(config) | |
| else: | |
| self.mlp = DenseMLP(config) | |
| def __call__(self, x: mx.array, cos: mx.array, sin: mx.array, rope_dim: int, | |
| attn_mask: Optional[mx.array]) -> mx.array: | |
| h = self.input_layernorm(x) | |
| h = self.attention(h, cos, sin, rope_dim, attn_mask) | |
| x = x + h | |
| h = self.post_attention_layernorm(x) | |
| h = self.mlp(h) | |
| return x + h | |
| class LLaDA2Backbone(nn.Module): | |
| def __init__(self, config: LLaDA2Config): | |
| super().__init__() | |
| self.config = config | |
| self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) | |
| self.layers = [DecoderLayer(config, i) for i in range(config.num_hidden_layers)] | |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| cos, sin, rope_dim = build_rope_cache( | |
| config.max_position_embeddings, config.head_dim, | |
| config.partial_rotary_factor, config.rope_theta, | |
| ) | |
| self._rope_cos = cos | |
| self._rope_sin = sin | |
| self._rope_dim = rope_dim | |
| def __call__(self, input_ids: mx.array, | |
| attn_mask: Optional[mx.array] = None, | |
| position_ids: Optional[mx.array] = None) -> mx.array: | |
| S = input_ids.shape[-1] | |
| if position_ids is None: | |
| cos = self._rope_cos[:S] # [S, rope_dim] | |
| sin = self._rope_sin[:S] | |
| else: | |
| # position_ids: [B, S] or [S]. Gather per-position cos/sin from cache. | |
| # For simplicity we assume all rows share the same positions (single CFG | |
| # path); the t2i caller calls model() separately for cond and uncond. | |
| if position_ids.ndim == 2: | |
| position_ids = position_ids[0] # [S] | |
| cos = self._rope_cos[position_ids] # [S, rope_dim] | |
| sin = self._rope_sin[position_ids] | |
| h = self.word_embeddings(input_ids) | |
| for layer in self.layers: | |
| h = layer(h, cos, sin, self._rope_dim, attn_mask) | |
| return self.norm(h) | |
| class LLaDA2Model(nn.Module): | |
| """Full LM: backbone + lm_head.""" | |
| def __init__(self, config: LLaDA2Config): | |
| super().__init__() | |
| self.config = config | |
| self.model = LLaDA2Backbone(config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| def __call__(self, input_ids: mx.array, | |
| attn_mask: Optional[mx.array] = None, | |
| position_ids: Optional[mx.array] = None) -> mx.array: | |
| h = self.model(input_ids, attn_mask, position_ids) | |
| return self.lm_head(h) | |