| """Qwen3 Text Encoder β MLX native implementation for Z-Image-Turbo. |
| |
| Architecture (from model config): |
| - 36 layers, hidden_size=2560, 32 attention heads, 8 KV heads (GQA 4:1) |
| - head_dim=128, intermediate_size=9728 |
| - hidden_act=silu (SwiGLU FFN) |
| - RMSNorm (eps=1e-6), QK-Norm on q/k projections |
| - RoPE (theta=1_000_000) |
| - vocab_size=151936 |
| |
| Weight key pattern: |
| model.embed_tokens.weight |
| model.layers.N.input_layernorm.weight |
| model.layers.N.self_attn.{q_proj,k_proj,v_proj,o_proj}.weight |
| model.layers.N.self_attn.{q_norm,k_norm}.weight |
| model.layers.N.post_attention_layernorm.weight |
| model.layers.N.mlp.{gate_proj,up_proj,down_proj}.weight |
| model.norm.weight |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
|
|
| |
|
|
| @dataclass |
| class Qwen3EncoderConfig: |
| hidden_size: int = 2560 |
| num_hidden_layers: int = 36 |
| num_attention_heads: int = 32 |
| num_key_value_heads: int = 8 |
| head_dim: int = 128 |
| intermediate_size: int = 9728 |
| rms_norm_eps: float = 1e-6 |
| rope_theta: float = 1_000_000.0 |
| vocab_size: int = 151936 |
| max_position_embeddings: int = 40960 |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = mx.ones((dim,)) |
| self.eps = eps |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| rms = mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps) |
| return x * rms * self.weight |
|
|
|
|
| |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, theta: float = 1_000_000.0, max_seq_len: int = 8192): |
| super().__init__() |
| self.dim = dim |
| self.theta = theta |
| inv_freq = 1.0 / (theta ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)) |
| self._inv_freq = inv_freq |
| self._max_cached = 0 |
| self._cos_cache = None |
| self._sin_cache = None |
|
|
| def _update_cache(self, seq_len: int): |
| if seq_len <= self._max_cached and self._cos_cache is not None: |
| return |
| t = mx.arange(seq_len, dtype=mx.float32) |
| freqs = mx.outer(t, self._inv_freq) |
| emb = mx.concatenate([freqs, freqs], axis=-1) |
| self._cos_cache = mx.cos(emb) |
| self._sin_cache = mx.sin(emb) |
| self._max_cached = seq_len |
|
|
| def __call__(self, seq_len: int) -> tuple[mx.array, mx.array]: |
| self._update_cache(seq_len) |
| return self._cos_cache[:seq_len], self._sin_cache[:seq_len] |
|
|
|
|
| def _rotate_half(x: mx.array) -> mx.array: |
| x1, x2 = mx.split(x, 2, axis=-1) |
| return mx.concatenate([-x2, x1], axis=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple[mx.array, mx.array]: |
| |
| |
| cos = cos[None, None, :, :] |
| sin = sin[None, None, :, :] |
| q_rot = q * cos + _rotate_half(q) * sin |
| k_rot = k * cos + _rotate_half(k) * sin |
| return q_rot, k_rot |
|
|
|
|
| |
|
|
| class Qwen3Attention(nn.Module): |
| def __init__(self, cfg: Qwen3EncoderConfig): |
| super().__init__() |
| self.n_heads = cfg.num_attention_heads |
| self.n_kv_heads = cfg.num_key_value_heads |
| self.head_dim = cfg.head_dim |
| self.n_rep = self.n_heads // self.n_kv_heads |
|
|
| self.q_proj = nn.Linear(cfg.hidden_size, self.n_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(cfg.hidden_size, self.n_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(cfg.hidden_size, self.n_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(self.n_heads * self.head_dim, cfg.hidden_size, bias=False) |
|
|
| |
| self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps) |
| self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps) |
|
|
| def __call__( |
| self, |
| x: mx.array, |
| cos: mx.array, |
| sin: mx.array, |
| mask: mx.array | None = None, |
| ) -> mx.array: |
| B, L, _ = x.shape |
|
|
| q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim) |
| k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) |
| v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) |
|
|
| |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| |
| q = q.transpose(0, 2, 1, 3) |
| k = k.transpose(0, 2, 1, 3) |
| v = v.transpose(0, 2, 1, 3) |
|
|
| |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| |
| if self.n_rep > 1: |
| k = mx.repeat(k, self.n_rep, axis=1) |
| v = mx.repeat(v, self.n_rep, axis=1) |
|
|
| |
| scale = 1.0 / math.sqrt(self.head_dim) |
| attn = (q @ k.transpose(0, 1, 3, 2)) * scale |
|
|
| if mask is not None: |
| attn = attn + mask |
|
|
| attn = mx.softmax(attn, axis=-1) |
| out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) |
|
|
| return self.o_proj(out) |
|
|
|
|
| |
|
|
| class Qwen3MLP(nn.Module): |
| def __init__(self, cfg: Qwen3EncoderConfig): |
| super().__init__() |
| self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False) |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| |
|
|
| class Qwen3DecoderLayer(nn.Module): |
| def __init__(self, cfg: Qwen3EncoderConfig): |
| super().__init__() |
| self.input_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) |
| self.self_attn = Qwen3Attention(cfg) |
| self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) |
| self.mlp = Qwen3MLP(cfg) |
|
|
| def __call__( |
| self, |
| x: mx.array, |
| cos: mx.array, |
| sin: mx.array, |
| mask: mx.array | None = None, |
| ) -> mx.array: |
| |
| h = self.input_layernorm(x) |
| h = self.self_attn(h, cos, sin, mask) |
| x = x + h |
|
|
| |
| h = self.post_attention_layernorm(x) |
| h = self.mlp(h) |
| x = x + h |
|
|
| return x |
|
|
|
|
| |
|
|
| class Qwen3Encoder(nn.Module): |
| """Qwen3 text encoder for Z-Image-Turbo. |
| |
| Uses the model as an encoder: runs all 36 layers, returns the |
| final hidden states (no causal mask, no generation). |
| """ |
|
|
| def __init__(self, cfg: Qwen3EncoderConfig | None = None): |
| super().__init__() |
| if cfg is None: |
| cfg = Qwen3EncoderConfig() |
| self.cfg = cfg |
|
|
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) |
| self.layers = [Qwen3DecoderLayer(cfg) for _ in range(cfg.num_hidden_layers)] |
| self.norm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) |
| self.rotary_emb = RotaryEmbedding( |
| dim=cfg.head_dim, |
| theta=cfg.rope_theta, |
| max_seq_len=cfg.max_position_embeddings, |
| ) |
|
|
| def __call__(self, input_ids: mx.array, mask: mx.array | None = None) -> mx.array: |
| """Encode text tokens. |
| |
| Returns the second-to-last hidden state (hidden_states[-2]), |
| matching diffusers ZImagePipeline which uses |
| ``text_encoder(..., output_hidden_states=True).hidden_states[-2]``. |
| |
| Applies a causal attention mask by default (matching HuggingFace |
| Qwen3Model which uses causal masking internally). |
| |
| Args: |
| input_ids: (B, L) token IDs |
| mask: optional attention mask (B, 1, L, L) β None = auto causal mask |
| |
| Returns: |
| hidden_states: (B, L, hidden_size) β penultimate layer output |
| """ |
| B, L = input_ids.shape |
| x = self.embed_tokens(input_ids) |
|
|
| cos, sin = self.rotary_emb(L) |
|
|
| |
| if mask is None: |
| mask = mx.full((L, L), -1e9) |
| mask = mx.triu(mask, k=1) |
| mask = mask[None, None, :, :] |
|
|
| n_layers = len(self.layers) |
| for i, layer in enumerate(self.layers): |
| x = layer(x, cos, sin, mask) |
| if i == n_layers - 2: |
| |
| penultimate = x |
|
|
| return penultimate |
|
|