"""MLX port of LLaDA2.0-Uni MoE backbone. Bidirectional-attention diffusion LLM. Images are in-vocabulary VQ tokens (offset at image_token_offset), so the backbone is purely sequence-in/logits-out. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional import mlx.core as mx import mlx.nn as nn @dataclass class LLaDA2Config: vocab_size: int = 173568 hidden_size: int = 2048 intermediate_size: int = 5120 num_hidden_layers: int = 20 num_attention_heads: int = 16 num_key_value_heads: int = 4 head_dim: int = 128 rms_norm_eps: float = 1e-6 max_position_embeddings: int = 8192 rope_theta: float = 600000.0 partial_rotary_factor: float = 0.5 pad_token_id: int = 156892 mask_token_id: int = 156895 eos_token_id: int = 156892 image_token_offset: int = 157184 num_experts: int = 256 num_shared_experts: int = 1 num_experts_per_tok: int = 8 n_group: int = 8 topk_group: int = 4 routed_scaling_factor: float = 2.5 moe_intermediate_size: int = 512 first_k_dense_replace: int = 1 @classmethod def from_hf(cls, hf_config: dict) -> "LLaDA2Config": return cls( vocab_size=hf_config["vocab_size"], hidden_size=hf_config["hidden_size"], intermediate_size=hf_config["intermediate_size"], num_hidden_layers=hf_config["num_hidden_layers"], num_attention_heads=hf_config["num_attention_heads"], num_key_value_heads=hf_config["num_key_value_heads"], head_dim=hf_config.get("head_dim", 128), rms_norm_eps=hf_config.get("rms_norm_eps", 1e-6), max_position_embeddings=hf_config.get("max_position_embeddings", 8192), rope_theta=hf_config.get("rope_theta", 600000.0), partial_rotary_factor=hf_config.get("partial_rotary_factor", 0.5), pad_token_id=hf_config.get("pad_token_id", 156892), image_token_offset=hf_config.get("image_token_offset", 157184), num_experts=hf_config["num_experts"], num_shared_experts=hf_config.get("num_shared_experts", 1), num_experts_per_tok=hf_config["num_experts_per_tok"], n_group=hf_config["n_group"], topk_group=hf_config["topk_group"], routed_scaling_factor=hf_config.get("routed_scaling_factor", 2.5), moe_intermediate_size=hf_config["moe_intermediate_size"], first_k_dense_replace=hf_config.get("first_k_dense_replace", 1), ) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = mx.ones((dim,)) self.eps = eps def __call__(self, x: mx.array) -> mx.array: return mx.fast.rms_norm(x, self.weight, self.eps) def _rope_inv_freq(head_dim: int, partial_rotary_factor: float, base: float): dim = int(head_dim * partial_rotary_factor) half = dim // 2 inv = 1.0 / (base ** (mx.arange(0, half, dtype=mx.float32) * 2.0 / dim)) return inv, dim def build_rope_cache(max_seq_len: int, head_dim: int, partial_rotary_factor: float, base: float): """Precompute cos/sin for positions [0, max_seq_len). Shape: [max_seq_len, rope_dim].""" inv_freq, rope_dim = _rope_inv_freq(head_dim, partial_rotary_factor, base) positions = mx.arange(max_seq_len, dtype=mx.float32) freqs = positions[:, None] * inv_freq[None, :] # [S, rope_dim/2] emb = mx.concatenate([freqs, freqs], axis=-1) # [S, rope_dim] (HF llama style) return mx.cos(emb), mx.sin(emb), rope_dim def apply_rope_partial(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, rope_dim: int): """q, k: [B, H, S, D]. Apply rotary to first rope_dim of D.""" q_rot, q_pass = q[..., :rope_dim], q[..., rope_dim:] k_rot, k_pass = k[..., :rope_dim], k[..., rope_dim:] cos_b = cos[None, None, :, :] sin_b = sin[None, None, :, :] def rotate_half(x): half = x.shape[-1] // 2 return mx.concatenate([-x[..., half:], x[..., :half]], axis=-1) q_rot = (q_rot * cos_b) + (rotate_half(q_rot) * sin_b) k_rot = (k_rot * cos_b) + (rotate_half(k_rot) * sin_b) q_out = mx.concatenate([q_rot, q_pass], axis=-1) if q_pass.shape[-1] > 0 else q_rot k_out = mx.concatenate([k_rot, k_pass], axis=-1) if k_pass.shape[-1] > 0 else k_rot return q_out, k_out class Attention(nn.Module): """GQA with packed QKV linear, QK norm, partial RoPE, bidirectional attention.""" def __init__(self, config: LLaDA2Config): super().__init__() self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.hidden_size = config.hidden_size self.scaling = self.head_dim ** -0.5 qkv_out = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim self.query_key_value = nn.Linear(self.hidden_size, qkv_out, bias=False) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def __call__(self, x: mx.array, cos: mx.array, sin: mx.array, rope_dim: int, attn_mask: Optional[mx.array] = None) -> mx.array: B, S, _ = x.shape qkv = self.query_key_value(x) qkv = qkv.reshape(B, S, self.num_heads + 2 * self.num_kv_heads, self.head_dim) q = qkv[:, :, : self.num_heads, :] k = qkv[:, :, self.num_heads : self.num_heads + self.num_kv_heads, :] v = qkv[:, :, self.num_heads + self.num_kv_heads :, :] # Per-head RMSNorm then transpose to [B, H, S, D] q = self.query_layernorm(q).transpose(0, 2, 1, 3) k = self.key_layernorm(k).transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) q, k = apply_rope_partial(q, k, cos, sin, rope_dim) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scaling, mask=attn_mask) out = out.transpose(0, 2, 1, 3).reshape(B, S, self.num_heads * self.head_dim) return self.dense(out) class DenseMLP(nn.Module): def __init__(self, config: LLaDA2Config): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def __call__(self, x: mx.array) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class MoEGate(nn.Module): """DeepSeek-V2 style: sigmoid + group-limited topk with expert_bias.""" def __init__(self, config: LLaDA2Config): super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.n_group = config.n_group self.topk_group = config.topk_group self.routed_scaling_factor = config.routed_scaling_factor self.weight = mx.zeros((self.num_experts, config.hidden_size)) self.expert_bias = mx.zeros((self.num_experts,)) def __call__(self, x: mx.array): """x: [N, H]. Returns (top_idx [N, top_k] int32, top_weight [N, top_k]).""" x_f = x.astype(mx.float32) logits = x_f @ self.weight.astype(mx.float32).T # [N, E] scores = mx.sigmoid(logits) scores_b = scores + self.expert_bias.astype(scores.dtype) N = x.shape[0] group_size = self.num_experts // self.n_group g = scores_b.reshape(N, self.n_group, group_size) # Sum of top-2 per group as group score top2_vals = mx.topk(g, 2, axis=-1) # ascending: [N, n_group, 2] group_scores = top2_vals.sum(axis=-1) # [N, n_group] # Pick top `topk_group` groups (largest). argpartition returns unsorted; that's fine. group_top = mx.argpartition(-group_scores, self.topk_group - 1, axis=-1)[:, : self.topk_group] group_mask = mx.zeros((N, self.n_group), dtype=mx.bool_) group_mask = mx.put_along_axis( group_mask, group_top, mx.ones(group_top.shape, dtype=mx.bool_), axis=-1 ) # Expand to expert mask score_mask = mx.broadcast_to( group_mask[:, :, None], (N, self.n_group, group_size) ).reshape(N, self.num_experts) masked = mx.where(score_mask, scores_b, mx.array(-float("inf"), dtype=scores_b.dtype)) # Top-k experts top_idx = mx.argpartition(-masked, self.top_k - 1, axis=-1)[:, : self.top_k] # [N, top_k] top_scores = mx.take_along_axis(scores, top_idx, axis=-1) # original sigmoid if self.top_k > 1: top_weight = top_scores / (top_scores.sum(axis=-1, keepdims=True) + 1e-20) else: top_weight = top_scores top_weight = top_weight * self.routed_scaling_factor return top_idx, top_weight.astype(x.dtype) class MoEBlock(nn.Module): """256 routed experts + 1 shared. Experts packed as [E, out, in] for per-expert matmul loop.""" def __init__(self, config: LLaDA2Config): super().__init__() self.num_experts = config.num_experts self.moe_hidden = config.moe_intermediate_size self.hidden = config.hidden_size self.top_k = config.num_experts_per_tok self.gate = MoEGate(config) # Packed expert weights, PyTorch-linear layout [E, out, in]. # Pre-allocated; weight loader overwrites via model.update(). self.experts_gate_w = mx.zeros((self.num_experts, self.moe_hidden, self.hidden)) self.experts_up_w = mx.zeros((self.num_experts, self.moe_hidden, self.hidden)) self.experts_down_w = mx.zeros((self.num_experts, self.hidden, self.moe_hidden)) shared_inter = config.moe_intermediate_size * config.num_shared_experts self.shared_gate_proj = nn.Linear(self.hidden, shared_inter, bias=False) self.shared_up_proj = nn.Linear(self.hidden, shared_inter, bias=False) self.shared_down_proj = nn.Linear(shared_inter, self.hidden, bias=False) def __call__(self, x: mx.array) -> mx.array: B, S, H = x.shape identity = x x_flat = x.reshape(B * S, H) N = x_flat.shape[0] top_idx, top_weight = self.gate(x_flat) # [N, top_k] int, float # Flatten into (N*top_k) slots in row-major order: [tok0_slot0, tok0_slot1, ...] flat_expert = top_idx.reshape(-1) # [N*top_k] slot_weight = top_weight.reshape(-1, 1).astype(x.dtype) # [N*top_k, 1] token_ids = mx.repeat(mx.arange(N), self.top_k) # [N*top_k] slot_x = x_flat[token_ids] # [N_slots, H] N_slots = slot_x.shape[0] # Gathering [N_slots, H_moe, H] is memory-heavy (O(N_slots * H_moe * H)). # At H_moe=512, H=2048, bf16: 2 MB per slot. N_slots=8700 → ~17 GB per # gather. Chunk to cap peak memory. Empirically 512-slot chunks are fine. CHUNK = 512 parts = [] for start in range(0, N_slots, CHUNK): end = min(start + CHUNK, N_slots) chunk_expert = flat_expert[start:end] chunk_x = slot_x[start:end] chunk_w = slot_weight[start:end] g_w = self.experts_gate_w[chunk_expert] # [k, H_moe, H] u_w = self.experts_up_w[chunk_expert] # [k, H_moe, H] d_w = self.experts_down_w[chunk_expert] # [k, H, H_moe] x3 = chunk_x[:, None, :] # [k, 1, H] g_out = (x3 @ g_w.transpose(0, 2, 1)).squeeze(1) # [k, H_moe] u_out = (x3 @ u_w.transpose(0, 2, 1)).squeeze(1) # [k, H_moe] act = nn.silu(g_out) * u_out # [k, H_moe] d_out = (act[:, None, :] @ d_w.transpose(0, 2, 1)).squeeze(1) # [k, H] parts.append(d_out * chunk_w) # [k, H] # Force materialization to release the [k, H_moe, H] gather buffers mx.eval(parts[-1]) weighted = mx.concatenate(parts, axis=0) # [N_slots, H] # Sum top_k contributions per token y = weighted.reshape(N, self.top_k, H).sum(axis=1) # [N, H] y = y.reshape(B, S, H) shared = self.shared_down_proj( nn.silu(self.shared_gate_proj(identity)) * self.shared_up_proj(identity) ) return y + shared class DecoderLayer(nn.Module): def __init__(self, config: LLaDA2Config, layer_idx: int): super().__init__() self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention = Attention(config) if layer_idx >= config.first_k_dense_replace: self.mlp = MoEBlock(config) else: self.mlp = DenseMLP(config) def __call__(self, x: mx.array, cos: mx.array, sin: mx.array, rope_dim: int, attn_mask: Optional[mx.array]) -> mx.array: h = self.input_layernorm(x) h = self.attention(h, cos, sin, rope_dim, attn_mask) x = x + h h = self.post_attention_layernorm(x) h = self.mlp(h) return x + h class LLaDA2Backbone(nn.Module): def __init__(self, config: LLaDA2Config): super().__init__() self.config = config self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [DecoderLayer(config, i) for i in range(config.num_hidden_layers)] self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) cos, sin, rope_dim = build_rope_cache( config.max_position_embeddings, config.head_dim, config.partial_rotary_factor, config.rope_theta, ) self._rope_cos = cos self._rope_sin = sin self._rope_dim = rope_dim def __call__(self, input_ids: mx.array, attn_mask: Optional[mx.array] = None, position_ids: Optional[mx.array] = None) -> mx.array: S = input_ids.shape[-1] if position_ids is None: cos = self._rope_cos[:S] # [S, rope_dim] sin = self._rope_sin[:S] else: # position_ids: [B, S] or [S]. Gather per-position cos/sin from cache. # For simplicity we assume all rows share the same positions (single CFG # path); the t2i caller calls model() separately for cond and uncond. if position_ids.ndim == 2: position_ids = position_ids[0] # [S] cos = self._rope_cos[position_ids] # [S, rope_dim] sin = self._rope_sin[position_ids] h = self.word_embeddings(input_ids) for layer in self.layers: h = layer(h, cos, sin, self._rope_dim, attn_mask) return self.norm(h) class LLaDA2Model(nn.Module): """Full LM: backbone + lm_head.""" def __init__(self, config: LLaDA2Config): super().__init__() self.config = config self.model = LLaDA2Backbone(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__(self, input_ids: mx.array, attn_mask: Optional[mx.array] = None, position_ids: Optional[mx.array] = None) -> mx.array: h = self.model(input_ids, attn_mask, position_ids) return self.lm_head(h)