Z-Image-Turbo-MLX / qwen3_encoder.py
illusion615's picture
Upload folder using huggingface_hub
64566e4 verified
"""Qwen3 Text Encoder β€” MLX native implementation for Z-Image-Turbo.
Architecture (from model config):
- 36 layers, hidden_size=2560, 32 attention heads, 8 KV heads (GQA 4:1)
- head_dim=128, intermediate_size=9728
- hidden_act=silu (SwiGLU FFN)
- RMSNorm (eps=1e-6), QK-Norm on q/k projections
- RoPE (theta=1_000_000)
- vocab_size=151936
Weight key pattern:
model.embed_tokens.weight
model.layers.N.input_layernorm.weight
model.layers.N.self_attn.{q_proj,k_proj,v_proj,o_proj}.weight
model.layers.N.self_attn.{q_norm,k_norm}.weight
model.layers.N.post_attention_layernorm.weight
model.layers.N.mlp.{gate_proj,up_proj,down_proj}.weight
model.norm.weight
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
# ── Config ────────────────────────────────────────────────────────
@dataclass
class Qwen3EncoderConfig:
hidden_size: int = 2560
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
head_dim: int = 128
intermediate_size: int = 9728
rms_norm_eps: float = 1e-6
rope_theta: float = 1_000_000.0
vocab_size: int = 151936
max_position_embeddings: int = 40960
# ── RMSNorm ───────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = mx.ones((dim,))
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
rms = mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
return x * rms * self.weight
# ── RoPE ──────────────────────────────────────────────────────────
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 1_000_000.0, max_seq_len: int = 8192):
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
self._inv_freq = inv_freq
self._max_cached = 0
self._cos_cache = None
self._sin_cache = None
def _update_cache(self, seq_len: int):
if seq_len <= self._max_cached and self._cos_cache is not None:
return
t = mx.arange(seq_len, dtype=mx.float32)
freqs = mx.outer(t, self._inv_freq)
emb = mx.concatenate([freqs, freqs], axis=-1)
self._cos_cache = mx.cos(emb)
self._sin_cache = mx.sin(emb)
self._max_cached = seq_len
def __call__(self, seq_len: int) -> tuple[mx.array, mx.array]:
self._update_cache(seq_len)
return self._cos_cache[:seq_len], self._sin_cache[:seq_len]
def _rotate_half(x: mx.array) -> mx.array:
x1, x2 = mx.split(x, 2, axis=-1)
return mx.concatenate([-x2, x1], axis=-1)
def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple[mx.array, mx.array]:
# q/k shape: (B, heads, L, head_dim)
# cos/sin: (seq_len, head_dim) β†’ (1, 1, seq_len, head_dim)
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
q_rot = q * cos + _rotate_half(q) * sin
k_rot = k * cos + _rotate_half(k) * sin
return q_rot, k_rot
# ── Attention ─────────────────────────────────────────────────────
class Qwen3Attention(nn.Module):
def __init__(self, cfg: Qwen3EncoderConfig):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.n_kv_heads = cfg.num_key_value_heads
self.head_dim = cfg.head_dim
self.n_rep = self.n_heads // self.n_kv_heads # GQA repeat factor
self.q_proj = nn.Linear(cfg.hidden_size, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(cfg.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(cfg.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.n_heads * self.head_dim, cfg.hidden_size, bias=False)
# QK-Norm
self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)
def __call__(
self,
x: mx.array,
cos: mx.array,
sin: mx.array,
mask: mx.array | None = None,
) -> mx.array:
B, L, _ = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)
# QK-Norm (per-head)
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to (B, heads, L, head_dim) for RoPE
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Apply RoPE
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# GQA: repeat KV heads
if self.n_rep > 1:
k = mx.repeat(k, self.n_rep, axis=1)
v = mx.repeat(v, self.n_rep, axis=1)
# Scaled dot-product attention
scale = 1.0 / math.sqrt(self.head_dim)
attn = (q @ k.transpose(0, 1, 3, 2)) * scale
if mask is not None:
attn = attn + mask
attn = mx.softmax(attn, axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(out)
# ── MLP (SwiGLU) ─────────────────────────────────────────────────
class Qwen3MLP(nn.Module):
def __init__(self, cfg: Qwen3EncoderConfig):
super().__init__()
self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
# ── Transformer Layer ────────────────────────────────────────────
class Qwen3DecoderLayer(nn.Module):
def __init__(self, cfg: Qwen3EncoderConfig):
super().__init__()
self.input_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
self.self_attn = Qwen3Attention(cfg)
self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
self.mlp = Qwen3MLP(cfg)
def __call__(
self,
x: mx.array,
cos: mx.array,
sin: mx.array,
mask: mx.array | None = None,
) -> mx.array:
# Pre-norm attention
h = self.input_layernorm(x)
h = self.self_attn(h, cos, sin, mask)
x = x + h
# Pre-norm FFN
h = self.post_attention_layernorm(x)
h = self.mlp(h)
x = x + h
return x
# ── Full Encoder ─────────────────────────────────────────────────
class Qwen3Encoder(nn.Module):
"""Qwen3 text encoder for Z-Image-Turbo.
Uses the model as an encoder: runs all 36 layers, returns the
final hidden states (no causal mask, no generation).
"""
def __init__(self, cfg: Qwen3EncoderConfig | None = None):
super().__init__()
if cfg is None:
cfg = Qwen3EncoderConfig()
self.cfg = cfg
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
self.layers = [Qwen3DecoderLayer(cfg) for _ in range(cfg.num_hidden_layers)]
self.norm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
self.rotary_emb = RotaryEmbedding(
dim=cfg.head_dim,
theta=cfg.rope_theta,
max_seq_len=cfg.max_position_embeddings,
)
def __call__(self, input_ids: mx.array, mask: mx.array | None = None) -> mx.array:
"""Encode text tokens.
Returns the second-to-last hidden state (hidden_states[-2]),
matching diffusers ZImagePipeline which uses
``text_encoder(..., output_hidden_states=True).hidden_states[-2]``.
Applies a causal attention mask by default (matching HuggingFace
Qwen3Model which uses causal masking internally).
Args:
input_ids: (B, L) token IDs
mask: optional attention mask (B, 1, L, L) β€” None = auto causal mask
Returns:
hidden_states: (B, L, hidden_size) β€” penultimate layer output
"""
B, L = input_ids.shape
x = self.embed_tokens(input_ids)
cos, sin = self.rotary_emb(L)
# Build causal mask if none provided (matches HuggingFace Qwen3Model)
if mask is None:
mask = mx.full((L, L), -1e9)
mask = mx.triu(mask, k=1) # upper triangle = -inf
mask = mask[None, None, :, :] # (1, 1, L, L)
n_layers = len(self.layers)
for i, layer in enumerate(self.layers):
x = layer(x, cos, sin, mask)
if i == n_layers - 2:
# Capture second-to-last layer output (no final norm)
penultimate = x
return penultimate