Z-Image-Turbo-MLX / zimage_dit.py
illusion615's picture
Upload folder using huggingface_hub
64566e4 verified
"""ZImageTransformer2DModel β€” MLX native S3-DiT for Z-Image-Turbo.
Architecture (from model config + weight shapes):
- 30 main DiT layers + 2 context_refiner + 2 noise_refiner
- dim=3840, n_heads=30, head_dim=128
- Dual-norm (pre+post) for both attention and FFN
- SwiGLU FFN (w1/w2/w3), intermediate=10240
- QK-Norm (RMSNorm on head_dim=128)
- AdaLN modulation: 4 outputs per block (shift_attn, scale_attn, shift_ffn, scale_ffn)
- N-dim RoPE: axes_dims=[32,48,48], rope_theta=256
- Timestep embedding: sinusoidal(256) β†’ MLP(256β†’1024β†’256)
- Caption projector: RMSNorm(2560) β†’ Linear(2560β†’3840)
- Patch embed: Linear(64β†’3840) (in_channels=16, patch_size=2 β†’ 16Γ—2Β²=64)
- Final layer: adaLN(256β†’3840) + Linear(3840β†’64)
Weight key patterns:
t_embedder.mlp.{0,2}.{weight,bias}
cap_embedder.{0,1}.{weight,bias} (0=RMSNorm, 1=Linear)
cap_pad_token, x_pad_token
all_x_embedder.2-1.{weight,bias}
layers.N.{adaLN_modulation.0, attention.*, attention_norm*, feed_forward.*, ffn_norm*}
context_refiner.N.{attention.*, attention_norm*, feed_forward.*, ffn_norm*}
noise_refiner.N.{adaLN_modulation.0, attention.*, attention_norm*, feed_forward.*, ffn_norm*}
all_final_layer.2-1.{linear, adaLN_modulation.1}
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
# ── Config ────────────────────────────────────────────────────────
@dataclass
class ZImageDiTConfig:
dim: int = 3840
n_heads: int = 30
n_kv_heads: int = 30
n_layers: int = 30
n_refiner_layers: int = 2
head_dim: int = 128
ffn_dim: int = 10240
in_channels: int = 16
patch_size: int = 2
cap_feat_dim: int = 2560 # Qwen3 hidden_size
t_embed_dim: int = 256 # timestep embedding dim
t_hidden_dim: int = 1024 # timestep MLP hidden
axes_dims: list[int] = field(default_factory=lambda: [32, 48, 48])
axes_lens: list[int] = field(default_factory=lambda: [1536, 512, 512])
rope_theta: float = 256.0
norm_eps: float = 1e-5
qk_norm: bool = True
t_scale: float = 1000.0
# ── RMSNorm ───────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dim,))
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps) * self.weight
# ── Timestep Embedding ────────────────────────────────────────────
def timestep_embedding(t: mx.array, dim: int = 256) -> mx.array:
"""Sinusoidal timestep embedding."""
half = dim // 2
freqs = mx.exp(-math.log(10000.0) * mx.arange(half, dtype=mx.float32) / half)
args = t[:, None].astype(mx.float32) * freqs[None, :]
return mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1)
class TimestepEmbedder(nn.Module):
"""Sinusoidal β†’ MLP timestep embedder: sin(t) β†’ Linear β†’ SiLU β†’ Linear."""
def __init__(self, t_embed_dim: int = 256, hidden_dim: int = 1024):
super().__init__()
self.mlp = [
nn.Linear(t_embed_dim, hidden_dim), # mlp.0
None, # SiLU (index 1, not a layer)
nn.Linear(hidden_dim, t_embed_dim), # mlp.2
]
def __call__(self, t: mx.array) -> mx.array:
x = timestep_embedding(t, self.mlp[0].weight.shape[1])
x = nn.silu(self.mlp[0](x))
x = self.mlp[2](x)
return x
# ── N-dim RoPE (matches diffusers RopeEmbedder) ──────────────────
class RopeEmbedder:
"""Precomputed per-axis frequency tables, indexed by position IDs.
Matches diffusers ``RopeEmbedder``:
1. Precompute complex frequencies per axis (as real angle tables here)
2. At forward time, gather from tables using integer position IDs
3. Concatenate per-axis results β†’ (seq_len, sum(axes_dims)//2)
The returned angles are used with :func:`apply_rope` which does the
equivalent of ``torch.view_as_complex(x) * polar(1, angles)`` using
real-valued cos/sin operations.
"""
def __init__(
self,
axes_dims: list[int],
axes_lens: list[int],
theta: float = 256.0,
):
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.theta = theta
# Precompute per-axis frequency tables
self._freq_tables: list[mx.array] = []
for d, e in zip(axes_dims, axes_lens):
inv_freq = 1.0 / (theta ** (mx.arange(0, d, 2, dtype=mx.float32) / d))
timestep = mx.arange(e, dtype=mx.float32)
freqs = mx.outer(timestep, inv_freq) # (e, d/2)
self._freq_tables.append(freqs)
def __call__(self, pos_ids: mx.array) -> mx.array:
"""Look up RoPE angles from precomputed tables.
Args:
pos_ids: (seq_len, 3) integer position IDs β€” one per axis.
Returns:
(seq_len, rope_half_dim) rotation angles.
"""
parts = []
for i in range(len(self.axes_dims)):
idx = pos_ids[:, i].astype(mx.int32)
parts.append(self._freq_tables[i][idx]) # (seq_len, d_i/2)
return mx.concatenate(parts, axis=-1)
def build_position_ids(
cap_len: int,
pH: int,
pW: int,
) -> tuple[mx.array, mx.array]:
"""Build position ID grids matching diffusers patchify_and_embed.
Caption tokens: ``create_coordinate_grid(size=(cap_len, 1, 1), start=(1, 0, 0))``
β†’ t-axis = 1..cap_len, h-axis = 0, w-axis = 0
Image tokens: ``create_coordinate_grid(size=(1, pH, pW), start=(cap_len+1, 0, 0))``
β†’ t-axis = cap_len+1, h-axis = 0..pH-1, w-axis = 0..pW-1
Returns:
(img_pos_ids, cap_pos_ids) each of shape (N, 3)
"""
# Caption: (cap_len, 3) β€” t varies, h=0, w=0
cap_t = mx.arange(1, cap_len + 1, dtype=mx.int32)[:, None] # (cap_len, 1)
cap_hw = mx.zeros((cap_len, 2), dtype=mx.int32)
cap_pos = mx.concatenate([cap_t, cap_hw], axis=-1) # (cap_len, 3)
# Image: (pH*pW, 3) β€” t=cap_len+1, h and w vary
t_val = cap_len + 1
img_ids = []
for h in range(pH):
for w in range(pW):
img_ids.append([t_val, h, w])
img_pos = mx.array(img_ids, dtype=mx.int32) # (pH*pW, 3)
return img_pos, cap_pos
def apply_rope(x: mx.array, freqs: mx.array) -> mx.array:
"""Apply rotary position embedding using interleaved pairing.
Equivalent to diffusers' complex multiplication:
``x_complex = view_as_complex(x.reshape(..., -1, 2))``
``x_out = view_as_real(x_complex * freqs_cis).flatten()``
x: (B, n_heads, L, head_dim)
freqs: (L, rope_half_dim) where rope_half_dim = sum(axes_dims)//2
"""
rope_half_dim = freqs.shape[-1]
rope_dim = rope_half_dim * 2
x_rope = x[..., :rope_dim]
x_pass = x[..., rope_dim:]
cos = mx.cos(freqs)[None, None, :, :] # (1, 1, L, rope_half_dim)
sin = mx.sin(freqs)[None, None, :, :]
# Interleaved pairing: (x[0], x[1]), (x[2], x[3]), ...
x_even = x_rope[..., 0::2] # even indices β†’ "real"
x_odd = x_rope[..., 1::2] # odd indices β†’ "imag"
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
# Interleave back: [re0, im0, re1, im1, ...]
out = mx.stack([out_even, out_odd], axis=-1) # (..., rope_half_dim, 2)
x_rope = out.reshape(*out.shape[:-2], rope_dim)
return mx.concatenate([x_rope, x_pass], axis=-1)
# ── Attention Block ───────────────────────────────────────────────
class DiTAttention(nn.Module):
"""Self-attention with QK-Norm and optional RoPE."""
def __init__(self, dim: int, n_heads: int, head_dim: int, qk_norm: bool = True, norm_eps: float = 1e-5):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.to_q = nn.Linear(dim, n_heads * head_dim, bias=False)
self.to_k = nn.Linear(dim, n_heads * head_dim, bias=False)
self.to_v = nn.Linear(dim, n_heads * head_dim, bias=False)
self.to_out = [nn.Linear(n_heads * head_dim, dim, bias=False)] # to_out.0
if qk_norm:
self.norm_q = RMSNorm(head_dim, eps=norm_eps)
self.norm_k = RMSNorm(head_dim, eps=norm_eps)
else:
self.norm_q = None
self.norm_k = None
def __call__(self, x: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array:
B, L, _ = x.shape
q = self.to_q(x).reshape(B, L, self.n_heads, self.head_dim)
k = self.to_k(x).reshape(B, L, self.n_heads, self.head_dim)
v = self.to_v(x).reshape(B, L, self.n_heads, self.head_dim)
# QK-Norm
if self.norm_q is not None:
q = self.norm_q(q)
k = self.norm_k(k)
# (B, n_heads, L, head_dim)
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# RoPE
if freqs is not None:
q = apply_rope(q, freqs)
k = apply_rope(k, freqs)
# Fused scaled dot-product attention (Metal kernel, no NxN materialization)
scale = 1.0 / math.sqrt(self.head_dim)
if mask is not None:
# Convert boolean mask (B, L) to additive mask for fused attention
attn_mask = mask[:, None, None, :].astype(q.dtype)
attn_mask = (1.0 - attn_mask) * (-1e9)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=attn_mask)
else:
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.to_out[0](out)
# ── SwiGLU FFN ────────────────────────────────────────────────────
class SwiGLUFFN(nn.Module):
"""SwiGLU feed-forward: gate * silu(w1(x)) + w3(x) β†’ w2."""
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, ffn_dim, bias=False) # gate
self.w2 = nn.Linear(ffn_dim, dim, bias=False) # down
self.w3 = nn.Linear(dim, ffn_dim, bias=False) # up
def __call__(self, x: mx.array) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
# ── AdaLN Modulation ─────────────────────────────────────────────
class AdaLNModulation(nn.Module):
"""AdaLN-Zero: project conditioning to shift/scale pairs.
Output dim = dim * n_mods (e.g. 3840 * 4 = 15360 for main blocks).
"""
def __init__(self, cond_dim: int, out_dim: int):
super().__init__()
# Weight key is adaLN_modulation.0 (index 0 in a Sequential-like list)
self._linear = nn.Linear(cond_dim, out_dim)
# Expose as list for weight loading: adaLN_modulation.0.weight/bias
@property
def parameters(self):
return {"0": {"weight": self._linear.weight, "bias": self._linear.bias}}
def __call__(self, c: mx.array) -> mx.array:
return self._linear(c)
# ── DiT Block (main layers + noise_refiner) ──────────────────────
class DiTBlock(nn.Module):
"""S3-DiT block with AdaLN modulation.
4 modulations: shift_attn, scale_attn, shift_ffn, scale_ffn
Dual-norm: pre-norm + post-norm for both attention and FFN.
"""
def __init__(self, cfg: ZImageDiTConfig):
super().__init__()
self.attention = DiTAttention(cfg.dim, cfg.n_heads, cfg.head_dim, cfg.qk_norm, cfg.norm_eps)
self.attention_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # pre-attn norm
self.attention_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # post-attn norm
self.feed_forward = SwiGLUFFN(cfg.dim, cfg.ffn_dim)
self.ffn_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # pre-ffn norm
self.ffn_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # post-ffn norm
# AdaLN: 4 modulation signals (shift_a, scale_a, shift_f, scale_f)
self.adaLN_modulation = [nn.Linear(cfg.t_embed_dim, cfg.dim * 4)]
def __call__(self, x: mx.array, c: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array:
"""
Args:
x: (B, L, dim) hidden states
c: (B, t_embed_dim) conditioning (timestep embedding)
freqs: optional RoPE frequencies for image tokens
mask: optional (B, L) boolean attention mask
"""
# Compute modulation from conditioning
mod = self.adaLN_modulation[0](c) # (B, dim*4)
scale_msa, gate_msa, scale_mlp, gate_mlp = mx.split(mod, 4, axis=-1)
gate_msa = mx.tanh(gate_msa)
gate_mlp = mx.tanh(gate_mlp)
scale_msa = 1.0 + scale_msa
scale_mlp = 1.0 + scale_mlp
scale_msa = scale_msa[:, None, :]
gate_msa = gate_msa[:, None, :]
scale_mlp = scale_mlp[:, None, :]
gate_mlp = gate_mlp[:, None, :]
attn_out = self.attention(self.attention_norm1(x) * scale_msa, freqs, mask)
x = x + gate_msa * self.attention_norm2(attn_out)
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(self.ffn_norm1(x) * scale_mlp)
)
return x
# ── Refiner Block (context_refiner β€” no AdaLN) ──────────────────
class RefinerBlock(nn.Module):
"""Refiner block WITHOUT AdaLN modulation (used for context_refiner)."""
def __init__(self, cfg: ZImageDiTConfig):
super().__init__()
self.attention = DiTAttention(cfg.dim, cfg.n_heads, cfg.head_dim, cfg.qk_norm, cfg.norm_eps)
self.attention_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
self.attention_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
self.feed_forward = SwiGLUFFN(cfg.dim, cfg.ffn_dim)
self.ffn_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
self.ffn_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
def __call__(self, x: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array:
h = self.attention_norm1(x)
h = self.attention(h, freqs, mask)
h = self.attention_norm2(h)
x = x + h
h = self.ffn_norm1(x)
h = self.feed_forward(h)
h = self.ffn_norm2(h)
x = x + h
return x
# ── Final Layer ───────────────────────────────────────────────────
class FinalLayer(nn.Module):
"""Final projection: LayerNorm + adaLN scale + Linear(dim β†’ patch_dim)."""
def __init__(self, dim: int, patch_dim: int, t_embed_dim: int):
super().__init__()
self.linear = nn.Linear(dim, patch_dim)
# adaLN_modulation.1 β€” SiLU + Linear (SiLU at index 0, Linear at index 1)
self.adaLN_modulation = [None, nn.Linear(t_embed_dim, dim)]
def __call__(self, x: mx.array, c: mx.array) -> mx.array:
# SiLU is part of FinalLayer's adaLN_modulation (unlike DiTBlock)
scale = 1.0 + self.adaLN_modulation[1](nn.silu(c)) # (B, dim)
scale = scale[:, None, :] # (B, 1, dim)
# LayerNorm (no learnable params) + scale + linear
x = mx.fast.layer_norm(x, None, None, eps=1e-6)
x = x * scale
x = self.linear(x)
return x
# ── Full ZImage Transformer ──────────────────────────────────────
class ZImageTransformer(nn.Module):
"""ZImageTransformer2DModel β€” S3-DiT for Z-Image-Turbo.
Forward flow:
1. Embed timestep β†’ t_emb (B, 256)
2. Project caption features: RMSNorm + Linear β†’ cap_emb (B, L_text, 3840)
3. Patchify + embed image latents β†’ x_emb (B, L_img, 3840)
4. Concatenate [cap_emb, x_emb] β†’ full sequence
5. Context refiner (2 blocks, no AdaLN)
6. Split β†’ img tokens get RoPE, cap tokens don't
7. Main DiT layers (30 blocks, with AdaLN)
8. Noise refiner (2 blocks, with AdaLN)
9. Extract image tokens β†’ final layer β†’ unpatchify
"""
def __init__(self, cfg: ZImageDiTConfig | None = None):
super().__init__()
if cfg is None:
cfg = ZImageDiTConfig()
self.cfg = cfg
# Timestep embedder
self.t_embedder = TimestepEmbedder(cfg.t_embed_dim, cfg.t_hidden_dim)
# Caption projector: cap_embedder.0 = RMSNorm, cap_embedder.1 = Linear
self.cap_embedder = [
RMSNorm(cfg.cap_feat_dim, eps=cfg.norm_eps),
nn.Linear(cfg.cap_feat_dim, cfg.dim),
]
# Learnable padding tokens
self.cap_pad_token = mx.zeros((1, cfg.dim))
self.x_pad_token = mx.zeros((1, cfg.dim))
# Image patch embedder β€” key uses "2-1" suffix for patch_size=2
# We store as a dict to match weight key `all_x_embedder.2-1.{weight,bias}`
patch_dim = cfg.in_channels * cfg.patch_size * cfg.patch_size # 16 * 4 = 64
self.all_x_embedder = {"2-1": nn.Linear(patch_dim, cfg.dim)}
# Context refiner (no AdaLN)
self.context_refiner = [RefinerBlock(cfg) for _ in range(cfg.n_refiner_layers)]
# Main DiT layers (with AdaLN)
self.layers = [DiTBlock(cfg) for _ in range(cfg.n_layers)]
# Noise refiner (with AdaLN)
self.noise_refiner = [DiTBlock(cfg) for _ in range(cfg.n_refiner_layers)]
# Final layer β€” key uses "2-1" suffix
self.all_final_layer = {
"2-1": FinalLayer(cfg.dim, patch_dim, cfg.t_embed_dim)
}
# Precomputed RoPE frequency tables (matches diffusers RopeEmbedder)
self._rope = RopeEmbedder(cfg.axes_dims, cfg.axes_lens, cfg.rope_theta)
def _patchify(self, x: mx.array) -> mx.array:
"""Convert image latents to patch sequence.
Matches diffusers: channels-last within each patch.
x: (B, C, H, W) β†’ (B, H//p * W//p, p*p*C)
diffusers logic:
image.view(C, 1, 1, h, pH, w, pW)
image.permute(1, 3, 5, 2, 4, 6, 0) # (1, h, w, 1, pH, pW, C)
reshape β†’ (h*w, pH*pW*C)
"""
B, C, H, W = x.shape
p = self.cfg.patch_size
pH, pW = H // p, W // p
# (B, C, pH, p, pW, p)
x = x.reshape(B, C, pH, p, pW, p)
# β†’ (B, pH, pW, p, p, C) β€” channels LAST per patch
x = x.transpose(0, 2, 4, 3, 5, 1)
# β†’ (B, pH*pW, p*p*C)
x = x.reshape(B, pH * pW, p * p * C)
return x
def _unpatchify(self, x: mx.array, h: int, w: int) -> mx.array:
"""Convert patch sequence back to image latents.
Matches diffusers: channels-last within each patch.
x: (B, pH*pW, p*p*C) β†’ (B, C, H, W)
diffusers logic:
x.view(1, h, w, 1, pH, pW, C)
x.permute(6, 0, 3, 1, 4, 2, 5) # (C, 1, 1, h, pH, w, pW)
reshape β†’ (C, H, W)
"""
B = x.shape[0]
p = self.cfg.patch_size
C = self.cfg.in_channels
pH, pW = h // p, w // p
# (B, pH, pW, p, p, C)
x = x.reshape(B, pH, pW, p, p, C)
# β†’ (B, C, pH, p, pW, p)
x = x.transpose(0, 5, 1, 3, 2, 4)
# β†’ (B, C, H, W)
x = x.reshape(B, C, h, w)
return x
def __call__(
self,
x: mx.array,
t: mx.array,
cap_feats: mx.array,
cap_mask: mx.array | None = None,
) -> mx.array:
"""Forward pass β€” matches diffusers ZImageTransformer2DModel.forward().
Correct execution order (from diffusers source):
1. t_embed
2. x_embed β†’ noise_refiner (image tokens with RoPE)
3. cap_embed β†’ context_refiner (text tokens with RoPE)
4. build unified [img, cap] sequence (IMAGE FIRST in basic mode)
5. main layers (30 blocks with AdaLN + RoPE)
6. final_layer on FULL unified sequence
7. extract image tokens β†’ unpatchify
Args:
x: (B, C, H, W) noisy latents
t: (B,) timesteps (1-sigma, scaled by pipeline)
cap_feats: (B, L_text, cap_feat_dim) text encoder hidden states
cap_mask: (B, L_text) boolean mask for padding
Returns:
noise_pred: (B, C, H, W) predicted noise
"""
B, C, H, W = x.shape
cfg = self.cfg
p = cfg.patch_size
pH, pW = H // p, W // p
# 1. Timestep embedding β†’ adaln_input
adaln_input = self.t_embedder(t * cfg.t_scale) # (B, 256)
# 2. Patchify + embed image latents
img = self._patchify(x) # (B, pH*pW, patch_dim=64)
img = self.all_x_embedder["2-1"](img) # (B, pH*pW, dim=3840)
L_cap_orig = cap_feats.shape[1]
L_img = img.shape[1]
# Pad caption to SEQ_MULTI_OF=32 (matching diffusers _pad_with_ids)
SEQ_MULTI_OF = 32
pad_len = (-L_cap_orig) % SEQ_MULTI_OF
L_cap = L_cap_orig + pad_len
# Build position IDs matching diffusers (cap: t=1..L_cap_orig, img: t=L_cap_orig+1)
# NOTE: position IDs use original cap length (not padded), padding tokens get (0,0,0) IDs
img_pos_ids, cap_pos_ids = build_position_ids(L_cap_orig, pH, pW)
# Look up RoPE frequencies from precomputed tables
img_freqs = self._rope(img_pos_ids) # (L_img, rope_half_dim)
cap_freqs_orig = self._rope(cap_pos_ids) # (L_cap_orig, rope_half_dim)
# Pad cap RoPE freqs with zeros for padding positions (same as diffusers)
if pad_len > 0:
cap_freqs = mx.concatenate([
cap_freqs_orig,
mx.zeros((pad_len, cap_freqs_orig.shape[-1]))
], axis=0)
else:
cap_freqs = cap_freqs_orig
# noise_refiner on image tokens (with AdaLN, with RoPE)
for block in self.noise_refiner:
img = block(img, adaln_input, img_freqs)
# 3. Caption embedding (cap_embedder is RMSNorm then Linear)
cap = self.cap_embedder[0](cap_feats) # RMSNorm
cap = self.cap_embedder[1](cap) # Linear β†’ (B, L_cap_orig, dim=3840)
# Pad caption with cap_pad_token (matching diffusers _pad_with_ids).
# In diffusers, ALL tokens (real + pad) attend to each other fully β€”
# cap_pad_token is a learned vector, not masked out. The diffusers
# "attn_mask" is only for batch-level padding (all-True for BS=1).
if pad_len > 0:
pad_tok = mx.broadcast_to(self.cap_pad_token, (B, pad_len, cfg.dim))
cap = mx.concatenate([cap, pad_tok], axis=1) # (B, L_cap, dim)
# context_refiner on text tokens (no AdaLN, WITH RoPE, no mask needed)
for block in self.context_refiner:
cap = block(cap, cap_freqs)
# 4. Build unified sequence [img, cap] β€” IMAGE FIRST (diffusers basic mode)
unified = mx.concatenate([img, cap], axis=1) # (B, L_img + L_cap, dim)
unified_freqs = mx.concatenate([img_freqs, cap_freqs], axis=0)
# 5. Main DiT layers (30 blocks, with AdaLN conditioning + RoPE)
for block in self.layers:
unified = block(unified, adaln_input, unified_freqs)
# 6. Final layer on FULL unified sequence (as diffusers does)
unified = self.all_final_layer["2-1"](unified, adaln_input)
# 7. Extract image tokens (first L_img tokens) and unpatchify
img_out = unified[:, :L_img, :] # (B, L_img, patch_dim=64)
out = self._unpatchify(img_out, H, W)
return out