mlx-llada2-uni / llada2 /model.py
treadon's picture
Upload llada2/model.py with huggingface_hub
732e128 verified
"""MLX port of LLaDA2.0-Uni MoE backbone.
Bidirectional-attention diffusion LLM. Images are in-vocabulary VQ tokens
(offset at image_token_offset), so the backbone is purely sequence-in/logits-out.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@dataclass
class LLaDA2Config:
vocab_size: int = 173568
hidden_size: int = 2048
intermediate_size: int = 5120
num_hidden_layers: int = 20
num_attention_heads: int = 16
num_key_value_heads: int = 4
head_dim: int = 128
rms_norm_eps: float = 1e-6
max_position_embeddings: int = 8192
rope_theta: float = 600000.0
partial_rotary_factor: float = 0.5
pad_token_id: int = 156892
mask_token_id: int = 156895
eos_token_id: int = 156892
image_token_offset: int = 157184
num_experts: int = 256
num_shared_experts: int = 1
num_experts_per_tok: int = 8
n_group: int = 8
topk_group: int = 4
routed_scaling_factor: float = 2.5
moe_intermediate_size: int = 512
first_k_dense_replace: int = 1
@classmethod
def from_hf(cls, hf_config: dict) -> "LLaDA2Config":
return cls(
vocab_size=hf_config["vocab_size"],
hidden_size=hf_config["hidden_size"],
intermediate_size=hf_config["intermediate_size"],
num_hidden_layers=hf_config["num_hidden_layers"],
num_attention_heads=hf_config["num_attention_heads"],
num_key_value_heads=hf_config["num_key_value_heads"],
head_dim=hf_config.get("head_dim", 128),
rms_norm_eps=hf_config.get("rms_norm_eps", 1e-6),
max_position_embeddings=hf_config.get("max_position_embeddings", 8192),
rope_theta=hf_config.get("rope_theta", 600000.0),
partial_rotary_factor=hf_config.get("partial_rotary_factor", 0.5),
pad_token_id=hf_config.get("pad_token_id", 156892),
image_token_offset=hf_config.get("image_token_offset", 157184),
num_experts=hf_config["num_experts"],
num_shared_experts=hf_config.get("num_shared_experts", 1),
num_experts_per_tok=hf_config["num_experts_per_tok"],
n_group=hf_config["n_group"],
topk_group=hf_config["topk_group"],
routed_scaling_factor=hf_config.get("routed_scaling_factor", 2.5),
moe_intermediate_size=hf_config["moe_intermediate_size"],
first_k_dense_replace=hf_config.get("first_k_dense_replace", 1),
)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = mx.ones((dim,))
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
def _rope_inv_freq(head_dim: int, partial_rotary_factor: float, base: float):
dim = int(head_dim * partial_rotary_factor)
half = dim // 2
inv = 1.0 / (base ** (mx.arange(0, half, dtype=mx.float32) * 2.0 / dim))
return inv, dim
def build_rope_cache(max_seq_len: int, head_dim: int, partial_rotary_factor: float, base: float):
"""Precompute cos/sin for positions [0, max_seq_len). Shape: [max_seq_len, rope_dim]."""
inv_freq, rope_dim = _rope_inv_freq(head_dim, partial_rotary_factor, base)
positions = mx.arange(max_seq_len, dtype=mx.float32)
freqs = positions[:, None] * inv_freq[None, :] # [S, rope_dim/2]
emb = mx.concatenate([freqs, freqs], axis=-1) # [S, rope_dim] (HF llama style)
return mx.cos(emb), mx.sin(emb), rope_dim
def apply_rope_partial(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, rope_dim: int):
"""q, k: [B, H, S, D]. Apply rotary to first rope_dim of D."""
q_rot, q_pass = q[..., :rope_dim], q[..., rope_dim:]
k_rot, k_pass = k[..., :rope_dim], k[..., rope_dim:]
cos_b = cos[None, None, :, :]
sin_b = sin[None, None, :, :]
def rotate_half(x):
half = x.shape[-1] // 2
return mx.concatenate([-x[..., half:], x[..., :half]], axis=-1)
q_rot = (q_rot * cos_b) + (rotate_half(q_rot) * sin_b)
k_rot = (k_rot * cos_b) + (rotate_half(k_rot) * sin_b)
q_out = mx.concatenate([q_rot, q_pass], axis=-1) if q_pass.shape[-1] > 0 else q_rot
k_out = mx.concatenate([k_rot, k_pass], axis=-1) if k_pass.shape[-1] > 0 else k_rot
return q_out, k_out
class Attention(nn.Module):
"""GQA with packed QKV linear, QK norm, partial RoPE, bidirectional attention."""
def __init__(self, config: LLaDA2Config):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.scaling = self.head_dim ** -0.5
qkv_out = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
self.query_key_value = nn.Linear(self.hidden_size, qkv_out, bias=False)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(self, x: mx.array, cos: mx.array, sin: mx.array, rope_dim: int,
attn_mask: Optional[mx.array] = None) -> mx.array:
B, S, _ = x.shape
qkv = self.query_key_value(x)
qkv = qkv.reshape(B, S, self.num_heads + 2 * self.num_kv_heads, self.head_dim)
q = qkv[:, :, : self.num_heads, :]
k = qkv[:, :, self.num_heads : self.num_heads + self.num_kv_heads, :]
v = qkv[:, :, self.num_heads + self.num_kv_heads :, :]
# Per-head RMSNorm then transpose to [B, H, S, D]
q = self.query_layernorm(q).transpose(0, 2, 1, 3)
k = self.key_layernorm(k).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
q, k = apply_rope_partial(q, k, cos, sin, rope_dim)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scaling, mask=attn_mask)
out = out.transpose(0, 2, 1, 3).reshape(B, S, self.num_heads * self.head_dim)
return self.dense(out)
class DenseMLP(nn.Module):
def __init__(self, config: LLaDA2Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
"""DeepSeek-V2 style: sigmoid + group-limited topk with expert_bias."""
def __init__(self, config: LLaDA2Config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.n_group = config.n_group
self.topk_group = config.topk_group
self.routed_scaling_factor = config.routed_scaling_factor
self.weight = mx.zeros((self.num_experts, config.hidden_size))
self.expert_bias = mx.zeros((self.num_experts,))
def __call__(self, x: mx.array):
"""x: [N, H]. Returns (top_idx [N, top_k] int32, top_weight [N, top_k])."""
x_f = x.astype(mx.float32)
logits = x_f @ self.weight.astype(mx.float32).T # [N, E]
scores = mx.sigmoid(logits)
scores_b = scores + self.expert_bias.astype(scores.dtype)
N = x.shape[0]
group_size = self.num_experts // self.n_group
g = scores_b.reshape(N, self.n_group, group_size)
# Sum of top-2 per group as group score
top2_vals = mx.topk(g, 2, axis=-1) # ascending: [N, n_group, 2]
group_scores = top2_vals.sum(axis=-1) # [N, n_group]
# Pick top `topk_group` groups (largest). argpartition returns unsorted; that's fine.
group_top = mx.argpartition(-group_scores, self.topk_group - 1, axis=-1)[:, : self.topk_group]
group_mask = mx.zeros((N, self.n_group), dtype=mx.bool_)
group_mask = mx.put_along_axis(
group_mask, group_top, mx.ones(group_top.shape, dtype=mx.bool_), axis=-1
)
# Expand to expert mask
score_mask = mx.broadcast_to(
group_mask[:, :, None], (N, self.n_group, group_size)
).reshape(N, self.num_experts)
masked = mx.where(score_mask, scores_b, mx.array(-float("inf"), dtype=scores_b.dtype))
# Top-k experts
top_idx = mx.argpartition(-masked, self.top_k - 1, axis=-1)[:, : self.top_k] # [N, top_k]
top_scores = mx.take_along_axis(scores, top_idx, axis=-1) # original sigmoid
if self.top_k > 1:
top_weight = top_scores / (top_scores.sum(axis=-1, keepdims=True) + 1e-20)
else:
top_weight = top_scores
top_weight = top_weight * self.routed_scaling_factor
return top_idx, top_weight.astype(x.dtype)
class MoEBlock(nn.Module):
"""256 routed experts + 1 shared. Experts packed as [E, out, in] for per-expert matmul loop."""
def __init__(self, config: LLaDA2Config):
super().__init__()
self.num_experts = config.num_experts
self.moe_hidden = config.moe_intermediate_size
self.hidden = config.hidden_size
self.top_k = config.num_experts_per_tok
self.gate = MoEGate(config)
# Packed expert weights, PyTorch-linear layout [E, out, in].
# Pre-allocated; weight loader overwrites via model.update().
self.experts_gate_w = mx.zeros((self.num_experts, self.moe_hidden, self.hidden))
self.experts_up_w = mx.zeros((self.num_experts, self.moe_hidden, self.hidden))
self.experts_down_w = mx.zeros((self.num_experts, self.hidden, self.moe_hidden))
shared_inter = config.moe_intermediate_size * config.num_shared_experts
self.shared_gate_proj = nn.Linear(self.hidden, shared_inter, bias=False)
self.shared_up_proj = nn.Linear(self.hidden, shared_inter, bias=False)
self.shared_down_proj = nn.Linear(shared_inter, self.hidden, bias=False)
def __call__(self, x: mx.array) -> mx.array:
B, S, H = x.shape
identity = x
x_flat = x.reshape(B * S, H)
N = x_flat.shape[0]
top_idx, top_weight = self.gate(x_flat) # [N, top_k] int, float
# Flatten into (N*top_k) slots in row-major order: [tok0_slot0, tok0_slot1, ...]
flat_expert = top_idx.reshape(-1) # [N*top_k]
slot_weight = top_weight.reshape(-1, 1).astype(x.dtype) # [N*top_k, 1]
token_ids = mx.repeat(mx.arange(N), self.top_k) # [N*top_k]
slot_x = x_flat[token_ids] # [N_slots, H]
N_slots = slot_x.shape[0]
# Gathering [N_slots, H_moe, H] is memory-heavy (O(N_slots * H_moe * H)).
# At H_moe=512, H=2048, bf16: 2 MB per slot. N_slots=8700 → ~17 GB per
# gather. Chunk to cap peak memory. Empirically 512-slot chunks are fine.
CHUNK = 512
parts = []
for start in range(0, N_slots, CHUNK):
end = min(start + CHUNK, N_slots)
chunk_expert = flat_expert[start:end]
chunk_x = slot_x[start:end]
chunk_w = slot_weight[start:end]
g_w = self.experts_gate_w[chunk_expert] # [k, H_moe, H]
u_w = self.experts_up_w[chunk_expert] # [k, H_moe, H]
d_w = self.experts_down_w[chunk_expert] # [k, H, H_moe]
x3 = chunk_x[:, None, :] # [k, 1, H]
g_out = (x3 @ g_w.transpose(0, 2, 1)).squeeze(1) # [k, H_moe]
u_out = (x3 @ u_w.transpose(0, 2, 1)).squeeze(1) # [k, H_moe]
act = nn.silu(g_out) * u_out # [k, H_moe]
d_out = (act[:, None, :] @ d_w.transpose(0, 2, 1)).squeeze(1) # [k, H]
parts.append(d_out * chunk_w) # [k, H]
# Force materialization to release the [k, H_moe, H] gather buffers
mx.eval(parts[-1])
weighted = mx.concatenate(parts, axis=0) # [N_slots, H]
# Sum top_k contributions per token
y = weighted.reshape(N, self.top_k, H).sum(axis=1) # [N, H]
y = y.reshape(B, S, H)
shared = self.shared_down_proj(
nn.silu(self.shared_gate_proj(identity)) * self.shared_up_proj(identity)
)
return y + shared
class DecoderLayer(nn.Module):
def __init__(self, config: LLaDA2Config, layer_idx: int):
super().__init__()
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention = Attention(config)
if layer_idx >= config.first_k_dense_replace:
self.mlp = MoEBlock(config)
else:
self.mlp = DenseMLP(config)
def __call__(self, x: mx.array, cos: mx.array, sin: mx.array, rope_dim: int,
attn_mask: Optional[mx.array]) -> mx.array:
h = self.input_layernorm(x)
h = self.attention(h, cos, sin, rope_dim, attn_mask)
x = x + h
h = self.post_attention_layernorm(x)
h = self.mlp(h)
return x + h
class LLaDA2Backbone(nn.Module):
def __init__(self, config: LLaDA2Config):
super().__init__()
self.config = config
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
cos, sin, rope_dim = build_rope_cache(
config.max_position_embeddings, config.head_dim,
config.partial_rotary_factor, config.rope_theta,
)
self._rope_cos = cos
self._rope_sin = sin
self._rope_dim = rope_dim
def __call__(self, input_ids: mx.array,
attn_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None) -> mx.array:
S = input_ids.shape[-1]
if position_ids is None:
cos = self._rope_cos[:S] # [S, rope_dim]
sin = self._rope_sin[:S]
else:
# position_ids: [B, S] or [S]. Gather per-position cos/sin from cache.
# For simplicity we assume all rows share the same positions (single CFG
# path); the t2i caller calls model() separately for cond and uncond.
if position_ids.ndim == 2:
position_ids = position_ids[0] # [S]
cos = self._rope_cos[position_ids] # [S, rope_dim]
sin = self._rope_sin[position_ids]
h = self.word_embeddings(input_ids)
for layer in self.layers:
h = layer(h, cos, sin, self._rope_dim, attn_mask)
return self.norm(h)
class LLaDA2Model(nn.Module):
"""Full LM: backbone + lm_head."""
def __init__(self, config: LLaDA2Config):
super().__init__()
self.config = config
self.model = LLaDA2Backbone(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, input_ids: mx.array,
attn_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None) -> mx.array:
h = self.model(input_ids, attn_mask, position_ids)
return self.lm_head(h)