| """ZImageTransformer2DModel β MLX native S3-DiT for Z-Image-Turbo. |
| |
| Architecture (from model config + weight shapes): |
| - 30 main DiT layers + 2 context_refiner + 2 noise_refiner |
| - dim=3840, n_heads=30, head_dim=128 |
| - Dual-norm (pre+post) for both attention and FFN |
| - SwiGLU FFN (w1/w2/w3), intermediate=10240 |
| - QK-Norm (RMSNorm on head_dim=128) |
| - AdaLN modulation: 4 outputs per block (shift_attn, scale_attn, shift_ffn, scale_ffn) |
| - N-dim RoPE: axes_dims=[32,48,48], rope_theta=256 |
| - Timestep embedding: sinusoidal(256) β MLP(256β1024β256) |
| - Caption projector: RMSNorm(2560) β Linear(2560β3840) |
| - Patch embed: Linear(64β3840) (in_channels=16, patch_size=2 β 16Γ2Β²=64) |
| - Final layer: adaLN(256β3840) + Linear(3840β64) |
| |
| Weight key patterns: |
| t_embedder.mlp.{0,2}.{weight,bias} |
| cap_embedder.{0,1}.{weight,bias} (0=RMSNorm, 1=Linear) |
| cap_pad_token, x_pad_token |
| all_x_embedder.2-1.{weight,bias} |
| layers.N.{adaLN_modulation.0, attention.*, attention_norm*, feed_forward.*, ffn_norm*} |
| context_refiner.N.{attention.*, attention_norm*, feed_forward.*, ffn_norm*} |
| noise_refiner.N.{adaLN_modulation.0, attention.*, attention_norm*, feed_forward.*, ffn_norm*} |
| all_final_layer.2-1.{linear, adaLN_modulation.1} |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass, field |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
|
|
| |
|
|
| @dataclass |
| class ZImageDiTConfig: |
| dim: int = 3840 |
| n_heads: int = 30 |
| n_kv_heads: int = 30 |
| n_layers: int = 30 |
| n_refiner_layers: int = 2 |
| head_dim: int = 128 |
| ffn_dim: int = 10240 |
| in_channels: int = 16 |
| patch_size: int = 2 |
| cap_feat_dim: int = 2560 |
| t_embed_dim: int = 256 |
| t_hidden_dim: int = 1024 |
| axes_dims: list[int] = field(default_factory=lambda: [32, 48, 48]) |
| axes_lens: list[int] = field(default_factory=lambda: [1536, 512, 512]) |
| rope_theta: float = 256.0 |
| norm_eps: float = 1e-5 |
| qk_norm: bool = True |
| t_scale: float = 1000.0 |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.weight = mx.ones((dim,)) |
| self.eps = eps |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps) * self.weight |
|
|
|
|
| |
|
|
| def timestep_embedding(t: mx.array, dim: int = 256) -> mx.array: |
| """Sinusoidal timestep embedding.""" |
| half = dim // 2 |
| freqs = mx.exp(-math.log(10000.0) * mx.arange(half, dtype=mx.float32) / half) |
| args = t[:, None].astype(mx.float32) * freqs[None, :] |
| return mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1) |
|
|
|
|
| class TimestepEmbedder(nn.Module): |
| """Sinusoidal β MLP timestep embedder: sin(t) β Linear β SiLU β Linear.""" |
| def __init__(self, t_embed_dim: int = 256, hidden_dim: int = 1024): |
| super().__init__() |
| self.mlp = [ |
| nn.Linear(t_embed_dim, hidden_dim), |
| None, |
| nn.Linear(hidden_dim, t_embed_dim), |
| ] |
|
|
| def __call__(self, t: mx.array) -> mx.array: |
| x = timestep_embedding(t, self.mlp[0].weight.shape[1]) |
| x = nn.silu(self.mlp[0](x)) |
| x = self.mlp[2](x) |
| return x |
|
|
|
|
| |
|
|
| class RopeEmbedder: |
| """Precomputed per-axis frequency tables, indexed by position IDs. |
| |
| Matches diffusers ``RopeEmbedder``: |
| 1. Precompute complex frequencies per axis (as real angle tables here) |
| 2. At forward time, gather from tables using integer position IDs |
| 3. Concatenate per-axis results β (seq_len, sum(axes_dims)//2) |
| |
| The returned angles are used with :func:`apply_rope` which does the |
| equivalent of ``torch.view_as_complex(x) * polar(1, angles)`` using |
| real-valued cos/sin operations. |
| """ |
|
|
| def __init__( |
| self, |
| axes_dims: list[int], |
| axes_lens: list[int], |
| theta: float = 256.0, |
| ): |
| self.axes_dims = axes_dims |
| self.axes_lens = axes_lens |
| self.theta = theta |
| |
| self._freq_tables: list[mx.array] = [] |
| for d, e in zip(axes_dims, axes_lens): |
| inv_freq = 1.0 / (theta ** (mx.arange(0, d, 2, dtype=mx.float32) / d)) |
| timestep = mx.arange(e, dtype=mx.float32) |
| freqs = mx.outer(timestep, inv_freq) |
| self._freq_tables.append(freqs) |
|
|
| def __call__(self, pos_ids: mx.array) -> mx.array: |
| """Look up RoPE angles from precomputed tables. |
| |
| Args: |
| pos_ids: (seq_len, 3) integer position IDs β one per axis. |
| |
| Returns: |
| (seq_len, rope_half_dim) rotation angles. |
| """ |
| parts = [] |
| for i in range(len(self.axes_dims)): |
| idx = pos_ids[:, i].astype(mx.int32) |
| parts.append(self._freq_tables[i][idx]) |
| return mx.concatenate(parts, axis=-1) |
|
|
|
|
| def build_position_ids( |
| cap_len: int, |
| pH: int, |
| pW: int, |
| ) -> tuple[mx.array, mx.array]: |
| """Build position ID grids matching diffusers patchify_and_embed. |
| |
| Caption tokens: ``create_coordinate_grid(size=(cap_len, 1, 1), start=(1, 0, 0))`` |
| β t-axis = 1..cap_len, h-axis = 0, w-axis = 0 |
| |
| Image tokens: ``create_coordinate_grid(size=(1, pH, pW), start=(cap_len+1, 0, 0))`` |
| β t-axis = cap_len+1, h-axis = 0..pH-1, w-axis = 0..pW-1 |
| |
| Returns: |
| (img_pos_ids, cap_pos_ids) each of shape (N, 3) |
| """ |
| |
| cap_t = mx.arange(1, cap_len + 1, dtype=mx.int32)[:, None] |
| cap_hw = mx.zeros((cap_len, 2), dtype=mx.int32) |
| cap_pos = mx.concatenate([cap_t, cap_hw], axis=-1) |
|
|
| |
| t_val = cap_len + 1 |
| img_ids = [] |
| for h in range(pH): |
| for w in range(pW): |
| img_ids.append([t_val, h, w]) |
| img_pos = mx.array(img_ids, dtype=mx.int32) |
|
|
| return img_pos, cap_pos |
|
|
|
|
| def apply_rope(x: mx.array, freqs: mx.array) -> mx.array: |
| """Apply rotary position embedding using interleaved pairing. |
| |
| Equivalent to diffusers' complex multiplication: |
| ``x_complex = view_as_complex(x.reshape(..., -1, 2))`` |
| ``x_out = view_as_real(x_complex * freqs_cis).flatten()`` |
| |
| x: (B, n_heads, L, head_dim) |
| freqs: (L, rope_half_dim) where rope_half_dim = sum(axes_dims)//2 |
| """ |
| rope_half_dim = freqs.shape[-1] |
| rope_dim = rope_half_dim * 2 |
| x_rope = x[..., :rope_dim] |
| x_pass = x[..., rope_dim:] |
|
|
| cos = mx.cos(freqs)[None, None, :, :] |
| sin = mx.sin(freqs)[None, None, :, :] |
|
|
| |
| x_even = x_rope[..., 0::2] |
| x_odd = x_rope[..., 1::2] |
|
|
| out_even = x_even * cos - x_odd * sin |
| out_odd = x_even * sin + x_odd * cos |
|
|
| |
| out = mx.stack([out_even, out_odd], axis=-1) |
| x_rope = out.reshape(*out.shape[:-2], rope_dim) |
|
|
| return mx.concatenate([x_rope, x_pass], axis=-1) |
|
|
|
|
| |
|
|
| class DiTAttention(nn.Module): |
| """Self-attention with QK-Norm and optional RoPE.""" |
|
|
| def __init__(self, dim: int, n_heads: int, head_dim: int, qk_norm: bool = True, norm_eps: float = 1e-5): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = head_dim |
| self.to_q = nn.Linear(dim, n_heads * head_dim, bias=False) |
| self.to_k = nn.Linear(dim, n_heads * head_dim, bias=False) |
| self.to_v = nn.Linear(dim, n_heads * head_dim, bias=False) |
| self.to_out = [nn.Linear(n_heads * head_dim, dim, bias=False)] |
|
|
| if qk_norm: |
| self.norm_q = RMSNorm(head_dim, eps=norm_eps) |
| self.norm_k = RMSNorm(head_dim, eps=norm_eps) |
| else: |
| self.norm_q = None |
| self.norm_k = None |
|
|
| def __call__(self, x: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array: |
| B, L, _ = x.shape |
| q = self.to_q(x).reshape(B, L, self.n_heads, self.head_dim) |
| k = self.to_k(x).reshape(B, L, self.n_heads, self.head_dim) |
| v = self.to_v(x).reshape(B, L, self.n_heads, self.head_dim) |
|
|
| |
| if self.norm_q is not None: |
| q = self.norm_q(q) |
| k = self.norm_k(k) |
|
|
| |
| q = q.transpose(0, 2, 1, 3) |
| k = k.transpose(0, 2, 1, 3) |
| v = v.transpose(0, 2, 1, 3) |
|
|
| |
| if freqs is not None: |
| q = apply_rope(q, freqs) |
| k = apply_rope(k, freqs) |
|
|
| |
| scale = 1.0 / math.sqrt(self.head_dim) |
| if mask is not None: |
| |
| attn_mask = mask[:, None, None, :].astype(q.dtype) |
| attn_mask = (1.0 - attn_mask) * (-1e9) |
| out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=attn_mask) |
| else: |
| out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) |
|
|
| out = out.transpose(0, 2, 1, 3).reshape(B, L, -1) |
| return self.to_out[0](out) |
|
|
|
|
| |
|
|
| class SwiGLUFFN(nn.Module): |
| """SwiGLU feed-forward: gate * silu(w1(x)) + w3(x) β w2.""" |
| def __init__(self, dim: int, ffn_dim: int): |
| super().__init__() |
| self.w1 = nn.Linear(dim, ffn_dim, bias=False) |
| self.w2 = nn.Linear(ffn_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, ffn_dim, bias=False) |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| return self.w2(nn.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| |
|
|
| class AdaLNModulation(nn.Module): |
| """AdaLN-Zero: project conditioning to shift/scale pairs. |
| |
| Output dim = dim * n_mods (e.g. 3840 * 4 = 15360 for main blocks). |
| """ |
| def __init__(self, cond_dim: int, out_dim: int): |
| super().__init__() |
| |
| self._linear = nn.Linear(cond_dim, out_dim) |
|
|
| |
| @property |
| def parameters(self): |
| return {"0": {"weight": self._linear.weight, "bias": self._linear.bias}} |
|
|
| def __call__(self, c: mx.array) -> mx.array: |
| return self._linear(c) |
|
|
|
|
| |
|
|
| class DiTBlock(nn.Module): |
| """S3-DiT block with AdaLN modulation. |
| |
| 4 modulations: shift_attn, scale_attn, shift_ffn, scale_ffn |
| Dual-norm: pre-norm + post-norm for both attention and FFN. |
| """ |
|
|
| def __init__(self, cfg: ZImageDiTConfig): |
| super().__init__() |
| self.attention = DiTAttention(cfg.dim, cfg.n_heads, cfg.head_dim, cfg.qk_norm, cfg.norm_eps) |
| self.attention_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.attention_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.feed_forward = SwiGLUFFN(cfg.dim, cfg.ffn_dim) |
| self.ffn_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.ffn_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
|
|
| |
| self.adaLN_modulation = [nn.Linear(cfg.t_embed_dim, cfg.dim * 4)] |
|
|
| def __call__(self, x: mx.array, c: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array: |
| """ |
| Args: |
| x: (B, L, dim) hidden states |
| c: (B, t_embed_dim) conditioning (timestep embedding) |
| freqs: optional RoPE frequencies for image tokens |
| mask: optional (B, L) boolean attention mask |
| """ |
| |
| mod = self.adaLN_modulation[0](c) |
| scale_msa, gate_msa, scale_mlp, gate_mlp = mx.split(mod, 4, axis=-1) |
|
|
| gate_msa = mx.tanh(gate_msa) |
| gate_mlp = mx.tanh(gate_mlp) |
| scale_msa = 1.0 + scale_msa |
| scale_mlp = 1.0 + scale_mlp |
|
|
| scale_msa = scale_msa[:, None, :] |
| gate_msa = gate_msa[:, None, :] |
| scale_mlp = scale_mlp[:, None, :] |
| gate_mlp = gate_mlp[:, None, :] |
|
|
| attn_out = self.attention(self.attention_norm1(x) * scale_msa, freqs, mask) |
| x = x + gate_msa * self.attention_norm2(attn_out) |
|
|
| x = x + gate_mlp * self.ffn_norm2( |
| self.feed_forward(self.ffn_norm1(x) * scale_mlp) |
| ) |
|
|
| return x |
|
|
|
|
| |
|
|
| class RefinerBlock(nn.Module): |
| """Refiner block WITHOUT AdaLN modulation (used for context_refiner).""" |
|
|
| def __init__(self, cfg: ZImageDiTConfig): |
| super().__init__() |
| self.attention = DiTAttention(cfg.dim, cfg.n_heads, cfg.head_dim, cfg.qk_norm, cfg.norm_eps) |
| self.attention_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.attention_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.feed_forward = SwiGLUFFN(cfg.dim, cfg.ffn_dim) |
| self.ffn_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
| self.ffn_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) |
|
|
| def __call__(self, x: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array: |
| h = self.attention_norm1(x) |
| h = self.attention(h, freqs, mask) |
| h = self.attention_norm2(h) |
| x = x + h |
|
|
| h = self.ffn_norm1(x) |
| h = self.feed_forward(h) |
| h = self.ffn_norm2(h) |
| x = x + h |
|
|
| return x |
|
|
|
|
| |
|
|
| class FinalLayer(nn.Module): |
| """Final projection: LayerNorm + adaLN scale + Linear(dim β patch_dim).""" |
| def __init__(self, dim: int, patch_dim: int, t_embed_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(dim, patch_dim) |
| |
| self.adaLN_modulation = [None, nn.Linear(t_embed_dim, dim)] |
|
|
| def __call__(self, x: mx.array, c: mx.array) -> mx.array: |
| |
| scale = 1.0 + self.adaLN_modulation[1](nn.silu(c)) |
| scale = scale[:, None, :] |
|
|
| |
| x = mx.fast.layer_norm(x, None, None, eps=1e-6) |
| x = x * scale |
| x = self.linear(x) |
| return x |
|
|
|
|
| |
|
|
| class ZImageTransformer(nn.Module): |
| """ZImageTransformer2DModel β S3-DiT for Z-Image-Turbo. |
| |
| Forward flow: |
| 1. Embed timestep β t_emb (B, 256) |
| 2. Project caption features: RMSNorm + Linear β cap_emb (B, L_text, 3840) |
| 3. Patchify + embed image latents β x_emb (B, L_img, 3840) |
| 4. Concatenate [cap_emb, x_emb] β full sequence |
| 5. Context refiner (2 blocks, no AdaLN) |
| 6. Split β img tokens get RoPE, cap tokens don't |
| 7. Main DiT layers (30 blocks, with AdaLN) |
| 8. Noise refiner (2 blocks, with AdaLN) |
| 9. Extract image tokens β final layer β unpatchify |
| """ |
|
|
| def __init__(self, cfg: ZImageDiTConfig | None = None): |
| super().__init__() |
| if cfg is None: |
| cfg = ZImageDiTConfig() |
| self.cfg = cfg |
|
|
| |
| self.t_embedder = TimestepEmbedder(cfg.t_embed_dim, cfg.t_hidden_dim) |
|
|
| |
| self.cap_embedder = [ |
| RMSNorm(cfg.cap_feat_dim, eps=cfg.norm_eps), |
| nn.Linear(cfg.cap_feat_dim, cfg.dim), |
| ] |
|
|
| |
| self.cap_pad_token = mx.zeros((1, cfg.dim)) |
| self.x_pad_token = mx.zeros((1, cfg.dim)) |
|
|
| |
| |
| patch_dim = cfg.in_channels * cfg.patch_size * cfg.patch_size |
| self.all_x_embedder = {"2-1": nn.Linear(patch_dim, cfg.dim)} |
|
|
| |
| self.context_refiner = [RefinerBlock(cfg) for _ in range(cfg.n_refiner_layers)] |
|
|
| |
| self.layers = [DiTBlock(cfg) for _ in range(cfg.n_layers)] |
|
|
| |
| self.noise_refiner = [DiTBlock(cfg) for _ in range(cfg.n_refiner_layers)] |
|
|
| |
| self.all_final_layer = { |
| "2-1": FinalLayer(cfg.dim, patch_dim, cfg.t_embed_dim) |
| } |
|
|
| |
| self._rope = RopeEmbedder(cfg.axes_dims, cfg.axes_lens, cfg.rope_theta) |
|
|
| def _patchify(self, x: mx.array) -> mx.array: |
| """Convert image latents to patch sequence. |
| |
| Matches diffusers: channels-last within each patch. |
| x: (B, C, H, W) β (B, H//p * W//p, p*p*C) |
| |
| diffusers logic: |
| image.view(C, 1, 1, h, pH, w, pW) |
| image.permute(1, 3, 5, 2, 4, 6, 0) # (1, h, w, 1, pH, pW, C) |
| reshape β (h*w, pH*pW*C) |
| """ |
| B, C, H, W = x.shape |
| p = self.cfg.patch_size |
| pH, pW = H // p, W // p |
| |
| x = x.reshape(B, C, pH, p, pW, p) |
| |
| x = x.transpose(0, 2, 4, 3, 5, 1) |
| |
| x = x.reshape(B, pH * pW, p * p * C) |
| return x |
|
|
| def _unpatchify(self, x: mx.array, h: int, w: int) -> mx.array: |
| """Convert patch sequence back to image latents. |
| |
| Matches diffusers: channels-last within each patch. |
| x: (B, pH*pW, p*p*C) β (B, C, H, W) |
| |
| diffusers logic: |
| x.view(1, h, w, 1, pH, pW, C) |
| x.permute(6, 0, 3, 1, 4, 2, 5) # (C, 1, 1, h, pH, w, pW) |
| reshape β (C, H, W) |
| """ |
| B = x.shape[0] |
| p = self.cfg.patch_size |
| C = self.cfg.in_channels |
| pH, pW = h // p, w // p |
| |
| x = x.reshape(B, pH, pW, p, p, C) |
| |
| x = x.transpose(0, 5, 1, 3, 2, 4) |
| |
| x = x.reshape(B, C, h, w) |
| return x |
|
|
| def __call__( |
| self, |
| x: mx.array, |
| t: mx.array, |
| cap_feats: mx.array, |
| cap_mask: mx.array | None = None, |
| ) -> mx.array: |
| """Forward pass β matches diffusers ZImageTransformer2DModel.forward(). |
| |
| Correct execution order (from diffusers source): |
| 1. t_embed |
| 2. x_embed β noise_refiner (image tokens with RoPE) |
| 3. cap_embed β context_refiner (text tokens with RoPE) |
| 4. build unified [img, cap] sequence (IMAGE FIRST in basic mode) |
| 5. main layers (30 blocks with AdaLN + RoPE) |
| 6. final_layer on FULL unified sequence |
| 7. extract image tokens β unpatchify |
| |
| Args: |
| x: (B, C, H, W) noisy latents |
| t: (B,) timesteps (1-sigma, scaled by pipeline) |
| cap_feats: (B, L_text, cap_feat_dim) text encoder hidden states |
| cap_mask: (B, L_text) boolean mask for padding |
| |
| Returns: |
| noise_pred: (B, C, H, W) predicted noise |
| """ |
| B, C, H, W = x.shape |
| cfg = self.cfg |
| p = cfg.patch_size |
| pH, pW = H // p, W // p |
|
|
| |
| adaln_input = self.t_embedder(t * cfg.t_scale) |
|
|
| |
| img = self._patchify(x) |
| img = self.all_x_embedder["2-1"](img) |
|
|
| L_cap_orig = cap_feats.shape[1] |
| L_img = img.shape[1] |
|
|
| |
| SEQ_MULTI_OF = 32 |
| pad_len = (-L_cap_orig) % SEQ_MULTI_OF |
| L_cap = L_cap_orig + pad_len |
|
|
| |
| |
| img_pos_ids, cap_pos_ids = build_position_ids(L_cap_orig, pH, pW) |
|
|
| |
| img_freqs = self._rope(img_pos_ids) |
| cap_freqs_orig = self._rope(cap_pos_ids) |
|
|
| |
| if pad_len > 0: |
| cap_freqs = mx.concatenate([ |
| cap_freqs_orig, |
| mx.zeros((pad_len, cap_freqs_orig.shape[-1])) |
| ], axis=0) |
| else: |
| cap_freqs = cap_freqs_orig |
|
|
| |
| for block in self.noise_refiner: |
| img = block(img, adaln_input, img_freqs) |
|
|
| |
| cap = self.cap_embedder[0](cap_feats) |
| cap = self.cap_embedder[1](cap) |
|
|
| |
| |
| |
| |
| if pad_len > 0: |
| pad_tok = mx.broadcast_to(self.cap_pad_token, (B, pad_len, cfg.dim)) |
| cap = mx.concatenate([cap, pad_tok], axis=1) |
|
|
| |
| for block in self.context_refiner: |
| cap = block(cap, cap_freqs) |
|
|
| |
| unified = mx.concatenate([img, cap], axis=1) |
| unified_freqs = mx.concatenate([img_freqs, cap_freqs], axis=0) |
|
|
| |
| for block in self.layers: |
| unified = block(unified, adaln_input, unified_freqs) |
|
|
| |
| unified = self.all_final_layer["2-1"](unified, adaln_input) |
|
|
| |
| img_out = unified[:, :L_img, :] |
| out = self._unpatchify(img_out, H, W) |
| return out |
|
|