illusion615 commited on
Commit
64566e4
Β·
verified Β·
1 Parent(s): 69458a5

Upload folder using huggingface_hub

Browse files
Files changed (10) hide show
  1. README.md +163 -0
  2. __init__.py +1 -0
  3. autoencoder.py +217 -0
  4. config.json +57 -0
  5. pipeline.py +323 -0
  6. qwen3_encoder.py +266 -0
  7. scheduler.py +79 -0
  8. tokenizer.py +70 -0
  9. weight_loader.py +195 -0
  10. zimage_dit.py +606 -0
README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - zh
6
+ library_name: mlx
7
+ tags:
8
+ - mlx
9
+ - text-to-image
10
+ - apple-silicon
11
+ - image-generation
12
+ - diffusion
13
+ - dit
14
+ base_model: Tongyi-MAI/Z-Image-Turbo
15
+ pipeline_tag: text-to-image
16
+ ---
17
+
18
+ # Z-Image-Turbo MLX
19
+
20
+ **Pure MLX (Apple Silicon) inference pipeline for [Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo)** β€” a 10.26B parameter text-to-image model by Tongyi-MAI.
21
+
22
+ Zero PyTorch dependency. Runs natively on Apple Silicon via Metal GPU.
23
+
24
+ ## Highlights
25
+
26
+ - **100% MLX native** β€” no torch, no diffusers, no transformers needed
27
+ - **bfloat16 inference** with `mx.compile()` kernel fusion
28
+ - **Fused attention** via `mx.fast.scaled_dot_product_attention`
29
+ - **Optional quantization** (4-bit / 8-bit) for low-memory devices
30
+ - **Pixel-identical quality** to the PyTorch reference (verified VAE pixel diff = 0.00 on same latent input)
31
+ - **Chinese & English** prompts supported (Qwen3 text encoder)
32
+
33
+ ## Performance (Apple Silicon)
34
+
35
+ | Resolution | Steps | MLX (bf16) | PyTorch MPS | Ratio |
36
+ |-----------|-------|-----------|-------------|-------|
37
+ | 512Γ—512 | 4 | 5.6s | 5.4s | 1.04Γ— |
38
+ | 512Γ—512 | 8 | 10.6s | 10.0s | 1.06Γ— |
39
+ | 768Γ—768 | 8 | 26.5s | 24.6s | 1.08Γ— |
40
+
41
+ | Metric | MLX | PyTorch MPS |
42
+ |--------|-----|-------------|
43
+ | Load time | **6.3s** | 16.8s |
44
+ | Memory (loaded) | 19.1 GB | ~19 GB |
45
+ | Dependencies disk | ~200 MB | ~5 GB |
46
+
47
+ ## Quick Start
48
+
49
+ ### Install
50
+
51
+ ```bash
52
+ pip install mlx safetensors tokenizers pillow numpy
53
+ ```
54
+
55
+ ### Download
56
+
57
+ ```bash
58
+ # Clone this repo (includes inference code)
59
+ git clone https://huggingface.co/illusion615/Z-Image-Turbo-MLX
60
+
61
+ # Or download the original weights (the inference code handles both)
62
+ huggingface-cli download Tongyi-MAI/Z-Image-Turbo
63
+ ```
64
+
65
+ ### Generate
66
+
67
+ ```python
68
+ from pipeline import ZImageMLXPipeline
69
+
70
+ pipe = ZImageMLXPipeline()
71
+ pipe.load()
72
+
73
+ # Generate an image
74
+ image = pipe.generate(
75
+ prompt="a red cat sitting on a wooden table, detailed fur, soft lighting",
76
+ width=512,
77
+ height=512,
78
+ num_steps=8,
79
+ seed=42,
80
+ )
81
+
82
+ # Save
83
+ from PIL import Image
84
+ Image.fromarray(image).save("output.png")
85
+
86
+ pipe.unload()
87
+ ```
88
+
89
+ ### With Quantization (low memory)
90
+
91
+ ```python
92
+ import mlx.nn as nn
93
+
94
+ pipe = ZImageMLXPipeline()
95
+ pipe.load()
96
+
97
+ # Quantize DiT to 4-bit (reduces memory from 19 GB β†’ 11 GB)
98
+ def large_only(path, module):
99
+ return isinstance(module, nn.Linear) and module.weight.shape[-1] >= 1024
100
+
101
+ nn.quantize(pipe._dit, bits=4, group_size=64, class_predicate=large_only)
102
+ ```
103
+
104
+ ## Architecture
105
+
106
+ ```
107
+ ZImageMLXPipeline
108
+ β”œβ”€β”€ Qwen3 Text Encoder (4.02B params, bfloat16)
109
+ β”‚ └── 36-layer decoder-only transformer, hidden_size=2560
110
+ β”œβ”€β”€ ZImage DiT Transformer (6.15B params, bfloat16)
111
+ β”‚ β”œβ”€β”€ 2 noise_refiner blocks (with AdaLN)
112
+ β”‚ β”œβ”€β”€ 2 context_refiner blocks (no AdaLN)
113
+ β”‚ β”œβ”€β”€ 30 main DiT blocks (with AdaLN + RoPE)
114
+ β”‚ └── Final layer (adaLN + Linear)
115
+ β”œβ”€β”€ VAE Decoder (84M params, float32 force_upcast)
116
+ β”‚ └── 4 UpDecoderBlock2D + MidBlock2D with attention
117
+ └── FlowMatch Euler Scheduler (shift=3.0)
118
+ ```
119
+
120
+ ## Files
121
+
122
+ ```
123
+ β”œβ”€β”€ pipeline.py # Main inference pipeline
124
+ β”œβ”€β”€ zimage_dit.py # DiT transformer (S3-DiT architecture)
125
+ β”œβ”€β”€ qwen3_encoder.py # Qwen3 text encoder
126
+ β”œβ”€β”€ autoencoder.py # VAE decoder
127
+ β”œβ”€β”€ scheduler.py # FlowMatch Euler sampler
128
+ β”œβ”€β”€ tokenizer.py # Fast BPE tokenizer
129
+ β”œβ”€β”€ weight_loader.py # Safetensors loader with key mapping
130
+ └── config.json # Pipeline configuration
131
+ ```
132
+
133
+ ## Model Source
134
+
135
+ This is a pure-MLX inference adaptation of [Tongyi-MAI/Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo). The original model weights are loaded from the upstream HuggingFace repository. All inference code is original work.
136
+
137
+ ## Verified Quality
138
+
139
+ The MLX pipeline has been validated against the PyTorch/diffusers reference at four levels:
140
+
141
+ | Component | Validation | Result |
142
+ |-----------|-----------|--------|
143
+ | Tokenizer | Token-by-token comparison | Exact match |
144
+ | Text Encoder | Cosine similarity | 0.999989 |
145
+ | Scheduler | Sigma schedule diff | max 1e-7 |
146
+ | VAE Decoder | Pixel difference (same latent) | 0.00 |
147
+
148
+ ## License
149
+
150
+ Apache 2.0 (same as upstream Z-Image-Turbo)
151
+
152
+ ## Citation
153
+
154
+ If you use this MLX adaptation, please also cite the original model:
155
+
156
+ ```bibtex
157
+ @misc{z-image-turbo,
158
+ title={Z-Image-Turbo},
159
+ author={Tongyi-MAI},
160
+ year={2025},
161
+ url={https://huggingface.co/Tongyi-MAI/Z-Image-Turbo}
162
+ }
163
+ ```
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Z-Image-Turbo MLX native backend
autoencoder.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AutoencoderKL Decoder β€” pure MLX implementation.
2
+
3
+ Decodes latent representations to RGB images without PyTorch/diffusers
4
+ dependency. Architecture matches diffusers AutoencoderKL with the
5
+ Z-Image-Turbo VAE config:
6
+
7
+ latent_channels = 16
8
+ block_out_channels = [128, 256, 512, 512]
9
+ layers_per_block = 2 (decoder uses layers_per_block + 1 = 3)
10
+ norm_num_groups = 32
11
+ mid_block_add_attention = true
12
+ force_upcast = true (all ops in float32)
13
+ scaling_factor = 0.3611
14
+ shift_factor = 0.1159
15
+
16
+ Data format: NHWC throughout (MLX convention).
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+
23
+ import mlx.core as mx
24
+ import mlx.nn as nn
25
+
26
+ # Match diffusers VAE GroupNorm: eps=1e-6, pytorch_compatible=True
27
+ _GN_EPS = 1e-6
28
+
29
+
30
+ def _gn(groups: int, channels: int) -> nn.GroupNorm:
31
+ return nn.GroupNorm(groups, channels, eps=_GN_EPS, pytorch_compatible=True)
32
+
33
+
34
+ # ── Building blocks ──────────────────────────────────────────────
35
+
36
+
37
+ class ResnetBlock2D(nn.Module):
38
+ """Residual block: GroupNorm β†’ SiLU β†’ Conv β†’ GroupNorm β†’ SiLU β†’ Conv + skip."""
39
+
40
+ def __init__(self, in_channels: int, out_channels: int, groups: int = 32):
41
+ super().__init__()
42
+ self.norm1 = _gn(groups, in_channels)
43
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
44
+ self.norm2 = _gn(groups, out_channels)
45
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
46
+
47
+ self.conv_shortcut = None
48
+ if in_channels != out_channels:
49
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
50
+
51
+ def __call__(self, x: mx.array) -> mx.array:
52
+ residual = x
53
+ x = nn.silu(self.norm1(x))
54
+ x = self.conv1(x)
55
+ x = nn.silu(self.norm2(x))
56
+ x = self.conv2(x)
57
+ if self.conv_shortcut is not None:
58
+ residual = self.conv_shortcut(residual)
59
+ return x + residual
60
+
61
+
62
+ class AttentionBlock(nn.Module):
63
+ """Single-head self-attention over spatial positions (NHWC)."""
64
+
65
+ def __init__(self, channels: int, groups: int = 32):
66
+ super().__init__()
67
+ self.group_norm = _gn(groups, channels)
68
+ self.to_q = nn.Linear(channels, channels)
69
+ self.to_k = nn.Linear(channels, channels)
70
+ self.to_v = nn.Linear(channels, channels)
71
+ # diffusers wraps out-proj in a list (Sequential): to_out.0
72
+ self.to_out = [nn.Linear(channels, channels)]
73
+
74
+ def __call__(self, x: mx.array) -> mx.array:
75
+ residual = x
76
+ B, H, W, C = x.shape
77
+ x = self.group_norm(x)
78
+ x = x.reshape(B, H * W, C)
79
+
80
+ q = self.to_q(x)
81
+ k = self.to_k(x)
82
+ v = self.to_v(x)
83
+
84
+ scale = 1.0 / math.sqrt(C)
85
+ attn = (q @ k.transpose(0, 2, 1)) * scale
86
+ attn = mx.softmax(attn, axis=-1)
87
+ x = attn @ v
88
+
89
+ x = self.to_out[0](x)
90
+ x = x.reshape(B, H, W, C)
91
+ return x + residual
92
+
93
+
94
+ class Upsample2D(nn.Module):
95
+ """2Γ— nearest-neighbour upsample followed by a 3Γ—3 conv."""
96
+
97
+ def __init__(self, channels: int):
98
+ super().__init__()
99
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
100
+
101
+ def __call__(self, x: mx.array) -> mx.array:
102
+ # Nearest-neighbour 2Γ— in NHWC
103
+ B, H, W, C = x.shape
104
+ x = mx.repeat(x, 2, axis=1)
105
+ x = mx.repeat(x, 2, axis=2)
106
+ x = self.conv(x)
107
+ return x
108
+
109
+
110
+ class UpDecoderBlock2D(nn.Module):
111
+ """Decoder up-block: N resnet blocks + optional 2Γ— upsample."""
112
+
113
+ def __init__(
114
+ self,
115
+ in_channels: int,
116
+ out_channels: int,
117
+ num_layers: int = 3,
118
+ add_upsample: bool = True,
119
+ groups: int = 32,
120
+ ):
121
+ super().__init__()
122
+ self.resnets = []
123
+ for i in range(num_layers):
124
+ res_in = in_channels if i == 0 else out_channels
125
+ self.resnets.append(ResnetBlock2D(res_in, out_channels, groups))
126
+
127
+ self.upsamplers = []
128
+ if add_upsample:
129
+ self.upsamplers.append(Upsample2D(out_channels))
130
+
131
+ def __call__(self, x: mx.array) -> mx.array:
132
+ for resnet in self.resnets:
133
+ x = resnet(x)
134
+ for up in self.upsamplers:
135
+ x = up(x)
136
+ return x
137
+
138
+
139
+ class MidBlock2D(nn.Module):
140
+ """Mid block: resnet β†’ self-attention β†’ resnet."""
141
+
142
+ def __init__(self, channels: int, groups: int = 32):
143
+ super().__init__()
144
+ self.resnets = [
145
+ ResnetBlock2D(channels, channels, groups),
146
+ ResnetBlock2D(channels, channels, groups),
147
+ ]
148
+ self.attentions = [AttentionBlock(channels, groups)]
149
+
150
+ def __call__(self, x: mx.array) -> mx.array:
151
+ x = self.resnets[0](x)
152
+ x = self.attentions[0](x)
153
+ x = self.resnets[1](x)
154
+ return x
155
+
156
+
157
+ # ── Decoder ──────────────────────────────────────────────────────
158
+
159
+
160
+ class Decoder(nn.Module):
161
+ """AutoencoderKL Decoder (NHWC, pure MLX).
162
+
163
+ Module hierarchy matches diffusers weight-key paths after stripping
164
+ the ``decoder.`` prefix, so weights can be loaded directly.
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ latent_channels: int = 16,
170
+ block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
171
+ layers_per_block: int = 2,
172
+ norm_num_groups: int = 32,
173
+ ):
174
+ super().__init__()
175
+ reversed_ch = list(reversed(block_out_channels)) # [512, 512, 256, 128]
176
+
177
+ # Input projection
178
+ self.conv_in = nn.Conv2d(latent_channels, reversed_ch[0], kernel_size=3, padding=1)
179
+
180
+ # Mid block
181
+ self.mid_block = MidBlock2D(reversed_ch[0], norm_num_groups)
182
+
183
+ # Up blocks (3 upsamples β†’ total 8Γ— spatial increase)
184
+ self.up_blocks = []
185
+ for i, out_ch in enumerate(reversed_ch):
186
+ in_ch = reversed_ch[i - 1] if i > 0 else reversed_ch[0]
187
+ add_upsample = i < len(reversed_ch) - 1
188
+ self.up_blocks.append(
189
+ UpDecoderBlock2D(
190
+ in_channels=in_ch,
191
+ out_channels=out_ch,
192
+ num_layers=layers_per_block + 1,
193
+ add_upsample=add_upsample,
194
+ groups=norm_num_groups,
195
+ )
196
+ )
197
+
198
+ # Output
199
+ self.conv_norm_out = _gn(norm_num_groups, reversed_ch[-1])
200
+ self.conv_out = nn.Conv2d(reversed_ch[-1], 3, kernel_size=3, padding=1)
201
+
202
+ def __call__(self, z: mx.array) -> mx.array:
203
+ """Decode latents β†’ image.
204
+
205
+ Args:
206
+ z: (B, H, W, C) latent tensor in NHWC, **already scaled**.
207
+
208
+ Returns:
209
+ (B, 8H, 8W, 3) decoded image.
210
+ """
211
+ x = self.conv_in(z)
212
+ x = self.mid_block(x)
213
+ for block in self.up_blocks:
214
+ x = block(x)
215
+ x = nn.silu(self.conv_norm_out(x))
216
+ x = self.conv_out(x)
217
+ return x
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pipeline_type": "ZImageMLXPipeline",
3
+ "base_model": "Tongyi-MAI/Z-Image-Turbo",
4
+ "framework": "mlx",
5
+ "model": {
6
+ "total_params": "10.26B",
7
+ "text_encoder": {
8
+ "type": "Qwen3",
9
+ "params": "4.02B",
10
+ "hidden_size": 2560,
11
+ "num_layers": 36,
12
+ "num_attention_heads": 32,
13
+ "num_key_value_heads": 8,
14
+ "dtype": "bfloat16"
15
+ },
16
+ "transformer": {
17
+ "type": "ZImageTransformer (S3-DiT)",
18
+ "params": "6.15B",
19
+ "dim": 3840,
20
+ "n_heads": 30,
21
+ "head_dim": 128,
22
+ "n_layers": 30,
23
+ "n_refiner_layers": 2,
24
+ "ffn_dim": 10240,
25
+ "in_channels": 16,
26
+ "patch_size": 2,
27
+ "dtype": "bfloat16"
28
+ },
29
+ "vae": {
30
+ "type": "AutoencoderKL Decoder",
31
+ "params": "84M",
32
+ "latent_channels": 16,
33
+ "block_out_channels": [128, 256, 512, 512],
34
+ "scaling_factor": 0.3611,
35
+ "shift_factor": 0.1159,
36
+ "dtype": "float32"
37
+ },
38
+ "scheduler": {
39
+ "type": "FlowMatchEulerDiscrete",
40
+ "shift": 3.0,
41
+ "num_train_timesteps": 1000
42
+ }
43
+ },
44
+ "quantization": {
45
+ "supported_bits": [4, 8, 16],
46
+ "default_bits": 16,
47
+ "group_size": 64,
48
+ "min_quantize_dim": 1024
49
+ },
50
+ "generation_defaults": {
51
+ "width": 512,
52
+ "height": 512,
53
+ "num_steps": 8,
54
+ "guidance_scale": 0.0,
55
+ "max_text_len": 256
56
+ }
57
+ }
pipeline.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Z-Image-Turbo MLX Pipeline β€” end-to-end text-to-image generation.
2
+
3
+ Flow:
4
+ 1. Tokenize prompt β†’ token IDs
5
+ 2. Qwen3 Encoder β†’ text hidden states (MLX)
6
+ 3. Initialize random latents
7
+ 4. Denoise loop: DiT forward pass Γ— N steps (MLX)
8
+ 5. VAE decode latents β†’ RGB image (MLX native)
9
+ 6. Save to PNG
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import time
16
+ from pathlib import Path
17
+
18
+ import mlx.core as mx
19
+ import numpy as np
20
+ from PIL import Image
21
+
22
+ from .autoencoder import Decoder
23
+ from .qwen3_encoder import Qwen3Encoder, Qwen3EncoderConfig
24
+ from .zimage_dit import ZImageTransformer, ZImageDiTConfig
25
+ from .scheduler import FlowMatchEulerScheduler
26
+ from .tokenizer import Qwen2Tokenizer
27
+ from .weight_loader import (
28
+ _find_model_path,
29
+ _log_memory,
30
+ load_text_encoder_weights,
31
+ load_transformer_weights,
32
+ load_vae_decoder_weights,
33
+ )
34
+
35
+ logger = logging.getLogger("zimage-mlx")
36
+
37
+
38
+ def _cast_to_bf16(model):
39
+ """Cast all parameters of an nn.Module to bfloat16 in-place.
40
+
41
+ This halves memory and speeds up Metal compute for the DiT.
42
+ """
43
+ from mlx.utils import tree_map
44
+ params = model.parameters()
45
+ bf16_params = tree_map(lambda x: x.astype(mx.bfloat16) if isinstance(x, mx.array) else x, params)
46
+ model.update(bf16_params)
47
+ return model
48
+
49
+
50
+ class ZImageMLXPipeline:
51
+ """End-to-end Z-Image-Turbo inference pipeline β€” 100% MLX.
52
+
53
+ All stages run on Apple Silicon via MLX: text encoding,
54
+ DiT denoising, and VAE decoding. No PyTorch dependency.
55
+ """
56
+
57
+ def __init__(self, model_id: str = "Tongyi-MAI/Z-Image-Turbo"):
58
+ self.model_id = model_id
59
+ self._model_path: Path | None = None
60
+ self._tokenizer: Qwen2Tokenizer | None = None
61
+ self._encoder: Qwen3Encoder | None = None
62
+ self._dit: ZImageTransformer | None = None
63
+ self._dit_compiled = None # mx.compile'd forward pass
64
+ self._scheduler = FlowMatchEulerScheduler(shift=3.0)
65
+ self._vae: Decoder | None = None
66
+ self._loaded = False
67
+
68
+ def load(self, model_path: Path | None = None):
69
+ """Load all model components.
70
+
71
+ Memory strategy (staged loading):
72
+ - Encoder, DiT, VAE are loaded sequentially.
73
+ - During generation, encoder is released after text encoding
74
+ to reduce peak memory (see ``generate()``).
75
+ """
76
+ t0 = time.monotonic()
77
+ self._model_path = model_path or _find_model_path(self.model_id)
78
+ _log_memory("before load")
79
+
80
+ # 1. Tokenizer
81
+ logger.info("[ZImage-MLX] Loading tokenizer...")
82
+ self._tokenizer = Qwen2Tokenizer(self._model_path)
83
+
84
+ # 2. Text encoder (Qwen3)
85
+ logger.info("[ZImage-MLX] Loading text encoder (Qwen3, 36 layers)...")
86
+ self._encoder = Qwen3Encoder(Qwen3EncoderConfig())
87
+ te_weights = load_text_encoder_weights(self._model_path)
88
+ self._encoder.load_weights(list(te_weights.items()))
89
+ # Weights are already bfloat16 on disk; keep them as-is for memory savings
90
+ mx.eval(self._encoder.parameters())
91
+ del te_weights # release weight dict immediately
92
+ logger.info("[ZImage-MLX] Text encoder loaded (bfloat16)")
93
+ _log_memory("after text encoder")
94
+
95
+ # 3. DiT (ZImageTransformer)
96
+ logger.info("[ZImage-MLX] Loading transformer (S3-DiT, 30+2+2 layers)...")
97
+ self._dit = ZImageTransformer(ZImageDiTConfig())
98
+ dit_weights = load_transformer_weights(self._model_path)
99
+ self._dit.load_weights(list(dit_weights.items()))
100
+ # Cast DiT to bfloat16 for faster inference (~2Γ— speedup, ~50% memory)
101
+ # PyTorch diffusers also runs at bfloat16 on MPS
102
+ self._dit = _cast_to_bf16(self._dit)
103
+ mx.eval(self._dit.parameters())
104
+ del dit_weights # release weight dict immediately
105
+ # Compile DiT forward for additional Metal kernel fusion speedup
106
+ self._dit_compiled = mx.compile(self._dit)
107
+ logger.info("[ZImage-MLX] Transformer loaded (bfloat16 + compiled)")
108
+ _log_memory("after transformer")
109
+
110
+ # 4. VAE decoder (MLX native)
111
+ logger.info("[ZImage-MLX] Loading VAE decoder...")
112
+ self._vae = Decoder()
113
+ vae_weights = load_vae_decoder_weights(self._model_path)
114
+ self._vae.load_weights(vae_weights)
115
+ mx.eval(self._vae.parameters())
116
+ del vae_weights # release weight list immediately
117
+ logger.info("[ZImage-MLX] VAE decoder loaded")
118
+
119
+ elapsed = time.monotonic() - t0
120
+ self._loaded = True
121
+ _log_memory("after full load")
122
+ logger.info("[ZImage-MLX] Pipeline loaded in %.1fs", elapsed)
123
+
124
+
125
+
126
+ def generate(
127
+ self,
128
+ prompt: str,
129
+ width: int = 768,
130
+ height: int = 768,
131
+ num_steps: int = 8,
132
+ seed: int | None = None,
133
+ guidance_scale: float = 0.0, # Z-Image-Turbo typically uses 0
134
+ max_text_len: int = 256,
135
+ ) -> np.ndarray:
136
+ """Generate an image from a text prompt.
137
+
138
+ Args:
139
+ prompt: Text description (Chinese or English)
140
+ width: Output width (must be divisible by 16)
141
+ height: Output height (must be divisible by 16)
142
+ num_steps: Number of denoising steps
143
+ seed: Random seed (None for random)
144
+ guidance_scale: CFG scale (0 = no guidance)
145
+ max_text_len: Max text token length
146
+
147
+ Returns:
148
+ RGB image as numpy array (H, W, 3) uint8
149
+ """
150
+ if not self._loaded:
151
+ raise RuntimeError("Pipeline not loaded. Call load() first.")
152
+
153
+ t0 = time.monotonic()
154
+ if seed is None:
155
+ seed = int(time.time()) % (2**31)
156
+
157
+ # Ensure encoder is available (may have been released after prev gen)
158
+ self._reload_encoder()
159
+
160
+ # ── 1. Tokenize (with chat template like diffusers) ──
161
+ chat_result = self._tokenizer.apply_chat_template(prompt, max_length=max_text_len)
162
+ token_ids = chat_result["input_ids"] # list[int]
163
+ attn_mask = chat_result["attention_mask"] # list[int]
164
+
165
+ input_ids = mx.array([token_ids]) # (1, L)
166
+
167
+ # ── 2. Text encode ──
168
+ t_enc = time.monotonic()
169
+ if self._encoder is None:
170
+ raise RuntimeError("Text encoder not loaded. Call load() first.")
171
+ all_hidden = self._encoder(input_ids) # (1, L, 2560) β€” bfloat16
172
+ cap_feats = all_hidden # (1, L, 2560)
173
+ mx.eval(cap_feats)
174
+ logger.info("[ZImage-MLX] Text encoded in %.2fs, %d tokens", time.monotonic() - t_enc, cap_feats.shape[1])
175
+
176
+ # Release encoder to free memory before DiT denoising.
177
+ self._release_encoder()
178
+
179
+ # ── 3. Initialize latents ──
180
+ latent_h = height // 8
181
+ latent_w = width // 8
182
+ mx.random.seed(seed)
183
+ # Use bfloat16 latents to match DiT precision
184
+ latents = mx.random.normal((1, 16, latent_h, latent_w)).astype(mx.bfloat16)
185
+
186
+ # Ensure cap_feats is bfloat16 for DiT
187
+ cap_feats = cap_feats.astype(mx.bfloat16)
188
+
189
+ # ── 4. Denoise loop ──
190
+ sigmas = self._scheduler.get_sigmas(num_steps)
191
+ mx.eval(sigmas)
192
+ sigmas_list = sigmas.tolist()
193
+
194
+ dit_fn = self._dit_compiled if self._dit_compiled is not None else self._dit
195
+
196
+ t_denoise = time.monotonic()
197
+ for i in range(num_steps):
198
+ sigma = sigmas_list[i]
199
+ sigma_next = sigmas_list[i + 1]
200
+
201
+ t_step = mx.array([1.0 - sigma], dtype=mx.bfloat16)
202
+ noise_pred = dit_fn(latents, t_step, cap_feats)
203
+
204
+ # Diffusers negates the model output before passing to scheduler:
205
+ # noise_pred = -noise_pred
206
+ # Then scheduler does: prev_sample = sample + (sigma_next - sigma) * model_output
207
+ noise_pred = -noise_pred
208
+
209
+ latents = self._scheduler.step(noise_pred, sigma, sigma_next, latents)
210
+ mx.eval(latents)
211
+
212
+ logger.info("[ZImage-MLX] Step %d/%d (sigma %.4f β†’ %.4f)", i + 1, num_steps, sigma, sigma_next)
213
+
214
+ denoise_time = time.monotonic() - t_denoise
215
+ logger.info("[ZImage-MLX] Denoised in %.2fs (%.2fs/step)", denoise_time, denoise_time / num_steps)
216
+
217
+ # ── 5. VAE decode (MLX native) ──
218
+ t_vae = time.monotonic()
219
+ image = self._vae_decode(latents)
220
+ logger.info("[ZImage-MLX] VAE decoded in %.2fs", time.monotonic() - t_vae)
221
+
222
+ total = time.monotonic() - t0
223
+ logger.info("[ZImage-MLX] Total generation: %.2fs", total)
224
+
225
+ return image
226
+
227
+ def _vae_decode(self, latents: mx.array) -> np.ndarray:
228
+ """Decode latents β†’ RGB image using MLX VAE.
229
+
230
+ Diffusers formula:
231
+ z = latents / scaling_factor + shift_factor
232
+ raw = vae.decode(z) # output in [-1, 1]
233
+ image = raw / 2 + 0.5 # denormalize to [0, 1]
234
+ """
235
+ scaling_factor = 0.3611
236
+ shift_factor = 0.1159
237
+
238
+ # NCHW β†’ NHWC for MLX convolutions
239
+ z = latents.transpose(0, 2, 3, 1) # (B,C,H,W) β†’ (B,H,W,C)
240
+ z = z.astype(mx.float32) # force_upcast
241
+ z = z / scaling_factor + shift_factor
242
+
243
+ decoded = self._vae(z) # (B,8H,8W,3) in [-1, 1]
244
+ mx.eval(decoded)
245
+
246
+ # Denormalize [-1,1] β†’ [0,1], then clamp β†’ uint8
247
+ img = decoded[0] / 2.0 + 0.5
248
+ img = mx.clip(img, 0.0, 1.0)
249
+ img = np.array(img)
250
+ img = (img * 255).astype(np.uint8)
251
+ return img
252
+
253
+ def generate_and_save(
254
+ self,
255
+ prompt: str,
256
+ output_path: str,
257
+ width: int = 768,
258
+ height: int = 768,
259
+ num_steps: int = 8,
260
+ seed: int | None = None,
261
+ ) -> dict:
262
+ """Generate an image and save to file.
263
+
264
+ Returns:
265
+ Dict with generation metadata.
266
+ """
267
+ t0 = time.monotonic()
268
+ if seed is None:
269
+ seed = int(time.time()) % (2**31)
270
+
271
+ image = self.generate(
272
+ prompt=prompt,
273
+ width=width,
274
+ height=height,
275
+ num_steps=num_steps,
276
+ seed=seed,
277
+ )
278
+
279
+ # Save
280
+ img = Image.fromarray(image)
281
+ img.save(output_path)
282
+
283
+ elapsed = time.monotonic() - t0
284
+ return {
285
+ "image_path": output_path,
286
+ "width": width,
287
+ "height": height,
288
+ "seed": seed,
289
+ "num_steps": num_steps,
290
+ "elapsed_s": round(elapsed, 2),
291
+ "prompt": prompt,
292
+ }
293
+
294
+ def _release_encoder(self):
295
+ """Release text encoder to free ~5 GB before denoising."""
296
+ if self._encoder is not None:
297
+ self._encoder = None
298
+ mx.clear_cache()
299
+ _log_memory("after releasing encoder")
300
+
301
+ def _reload_encoder(self):
302
+ """Reload encoder for next generation (lazy, on-demand)."""
303
+ if self._encoder is None and self._model_path is not None:
304
+ logger.info("[ZImage-MLX] Reloading text encoder...")
305
+ self._encoder = Qwen3Encoder(Qwen3EncoderConfig())
306
+ te_weights = load_text_encoder_weights(self._model_path)
307
+ self._encoder.load_weights(list(te_weights.items()))
308
+ # Weights are bfloat16 on disk; keep as-is
309
+ mx.eval(self._encoder.parameters())
310
+ del te_weights
311
+ _log_memory("after reloading encoder")
312
+
313
+ def unload(self):
314
+ """Release all model memory."""
315
+ self._encoder = None
316
+ self._dit = None
317
+ self._vae = None
318
+ self._tokenizer = None
319
+ self._loaded = False
320
+ mx.clear_cache()
321
+ _log_memory("after full unload")
322
+ mx.clear_cache()
323
+ logger.info("[ZImage-MLX] Pipeline unloaded")
qwen3_encoder.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen3 Text Encoder β€” MLX native implementation for Z-Image-Turbo.
2
+
3
+ Architecture (from model config):
4
+ - 36 layers, hidden_size=2560, 32 attention heads, 8 KV heads (GQA 4:1)
5
+ - head_dim=128, intermediate_size=9728
6
+ - hidden_act=silu (SwiGLU FFN)
7
+ - RMSNorm (eps=1e-6), QK-Norm on q/k projections
8
+ - RoPE (theta=1_000_000)
9
+ - vocab_size=151936
10
+
11
+ Weight key pattern:
12
+ model.embed_tokens.weight
13
+ model.layers.N.input_layernorm.weight
14
+ model.layers.N.self_attn.{q_proj,k_proj,v_proj,o_proj}.weight
15
+ model.layers.N.self_attn.{q_norm,k_norm}.weight
16
+ model.layers.N.post_attention_layernorm.weight
17
+ model.layers.N.mlp.{gate_proj,up_proj,down_proj}.weight
18
+ model.norm.weight
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import math
24
+ from dataclasses import dataclass
25
+
26
+ import mlx.core as mx
27
+ import mlx.nn as nn
28
+
29
+
30
+ # ── Config ────────────────────────────────────────────────────────
31
+
32
+ @dataclass
33
+ class Qwen3EncoderConfig:
34
+ hidden_size: int = 2560
35
+ num_hidden_layers: int = 36
36
+ num_attention_heads: int = 32
37
+ num_key_value_heads: int = 8
38
+ head_dim: int = 128
39
+ intermediate_size: int = 9728
40
+ rms_norm_eps: float = 1e-6
41
+ rope_theta: float = 1_000_000.0
42
+ vocab_size: int = 151936
43
+ max_position_embeddings: int = 40960
44
+
45
+
46
+ # ── RMSNorm ───────────────────────────────────────────────────────
47
+
48
+ class RMSNorm(nn.Module):
49
+ def __init__(self, dim: int, eps: float = 1e-6):
50
+ super().__init__()
51
+ self.weight = mx.ones((dim,))
52
+ self.eps = eps
53
+
54
+ def __call__(self, x: mx.array) -> mx.array:
55
+ rms = mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
56
+ return x * rms * self.weight
57
+
58
+
59
+ # ── RoPE ──────────────────────────────────────────────────────────
60
+
61
+ class RotaryEmbedding(nn.Module):
62
+ def __init__(self, dim: int, theta: float = 1_000_000.0, max_seq_len: int = 8192):
63
+ super().__init__()
64
+ self.dim = dim
65
+ self.theta = theta
66
+ inv_freq = 1.0 / (theta ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
67
+ self._inv_freq = inv_freq
68
+ self._max_cached = 0
69
+ self._cos_cache = None
70
+ self._sin_cache = None
71
+
72
+ def _update_cache(self, seq_len: int):
73
+ if seq_len <= self._max_cached and self._cos_cache is not None:
74
+ return
75
+ t = mx.arange(seq_len, dtype=mx.float32)
76
+ freqs = mx.outer(t, self._inv_freq)
77
+ emb = mx.concatenate([freqs, freqs], axis=-1)
78
+ self._cos_cache = mx.cos(emb)
79
+ self._sin_cache = mx.sin(emb)
80
+ self._max_cached = seq_len
81
+
82
+ def __call__(self, seq_len: int) -> tuple[mx.array, mx.array]:
83
+ self._update_cache(seq_len)
84
+ return self._cos_cache[:seq_len], self._sin_cache[:seq_len]
85
+
86
+
87
+ def _rotate_half(x: mx.array) -> mx.array:
88
+ x1, x2 = mx.split(x, 2, axis=-1)
89
+ return mx.concatenate([-x2, x1], axis=-1)
90
+
91
+
92
+ def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple[mx.array, mx.array]:
93
+ # q/k shape: (B, heads, L, head_dim)
94
+ # cos/sin: (seq_len, head_dim) β†’ (1, 1, seq_len, head_dim)
95
+ cos = cos[None, None, :, :]
96
+ sin = sin[None, None, :, :]
97
+ q_rot = q * cos + _rotate_half(q) * sin
98
+ k_rot = k * cos + _rotate_half(k) * sin
99
+ return q_rot, k_rot
100
+
101
+
102
+ # ── Attention ─────────────────────────────────────────────────────
103
+
104
+ class Qwen3Attention(nn.Module):
105
+ def __init__(self, cfg: Qwen3EncoderConfig):
106
+ super().__init__()
107
+ self.n_heads = cfg.num_attention_heads
108
+ self.n_kv_heads = cfg.num_key_value_heads
109
+ self.head_dim = cfg.head_dim
110
+ self.n_rep = self.n_heads // self.n_kv_heads # GQA repeat factor
111
+
112
+ self.q_proj = nn.Linear(cfg.hidden_size, self.n_heads * self.head_dim, bias=False)
113
+ self.k_proj = nn.Linear(cfg.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
114
+ self.v_proj = nn.Linear(cfg.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
115
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, cfg.hidden_size, bias=False)
116
+
117
+ # QK-Norm
118
+ self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)
119
+ self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_norm_eps)
120
+
121
+ def __call__(
122
+ self,
123
+ x: mx.array,
124
+ cos: mx.array,
125
+ sin: mx.array,
126
+ mask: mx.array | None = None,
127
+ ) -> mx.array:
128
+ B, L, _ = x.shape
129
+
130
+ q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim)
131
+ k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)
132
+ v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim)
133
+
134
+ # QK-Norm (per-head)
135
+ q = self.q_norm(q)
136
+ k = self.k_norm(k)
137
+
138
+ # Transpose to (B, heads, L, head_dim) for RoPE
139
+ q = q.transpose(0, 2, 1, 3)
140
+ k = k.transpose(0, 2, 1, 3)
141
+ v = v.transpose(0, 2, 1, 3)
142
+
143
+ # Apply RoPE
144
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
145
+
146
+ # GQA: repeat KV heads
147
+ if self.n_rep > 1:
148
+ k = mx.repeat(k, self.n_rep, axis=1)
149
+ v = mx.repeat(v, self.n_rep, axis=1)
150
+
151
+ # Scaled dot-product attention
152
+ scale = 1.0 / math.sqrt(self.head_dim)
153
+ attn = (q @ k.transpose(0, 1, 3, 2)) * scale
154
+
155
+ if mask is not None:
156
+ attn = attn + mask
157
+
158
+ attn = mx.softmax(attn, axis=-1)
159
+ out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
160
+
161
+ return self.o_proj(out)
162
+
163
+
164
+ # ── MLP (SwiGLU) ─────────────────────────────────────────────────
165
+
166
+ class Qwen3MLP(nn.Module):
167
+ def __init__(self, cfg: Qwen3EncoderConfig):
168
+ super().__init__()
169
+ self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
170
+ self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
171
+ self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
172
+
173
+ def __call__(self, x: mx.array) -> mx.array:
174
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
175
+
176
+
177
+ # ── Transformer Layer ────────────────────────────────────────────
178
+
179
+ class Qwen3DecoderLayer(nn.Module):
180
+ def __init__(self, cfg: Qwen3EncoderConfig):
181
+ super().__init__()
182
+ self.input_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
183
+ self.self_attn = Qwen3Attention(cfg)
184
+ self.post_attention_layernorm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
185
+ self.mlp = Qwen3MLP(cfg)
186
+
187
+ def __call__(
188
+ self,
189
+ x: mx.array,
190
+ cos: mx.array,
191
+ sin: mx.array,
192
+ mask: mx.array | None = None,
193
+ ) -> mx.array:
194
+ # Pre-norm attention
195
+ h = self.input_layernorm(x)
196
+ h = self.self_attn(h, cos, sin, mask)
197
+ x = x + h
198
+
199
+ # Pre-norm FFN
200
+ h = self.post_attention_layernorm(x)
201
+ h = self.mlp(h)
202
+ x = x + h
203
+
204
+ return x
205
+
206
+
207
+ # ── Full Encoder ─────────────────────────────────────────────────
208
+
209
+ class Qwen3Encoder(nn.Module):
210
+ """Qwen3 text encoder for Z-Image-Turbo.
211
+
212
+ Uses the model as an encoder: runs all 36 layers, returns the
213
+ final hidden states (no causal mask, no generation).
214
+ """
215
+
216
+ def __init__(self, cfg: Qwen3EncoderConfig | None = None):
217
+ super().__init__()
218
+ if cfg is None:
219
+ cfg = Qwen3EncoderConfig()
220
+ self.cfg = cfg
221
+
222
+ self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
223
+ self.layers = [Qwen3DecoderLayer(cfg) for _ in range(cfg.num_hidden_layers)]
224
+ self.norm = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
225
+ self.rotary_emb = RotaryEmbedding(
226
+ dim=cfg.head_dim,
227
+ theta=cfg.rope_theta,
228
+ max_seq_len=cfg.max_position_embeddings,
229
+ )
230
+
231
+ def __call__(self, input_ids: mx.array, mask: mx.array | None = None) -> mx.array:
232
+ """Encode text tokens.
233
+
234
+ Returns the second-to-last hidden state (hidden_states[-2]),
235
+ matching diffusers ZImagePipeline which uses
236
+ ``text_encoder(..., output_hidden_states=True).hidden_states[-2]``.
237
+
238
+ Applies a causal attention mask by default (matching HuggingFace
239
+ Qwen3Model which uses causal masking internally).
240
+
241
+ Args:
242
+ input_ids: (B, L) token IDs
243
+ mask: optional attention mask (B, 1, L, L) β€” None = auto causal mask
244
+
245
+ Returns:
246
+ hidden_states: (B, L, hidden_size) β€” penultimate layer output
247
+ """
248
+ B, L = input_ids.shape
249
+ x = self.embed_tokens(input_ids)
250
+
251
+ cos, sin = self.rotary_emb(L)
252
+
253
+ # Build causal mask if none provided (matches HuggingFace Qwen3Model)
254
+ if mask is None:
255
+ mask = mx.full((L, L), -1e9)
256
+ mask = mx.triu(mask, k=1) # upper triangle = -inf
257
+ mask = mask[None, None, :, :] # (1, 1, L, L)
258
+
259
+ n_layers = len(self.layers)
260
+ for i, layer in enumerate(self.layers):
261
+ x = layer(x, cos, sin, mask)
262
+ if i == n_layers - 2:
263
+ # Capture second-to-last layer output (no final norm)
264
+ penultimate = x
265
+
266
+ return penultimate
scheduler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlowMatch Euler Discrete Scheduler for Z-Image-Turbo.
2
+
3
+ Implements the ODE-based Flow Matching sampling schedule used by
4
+ Z-Image-Turbo (same as FLUX-schnell).
5
+
6
+ Config: shift=3.0, num_train_timesteps=1000
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import mlx.core as mx
12
+
13
+
14
+ class FlowMatchEulerScheduler:
15
+ """Flow Matching Euler sampler.
16
+
17
+ The forward diffusion maps data x0 to noise x1 via:
18
+ x_t = (1 - t) * x0 + t * noise
19
+
20
+ Denoising reverses this with Euler steps from t=1 β†’ t=0.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ num_train_timesteps: int = 1000,
26
+ shift: float = 3.0,
27
+ sigma_min: float = 0.002994012087583542,
28
+ ):
29
+ self.num_train_timesteps = num_train_timesteps
30
+ self.shift = shift
31
+ self.sigma_min = sigma_min
32
+
33
+ def get_sigmas(self, num_steps: int) -> mx.array:
34
+ """Compute sigma schedule matching diffusers FlowMatchEulerDiscreteScheduler.
35
+
36
+ diffusers logic:
37
+ 1. sigma_max=1.0, sigma_minβ‰ˆ0.00299 (not 0!)
38
+ 2. timesteps = linspace(sigma_max*1000, sigma_min*1000, num_steps)
39
+ 3. sigmas_raw = timesteps / 1000
40
+ 4. sigmas = shift * raw / (1 + (shift-1) * raw)
41
+ 5. append terminal sigma = 0
42
+ """
43
+ # Match diffusers: linspace from sigma_max to sigma_min, NOT to 0
44
+ timesteps = mx.linspace(
45
+ float(self.num_train_timesteps),
46
+ float(self.sigma_min * self.num_train_timesteps),
47
+ num_steps,
48
+ )
49
+ raw_sigmas = timesteps / float(self.num_train_timesteps)
50
+ sigmas = self.shift * raw_sigmas / (1.0 + (self.shift - 1.0) * raw_sigmas)
51
+ # Append terminal zero
52
+ sigmas = mx.concatenate([sigmas, mx.array([0.0])])
53
+ return sigmas
54
+
55
+ def unshift_sigma(self, shifted_sigma: float) -> float:
56
+ """Invert the shift to recover the original (unshifted) sigma.
57
+
58
+ shifted = shift * s / (1 + (shift-1) * s) β†’ s = shifted / (shift - (shift-1) * shifted)
59
+ """
60
+ if self.shift == 1.0:
61
+ return shifted_sigma
62
+ denom = self.shift - (self.shift - 1.0) * shifted_sigma
63
+ if denom == 0:
64
+ return 1.0
65
+ return shifted_sigma / denom
66
+
67
+ def step(
68
+ self,
69
+ model_output: mx.array,
70
+ sigma: float,
71
+ sigma_next: float,
72
+ sample: mx.array,
73
+ ) -> mx.array:
74
+ """Single Euler step.
75
+
76
+ v-prediction: x_{t-1} = x_t + (sigma_next - sigma) * v_pred
77
+ """
78
+ dt = sigma_next - sigma
79
+ return sample + dt * model_output
tokenizer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen2 Tokenizer adapter for Z-Image-Turbo.
2
+
3
+ Uses the `tokenizers` library directly for fast BPE tokenization,
4
+ avoiding the slow AutoTokenizer.from_pretrained() initialization.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ logger = logging.getLogger("zimage-mlx")
14
+
15
+
16
+ class Qwen2Tokenizer:
17
+ """Fast BPE tokenizer using tokenizers library."""
18
+
19
+ def __init__(self, model_path: Path):
20
+ from tokenizers import Tokenizer as HFTokenizer
21
+
22
+ tokenizer_path = model_path / "tokenizer"
23
+ json_file = tokenizer_path / "tokenizer.json"
24
+ if not json_file.exists():
25
+ json_file = model_path / "tokenizer.json"
26
+ if not json_file.exists():
27
+ raise FileNotFoundError(f"tokenizer.json not found in {model_path}")
28
+
29
+ self._tokenizer = HFTokenizer.from_file(str(json_file))
30
+
31
+ # Load chat template from tokenizer_config.json if available
32
+ config_file = tokenizer_path / "tokenizer_config.json"
33
+ if not config_file.exists():
34
+ config_file = model_path / "tokenizer_config.json"
35
+ self._chat_template = None
36
+ if config_file.exists():
37
+ with open(config_file) as f:
38
+ cfg = json.load(f)
39
+ self._chat_template = cfg.get("chat_template")
40
+
41
+ logger.info("[ZImage] Tokenizer loaded: vocab_size=%d", self._tokenizer.get_vocab_size())
42
+
43
+ def encode(self, text: str, max_length: int = 512) -> list[int]:
44
+ """Encode text to token IDs."""
45
+ encoded = self._tokenizer.encode(text)
46
+ ids = encoded.ids
47
+ if len(ids) > max_length:
48
+ ids = ids[:max_length]
49
+ return ids
50
+
51
+ def apply_chat_template(self, prompt: str, max_length: int = 512) -> dict:
52
+ """Apply Qwen3 chat template format and tokenize.
53
+
54
+ Wraps prompt in chat format:
55
+ <|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n
56
+
57
+ Returns dict with 'input_ids' and 'attention_mask'.
58
+ """
59
+ # Build chat-formatted text manually (Qwen3 chat template)
60
+ chat_text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
61
+ encoded = self._tokenizer.encode(chat_text)
62
+ ids = encoded.ids
63
+ if len(ids) > max_length:
64
+ ids = ids[:max_length]
65
+ attn_mask = [1] * len(ids)
66
+ return {"input_ids": ids, "attention_mask": attn_mask}
67
+
68
+ @property
69
+ def vocab_size(self) -> int:
70
+ return self._tokenizer.get_vocab_size()
weight_loader.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Weight loader for Z-Image-Turbo MLX backend.
2
+
3
+ Loads safetensors weights from HuggingFace cache and maps them
4
+ to the MLX module hierarchy.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import glob
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ import mlx.core as mx
14
+
15
+ logger = logging.getLogger("zimage-mlx")
16
+
17
+ # Default HF cache path for Z-Image-Turbo
18
+ _DEFAULT_MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
19
+ _HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
20
+
21
+ # Local weights directory (project-local, survives HF cache cleanup)
22
+ _LOCAL_WEIGHTS_DIR = Path(__file__).parent / "weights"
23
+
24
+
25
+ def _find_model_path(model_id: str = _DEFAULT_MODEL_ID) -> Path:
26
+ """Find local weight path for a model.
27
+
28
+ Priority:
29
+ 1. Project-local ``backends/mlx_zimage/weights/`` (if text_encoder/ exists)
30
+ 2. HF cache ``~/.cache/huggingface/hub/models--Tongyi-MAI--Z-Image-Turbo/``
31
+ """
32
+ # 1. Local weights directory
33
+ if _LOCAL_WEIGHTS_DIR.is_dir() and (_LOCAL_WEIGHTS_DIR / "text_encoder").is_dir():
34
+ logger.info("[ZImage] Using local weights: %s", _LOCAL_WEIGHTS_DIR)
35
+ return _LOCAL_WEIGHTS_DIR
36
+
37
+ # 2. HF cache
38
+ safe_id = model_id.replace("/", "--")
39
+ model_dir = _HF_CACHE / f"models--{safe_id}"
40
+ if not model_dir.exists():
41
+ raise FileNotFoundError(
42
+ f"Model not found. Neither local ({_LOCAL_WEIGHTS_DIR}) "
43
+ f"nor HF cache ({model_dir}) available."
44
+ )
45
+ # Find the latest snapshot
46
+ snapshots = sorted(model_dir.glob("snapshots/*"), key=lambda p: p.stat().st_mtime, reverse=True)
47
+ if not snapshots:
48
+ raise FileNotFoundError(f"No snapshots found in {model_dir}")
49
+ logger.info("[ZImage] Using HF cache: %s", snapshots[0])
50
+ return snapshots[0]
51
+
52
+
53
+ def _log_memory(label: str) -> None:
54
+ """Log Metal memory usage (safe no-op if unavailable)."""
55
+ try:
56
+ active = mx.metal.get_active_memory() / (1024 ** 3)
57
+ peak = mx.metal.get_peak_memory() / (1024 ** 3)
58
+ logger.info("[ZImage] MEM %s: active=%.2f GB, peak=%.2f GB", label, active, peak)
59
+ except Exception:
60
+ pass # mx.metal not available (e.g. CI / non-Apple)
61
+
62
+
63
+ def _load_safetensors_shards(
64
+ shard_dir: Path,
65
+ pattern: str = "*.safetensors",
66
+ *,
67
+ key_filter: str | None = None,
68
+ ) -> dict[str, mx.array]:
69
+ """Load safetensors files via mx.load() β€” zero-copy, preserves bfloat16.
70
+
71
+ Args:
72
+ shard_dir: Directory containing safetensors shard files.
73
+ pattern: Glob pattern for shard files.
74
+ key_filter: If set, only load keys starting with this prefix.
75
+ """
76
+ files = sorted(shard_dir.glob(pattern))
77
+ if not files:
78
+ raise FileNotFoundError(f"No safetensors files in {shard_dir}")
79
+
80
+ params: dict[str, mx.array] = {}
81
+ for f in files:
82
+ # mx.load() natively reads safetensors β†’ mx.array (preserves bfloat16)
83
+ shard = mx.load(str(f))
84
+ if key_filter:
85
+ shard = {k: v for k, v in shard.items() if k.startswith(key_filter)}
86
+ params.update(shard)
87
+ logger.info("[ZImage] Loaded shard %s (%d keys)", f.name, len(shard))
88
+
89
+ logger.info("[ZImage] Total: %d keys from %d files in %s", len(params), len(files), shard_dir.name)
90
+ _log_memory(f"after loading {shard_dir.name}")
91
+ return params
92
+
93
+
94
+ # ── Text Encoder weight mapping ──────────────────────────────────
95
+
96
+ def load_text_encoder_weights(model_path: Path | None = None) -> dict[str, mx.array]:
97
+ """Load and map Qwen3 text encoder weights for MLX.
98
+
99
+ The safetensors keys use the pattern:
100
+ model.embed_tokens.weight
101
+ model.layers.N.input_layernorm.weight
102
+ model.layers.N.self_attn.q_proj.weight
103
+ ...
104
+ model.norm.weight
105
+
106
+ Our MLX module uses:
107
+ embed_tokens.weight
108
+ layers.N.input_layernorm.weight
109
+ layers.N.self_attn.q_proj.weight
110
+ ...
111
+ norm.weight
112
+
113
+ So we strip the leading "model." prefix.
114
+ """
115
+ if model_path is None:
116
+ model_path = _find_model_path()
117
+
118
+ te_dir = model_path / "text_encoder"
119
+ raw = _load_safetensors_shards(te_dir, "model-*.safetensors")
120
+
121
+ mapped: dict[str, mx.array] = {}
122
+ for key, tensor in raw.items():
123
+ # Strip "model." prefix
124
+ if key.startswith("model."):
125
+ new_key = key[len("model."):]
126
+ else:
127
+ new_key = key
128
+
129
+ mapped[new_key] = tensor
130
+
131
+ logger.info("[ZImage] Text encoder: %d parameters mapped", len(mapped))
132
+ return mapped
133
+
134
+
135
+ # ── Transformer weight mapping ───────────────────────────────────
136
+
137
+ def load_transformer_weights(model_path: Path | None = None) -> dict[str, mx.array]:
138
+ """Load ZImageTransformer2DModel weights."""
139
+ if model_path is None:
140
+ model_path = _find_model_path()
141
+
142
+ dit_dir = model_path / "transformer"
143
+ raw = _load_safetensors_shards(dit_dir, "diffusion_pytorch_model-*.safetensors")
144
+
145
+ # Keys are already flat (no "model." prefix), use as-is
146
+ logger.info("[ZImage] Transformer: %d parameters loaded", len(raw))
147
+ return raw
148
+
149
+
150
+ # ── VAE weight mapping ───────────────────────────────────────────
151
+
152
+ def load_vae_weights(model_path: Path | None = None) -> dict[str, mx.array]:
153
+ """Load AutoencoderKL weights."""
154
+ if model_path is None:
155
+ model_path = _find_model_path()
156
+
157
+ vae_dir = model_path / "vae"
158
+ raw = _load_safetensors_shards(vae_dir)
159
+
160
+ logger.info("[ZImage] VAE: %d parameters loaded", len(raw))
161
+ return raw
162
+
163
+
164
+ def load_vae_decoder_weights(model_path: Path | None = None) -> list[tuple[str, mx.array]]:
165
+ """Load VAE decoder weights, mapped for the MLX Decoder module.
166
+
167
+ Only loads keys starting with ``decoder.`` (skips encoder weights
168
+ to avoid wasting memory). Performs two transformations:
169
+ 1. Strips the ``decoder.`` prefix so keys match the Decoder module tree.
170
+ 2. Transposes Conv2d weights from PyTorch (O,I,kH,kW) β†’ MLX (O,kH,kW,I).
171
+
172
+ Returns a list of (key, array) tuples ready for ``Decoder.load_weights()``.
173
+ """
174
+ if model_path is None:
175
+ model_path = _find_model_path()
176
+
177
+ vae_dir = model_path / "vae"
178
+ # Only load decoder.* keys β€” skip encoder weights entirely
179
+ raw = _load_safetensors_shards(vae_dir, key_filter="decoder.")
180
+
181
+ weights: list[tuple[str, mx.array]] = []
182
+ for key, val in raw.items():
183
+ key = key[len("decoder."):]
184
+
185
+ # Conv2d weight: (O, I, kH, kW) β†’ (O, kH, kW, I)
186
+ if val.ndim == 4:
187
+ val = val.transpose(0, 2, 3, 1)
188
+
189
+ # force_upcast: ensure float32 for numerical stability
190
+ val = val.astype(mx.float32)
191
+
192
+ weights.append((key, val))
193
+
194
+ logger.info("[ZImage] VAE decoder: %d parameters mapped", len(weights))
195
+ return weights
zimage_dit.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ZImageTransformer2DModel β€” MLX native S3-DiT for Z-Image-Turbo.
2
+
3
+ Architecture (from model config + weight shapes):
4
+ - 30 main DiT layers + 2 context_refiner + 2 noise_refiner
5
+ - dim=3840, n_heads=30, head_dim=128
6
+ - Dual-norm (pre+post) for both attention and FFN
7
+ - SwiGLU FFN (w1/w2/w3), intermediate=10240
8
+ - QK-Norm (RMSNorm on head_dim=128)
9
+ - AdaLN modulation: 4 outputs per block (shift_attn, scale_attn, shift_ffn, scale_ffn)
10
+ - N-dim RoPE: axes_dims=[32,48,48], rope_theta=256
11
+ - Timestep embedding: sinusoidal(256) β†’ MLP(256β†’1024β†’256)
12
+ - Caption projector: RMSNorm(2560) β†’ Linear(2560β†’3840)
13
+ - Patch embed: Linear(64β†’3840) (in_channels=16, patch_size=2 β†’ 16Γ—2Β²=64)
14
+ - Final layer: adaLN(256β†’3840) + Linear(3840β†’64)
15
+
16
+ Weight key patterns:
17
+ t_embedder.mlp.{0,2}.{weight,bias}
18
+ cap_embedder.{0,1}.{weight,bias} (0=RMSNorm, 1=Linear)
19
+ cap_pad_token, x_pad_token
20
+ all_x_embedder.2-1.{weight,bias}
21
+ layers.N.{adaLN_modulation.0, attention.*, attention_norm*, feed_forward.*, ffn_norm*}
22
+ context_refiner.N.{attention.*, attention_norm*, feed_forward.*, ffn_norm*}
23
+ noise_refiner.N.{adaLN_modulation.0, attention.*, attention_norm*, feed_forward.*, ffn_norm*}
24
+ all_final_layer.2-1.{linear, adaLN_modulation.1}
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import math
30
+ from dataclasses import dataclass, field
31
+
32
+ import mlx.core as mx
33
+ import mlx.nn as nn
34
+
35
+
36
+ # ── Config ────────────────────────────────────────────────────────
37
+
38
+ @dataclass
39
+ class ZImageDiTConfig:
40
+ dim: int = 3840
41
+ n_heads: int = 30
42
+ n_kv_heads: int = 30
43
+ n_layers: int = 30
44
+ n_refiner_layers: int = 2
45
+ head_dim: int = 128
46
+ ffn_dim: int = 10240
47
+ in_channels: int = 16
48
+ patch_size: int = 2
49
+ cap_feat_dim: int = 2560 # Qwen3 hidden_size
50
+ t_embed_dim: int = 256 # timestep embedding dim
51
+ t_hidden_dim: int = 1024 # timestep MLP hidden
52
+ axes_dims: list[int] = field(default_factory=lambda: [32, 48, 48])
53
+ axes_lens: list[int] = field(default_factory=lambda: [1536, 512, 512])
54
+ rope_theta: float = 256.0
55
+ norm_eps: float = 1e-5
56
+ qk_norm: bool = True
57
+ t_scale: float = 1000.0
58
+
59
+
60
+ # ── RMSNorm ───────────────────────────────────────────────────────
61
+
62
+ class RMSNorm(nn.Module):
63
+ def __init__(self, dim: int, eps: float = 1e-5):
64
+ super().__init__()
65
+ self.weight = mx.ones((dim,))
66
+ self.eps = eps
67
+
68
+ def __call__(self, x: mx.array) -> mx.array:
69
+ return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps) * self.weight
70
+
71
+
72
+ # ── Timestep Embedding ────────────────────────────────────────────
73
+
74
+ def timestep_embedding(t: mx.array, dim: int = 256) -> mx.array:
75
+ """Sinusoidal timestep embedding."""
76
+ half = dim // 2
77
+ freqs = mx.exp(-math.log(10000.0) * mx.arange(half, dtype=mx.float32) / half)
78
+ args = t[:, None].astype(mx.float32) * freqs[None, :]
79
+ return mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1)
80
+
81
+
82
+ class TimestepEmbedder(nn.Module):
83
+ """Sinusoidal β†’ MLP timestep embedder: sin(t) β†’ Linear β†’ SiLU β†’ Linear."""
84
+ def __init__(self, t_embed_dim: int = 256, hidden_dim: int = 1024):
85
+ super().__init__()
86
+ self.mlp = [
87
+ nn.Linear(t_embed_dim, hidden_dim), # mlp.0
88
+ None, # SiLU (index 1, not a layer)
89
+ nn.Linear(hidden_dim, t_embed_dim), # mlp.2
90
+ ]
91
+
92
+ def __call__(self, t: mx.array) -> mx.array:
93
+ x = timestep_embedding(t, self.mlp[0].weight.shape[1])
94
+ x = nn.silu(self.mlp[0](x))
95
+ x = self.mlp[2](x)
96
+ return x
97
+
98
+
99
+ # ── N-dim RoPE (matches diffusers RopeEmbedder) ──────────────────
100
+
101
+ class RopeEmbedder:
102
+ """Precomputed per-axis frequency tables, indexed by position IDs.
103
+
104
+ Matches diffusers ``RopeEmbedder``:
105
+ 1. Precompute complex frequencies per axis (as real angle tables here)
106
+ 2. At forward time, gather from tables using integer position IDs
107
+ 3. Concatenate per-axis results β†’ (seq_len, sum(axes_dims)//2)
108
+
109
+ The returned angles are used with :func:`apply_rope` which does the
110
+ equivalent of ``torch.view_as_complex(x) * polar(1, angles)`` using
111
+ real-valued cos/sin operations.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ axes_dims: list[int],
117
+ axes_lens: list[int],
118
+ theta: float = 256.0,
119
+ ):
120
+ self.axes_dims = axes_dims
121
+ self.axes_lens = axes_lens
122
+ self.theta = theta
123
+ # Precompute per-axis frequency tables
124
+ self._freq_tables: list[mx.array] = []
125
+ for d, e in zip(axes_dims, axes_lens):
126
+ inv_freq = 1.0 / (theta ** (mx.arange(0, d, 2, dtype=mx.float32) / d))
127
+ timestep = mx.arange(e, dtype=mx.float32)
128
+ freqs = mx.outer(timestep, inv_freq) # (e, d/2)
129
+ self._freq_tables.append(freqs)
130
+
131
+ def __call__(self, pos_ids: mx.array) -> mx.array:
132
+ """Look up RoPE angles from precomputed tables.
133
+
134
+ Args:
135
+ pos_ids: (seq_len, 3) integer position IDs β€” one per axis.
136
+
137
+ Returns:
138
+ (seq_len, rope_half_dim) rotation angles.
139
+ """
140
+ parts = []
141
+ for i in range(len(self.axes_dims)):
142
+ idx = pos_ids[:, i].astype(mx.int32)
143
+ parts.append(self._freq_tables[i][idx]) # (seq_len, d_i/2)
144
+ return mx.concatenate(parts, axis=-1)
145
+
146
+
147
+ def build_position_ids(
148
+ cap_len: int,
149
+ pH: int,
150
+ pW: int,
151
+ ) -> tuple[mx.array, mx.array]:
152
+ """Build position ID grids matching diffusers patchify_and_embed.
153
+
154
+ Caption tokens: ``create_coordinate_grid(size=(cap_len, 1, 1), start=(1, 0, 0))``
155
+ β†’ t-axis = 1..cap_len, h-axis = 0, w-axis = 0
156
+
157
+ Image tokens: ``create_coordinate_grid(size=(1, pH, pW), start=(cap_len+1, 0, 0))``
158
+ β†’ t-axis = cap_len+1, h-axis = 0..pH-1, w-axis = 0..pW-1
159
+
160
+ Returns:
161
+ (img_pos_ids, cap_pos_ids) each of shape (N, 3)
162
+ """
163
+ # Caption: (cap_len, 3) β€” t varies, h=0, w=0
164
+ cap_t = mx.arange(1, cap_len + 1, dtype=mx.int32)[:, None] # (cap_len, 1)
165
+ cap_hw = mx.zeros((cap_len, 2), dtype=mx.int32)
166
+ cap_pos = mx.concatenate([cap_t, cap_hw], axis=-1) # (cap_len, 3)
167
+
168
+ # Image: (pH*pW, 3) β€” t=cap_len+1, h and w vary
169
+ t_val = cap_len + 1
170
+ img_ids = []
171
+ for h in range(pH):
172
+ for w in range(pW):
173
+ img_ids.append([t_val, h, w])
174
+ img_pos = mx.array(img_ids, dtype=mx.int32) # (pH*pW, 3)
175
+
176
+ return img_pos, cap_pos
177
+
178
+
179
+ def apply_rope(x: mx.array, freqs: mx.array) -> mx.array:
180
+ """Apply rotary position embedding using interleaved pairing.
181
+
182
+ Equivalent to diffusers' complex multiplication:
183
+ ``x_complex = view_as_complex(x.reshape(..., -1, 2))``
184
+ ``x_out = view_as_real(x_complex * freqs_cis).flatten()``
185
+
186
+ x: (B, n_heads, L, head_dim)
187
+ freqs: (L, rope_half_dim) where rope_half_dim = sum(axes_dims)//2
188
+ """
189
+ rope_half_dim = freqs.shape[-1]
190
+ rope_dim = rope_half_dim * 2
191
+ x_rope = x[..., :rope_dim]
192
+ x_pass = x[..., rope_dim:]
193
+
194
+ cos = mx.cos(freqs)[None, None, :, :] # (1, 1, L, rope_half_dim)
195
+ sin = mx.sin(freqs)[None, None, :, :]
196
+
197
+ # Interleaved pairing: (x[0], x[1]), (x[2], x[3]), ...
198
+ x_even = x_rope[..., 0::2] # even indices β†’ "real"
199
+ x_odd = x_rope[..., 1::2] # odd indices β†’ "imag"
200
+
201
+ out_even = x_even * cos - x_odd * sin
202
+ out_odd = x_even * sin + x_odd * cos
203
+
204
+ # Interleave back: [re0, im0, re1, im1, ...]
205
+ out = mx.stack([out_even, out_odd], axis=-1) # (..., rope_half_dim, 2)
206
+ x_rope = out.reshape(*out.shape[:-2], rope_dim)
207
+
208
+ return mx.concatenate([x_rope, x_pass], axis=-1)
209
+
210
+
211
+ # ── Attention Block ───────────────────────────────────────────────
212
+
213
+ class DiTAttention(nn.Module):
214
+ """Self-attention with QK-Norm and optional RoPE."""
215
+
216
+ def __init__(self, dim: int, n_heads: int, head_dim: int, qk_norm: bool = True, norm_eps: float = 1e-5):
217
+ super().__init__()
218
+ self.n_heads = n_heads
219
+ self.head_dim = head_dim
220
+ self.to_q = nn.Linear(dim, n_heads * head_dim, bias=False)
221
+ self.to_k = nn.Linear(dim, n_heads * head_dim, bias=False)
222
+ self.to_v = nn.Linear(dim, n_heads * head_dim, bias=False)
223
+ self.to_out = [nn.Linear(n_heads * head_dim, dim, bias=False)] # to_out.0
224
+
225
+ if qk_norm:
226
+ self.norm_q = RMSNorm(head_dim, eps=norm_eps)
227
+ self.norm_k = RMSNorm(head_dim, eps=norm_eps)
228
+ else:
229
+ self.norm_q = None
230
+ self.norm_k = None
231
+
232
+ def __call__(self, x: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array:
233
+ B, L, _ = x.shape
234
+ q = self.to_q(x).reshape(B, L, self.n_heads, self.head_dim)
235
+ k = self.to_k(x).reshape(B, L, self.n_heads, self.head_dim)
236
+ v = self.to_v(x).reshape(B, L, self.n_heads, self.head_dim)
237
+
238
+ # QK-Norm
239
+ if self.norm_q is not None:
240
+ q = self.norm_q(q)
241
+ k = self.norm_k(k)
242
+
243
+ # (B, n_heads, L, head_dim)
244
+ q = q.transpose(0, 2, 1, 3)
245
+ k = k.transpose(0, 2, 1, 3)
246
+ v = v.transpose(0, 2, 1, 3)
247
+
248
+ # RoPE
249
+ if freqs is not None:
250
+ q = apply_rope(q, freqs)
251
+ k = apply_rope(k, freqs)
252
+
253
+ # Fused scaled dot-product attention (Metal kernel, no NxN materialization)
254
+ scale = 1.0 / math.sqrt(self.head_dim)
255
+ if mask is not None:
256
+ # Convert boolean mask (B, L) to additive mask for fused attention
257
+ attn_mask = mask[:, None, None, :].astype(q.dtype)
258
+ attn_mask = (1.0 - attn_mask) * (-1e9)
259
+ out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=attn_mask)
260
+ else:
261
+ out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
262
+
263
+ out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
264
+ return self.to_out[0](out)
265
+
266
+
267
+ # ── SwiGLU FFN ────────────────────────────────────────────────────
268
+
269
+ class SwiGLUFFN(nn.Module):
270
+ """SwiGLU feed-forward: gate * silu(w1(x)) + w3(x) β†’ w2."""
271
+ def __init__(self, dim: int, ffn_dim: int):
272
+ super().__init__()
273
+ self.w1 = nn.Linear(dim, ffn_dim, bias=False) # gate
274
+ self.w2 = nn.Linear(ffn_dim, dim, bias=False) # down
275
+ self.w3 = nn.Linear(dim, ffn_dim, bias=False) # up
276
+
277
+ def __call__(self, x: mx.array) -> mx.array:
278
+ return self.w2(nn.silu(self.w1(x)) * self.w3(x))
279
+
280
+
281
+ # ── AdaLN Modulation ─────────────────────────────────────────────
282
+
283
+ class AdaLNModulation(nn.Module):
284
+ """AdaLN-Zero: project conditioning to shift/scale pairs.
285
+
286
+ Output dim = dim * n_mods (e.g. 3840 * 4 = 15360 for main blocks).
287
+ """
288
+ def __init__(self, cond_dim: int, out_dim: int):
289
+ super().__init__()
290
+ # Weight key is adaLN_modulation.0 (index 0 in a Sequential-like list)
291
+ self._linear = nn.Linear(cond_dim, out_dim)
292
+
293
+ # Expose as list for weight loading: adaLN_modulation.0.weight/bias
294
+ @property
295
+ def parameters(self):
296
+ return {"0": {"weight": self._linear.weight, "bias": self._linear.bias}}
297
+
298
+ def __call__(self, c: mx.array) -> mx.array:
299
+ return self._linear(c)
300
+
301
+
302
+ # ── DiT Block (main layers + noise_refiner) ──────────────────────
303
+
304
+ class DiTBlock(nn.Module):
305
+ """S3-DiT block with AdaLN modulation.
306
+
307
+ 4 modulations: shift_attn, scale_attn, shift_ffn, scale_ffn
308
+ Dual-norm: pre-norm + post-norm for both attention and FFN.
309
+ """
310
+
311
+ def __init__(self, cfg: ZImageDiTConfig):
312
+ super().__init__()
313
+ self.attention = DiTAttention(cfg.dim, cfg.n_heads, cfg.head_dim, cfg.qk_norm, cfg.norm_eps)
314
+ self.attention_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # pre-attn norm
315
+ self.attention_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # post-attn norm
316
+ self.feed_forward = SwiGLUFFN(cfg.dim, cfg.ffn_dim)
317
+ self.ffn_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # pre-ffn norm
318
+ self.ffn_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps) # post-ffn norm
319
+
320
+ # AdaLN: 4 modulation signals (shift_a, scale_a, shift_f, scale_f)
321
+ self.adaLN_modulation = [nn.Linear(cfg.t_embed_dim, cfg.dim * 4)]
322
+
323
+ def __call__(self, x: mx.array, c: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array:
324
+ """
325
+ Args:
326
+ x: (B, L, dim) hidden states
327
+ c: (B, t_embed_dim) conditioning (timestep embedding)
328
+ freqs: optional RoPE frequencies for image tokens
329
+ mask: optional (B, L) boolean attention mask
330
+ """
331
+ # Compute modulation from conditioning
332
+ mod = self.adaLN_modulation[0](c) # (B, dim*4)
333
+ scale_msa, gate_msa, scale_mlp, gate_mlp = mx.split(mod, 4, axis=-1)
334
+
335
+ gate_msa = mx.tanh(gate_msa)
336
+ gate_mlp = mx.tanh(gate_mlp)
337
+ scale_msa = 1.0 + scale_msa
338
+ scale_mlp = 1.0 + scale_mlp
339
+
340
+ scale_msa = scale_msa[:, None, :]
341
+ gate_msa = gate_msa[:, None, :]
342
+ scale_mlp = scale_mlp[:, None, :]
343
+ gate_mlp = gate_mlp[:, None, :]
344
+
345
+ attn_out = self.attention(self.attention_norm1(x) * scale_msa, freqs, mask)
346
+ x = x + gate_msa * self.attention_norm2(attn_out)
347
+
348
+ x = x + gate_mlp * self.ffn_norm2(
349
+ self.feed_forward(self.ffn_norm1(x) * scale_mlp)
350
+ )
351
+
352
+ return x
353
+
354
+
355
+ # ── Refiner Block (context_refiner β€” no AdaLN) ──────────────────
356
+
357
+ class RefinerBlock(nn.Module):
358
+ """Refiner block WITHOUT AdaLN modulation (used for context_refiner)."""
359
+
360
+ def __init__(self, cfg: ZImageDiTConfig):
361
+ super().__init__()
362
+ self.attention = DiTAttention(cfg.dim, cfg.n_heads, cfg.head_dim, cfg.qk_norm, cfg.norm_eps)
363
+ self.attention_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
364
+ self.attention_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
365
+ self.feed_forward = SwiGLUFFN(cfg.dim, cfg.ffn_dim)
366
+ self.ffn_norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
367
+ self.ffn_norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
368
+
369
+ def __call__(self, x: mx.array, freqs: mx.array | None = None, mask: mx.array | None = None) -> mx.array:
370
+ h = self.attention_norm1(x)
371
+ h = self.attention(h, freqs, mask)
372
+ h = self.attention_norm2(h)
373
+ x = x + h
374
+
375
+ h = self.ffn_norm1(x)
376
+ h = self.feed_forward(h)
377
+ h = self.ffn_norm2(h)
378
+ x = x + h
379
+
380
+ return x
381
+
382
+
383
+ # ── Final Layer ───────────────────────────────────────────────────
384
+
385
+ class FinalLayer(nn.Module):
386
+ """Final projection: LayerNorm + adaLN scale + Linear(dim β†’ patch_dim)."""
387
+ def __init__(self, dim: int, patch_dim: int, t_embed_dim: int):
388
+ super().__init__()
389
+ self.linear = nn.Linear(dim, patch_dim)
390
+ # adaLN_modulation.1 β€” SiLU + Linear (SiLU at index 0, Linear at index 1)
391
+ self.adaLN_modulation = [None, nn.Linear(t_embed_dim, dim)]
392
+
393
+ def __call__(self, x: mx.array, c: mx.array) -> mx.array:
394
+ # SiLU is part of FinalLayer's adaLN_modulation (unlike DiTBlock)
395
+ scale = 1.0 + self.adaLN_modulation[1](nn.silu(c)) # (B, dim)
396
+ scale = scale[:, None, :] # (B, 1, dim)
397
+
398
+ # LayerNorm (no learnable params) + scale + linear
399
+ x = mx.fast.layer_norm(x, None, None, eps=1e-6)
400
+ x = x * scale
401
+ x = self.linear(x)
402
+ return x
403
+
404
+
405
+ # ── Full ZImage Transformer ──────────────────────────────────────
406
+
407
+ class ZImageTransformer(nn.Module):
408
+ """ZImageTransformer2DModel β€” S3-DiT for Z-Image-Turbo.
409
+
410
+ Forward flow:
411
+ 1. Embed timestep β†’ t_emb (B, 256)
412
+ 2. Project caption features: RMSNorm + Linear β†’ cap_emb (B, L_text, 3840)
413
+ 3. Patchify + embed image latents β†’ x_emb (B, L_img, 3840)
414
+ 4. Concatenate [cap_emb, x_emb] β†’ full sequence
415
+ 5. Context refiner (2 blocks, no AdaLN)
416
+ 6. Split β†’ img tokens get RoPE, cap tokens don't
417
+ 7. Main DiT layers (30 blocks, with AdaLN)
418
+ 8. Noise refiner (2 blocks, with AdaLN)
419
+ 9. Extract image tokens β†’ final layer β†’ unpatchify
420
+ """
421
+
422
+ def __init__(self, cfg: ZImageDiTConfig | None = None):
423
+ super().__init__()
424
+ if cfg is None:
425
+ cfg = ZImageDiTConfig()
426
+ self.cfg = cfg
427
+
428
+ # Timestep embedder
429
+ self.t_embedder = TimestepEmbedder(cfg.t_embed_dim, cfg.t_hidden_dim)
430
+
431
+ # Caption projector: cap_embedder.0 = RMSNorm, cap_embedder.1 = Linear
432
+ self.cap_embedder = [
433
+ RMSNorm(cfg.cap_feat_dim, eps=cfg.norm_eps),
434
+ nn.Linear(cfg.cap_feat_dim, cfg.dim),
435
+ ]
436
+
437
+ # Learnable padding tokens
438
+ self.cap_pad_token = mx.zeros((1, cfg.dim))
439
+ self.x_pad_token = mx.zeros((1, cfg.dim))
440
+
441
+ # Image patch embedder β€” key uses "2-1" suffix for patch_size=2
442
+ # We store as a dict to match weight key `all_x_embedder.2-1.{weight,bias}`
443
+ patch_dim = cfg.in_channels * cfg.patch_size * cfg.patch_size # 16 * 4 = 64
444
+ self.all_x_embedder = {"2-1": nn.Linear(patch_dim, cfg.dim)}
445
+
446
+ # Context refiner (no AdaLN)
447
+ self.context_refiner = [RefinerBlock(cfg) for _ in range(cfg.n_refiner_layers)]
448
+
449
+ # Main DiT layers (with AdaLN)
450
+ self.layers = [DiTBlock(cfg) for _ in range(cfg.n_layers)]
451
+
452
+ # Noise refiner (with AdaLN)
453
+ self.noise_refiner = [DiTBlock(cfg) for _ in range(cfg.n_refiner_layers)]
454
+
455
+ # Final layer β€” key uses "2-1" suffix
456
+ self.all_final_layer = {
457
+ "2-1": FinalLayer(cfg.dim, patch_dim, cfg.t_embed_dim)
458
+ }
459
+
460
+ # Precomputed RoPE frequency tables (matches diffusers RopeEmbedder)
461
+ self._rope = RopeEmbedder(cfg.axes_dims, cfg.axes_lens, cfg.rope_theta)
462
+
463
+ def _patchify(self, x: mx.array) -> mx.array:
464
+ """Convert image latents to patch sequence.
465
+
466
+ Matches diffusers: channels-last within each patch.
467
+ x: (B, C, H, W) β†’ (B, H//p * W//p, p*p*C)
468
+
469
+ diffusers logic:
470
+ image.view(C, 1, 1, h, pH, w, pW)
471
+ image.permute(1, 3, 5, 2, 4, 6, 0) # (1, h, w, 1, pH, pW, C)
472
+ reshape β†’ (h*w, pH*pW*C)
473
+ """
474
+ B, C, H, W = x.shape
475
+ p = self.cfg.patch_size
476
+ pH, pW = H // p, W // p
477
+ # (B, C, pH, p, pW, p)
478
+ x = x.reshape(B, C, pH, p, pW, p)
479
+ # β†’ (B, pH, pW, p, p, C) β€” channels LAST per patch
480
+ x = x.transpose(0, 2, 4, 3, 5, 1)
481
+ # β†’ (B, pH*pW, p*p*C)
482
+ x = x.reshape(B, pH * pW, p * p * C)
483
+ return x
484
+
485
+ def _unpatchify(self, x: mx.array, h: int, w: int) -> mx.array:
486
+ """Convert patch sequence back to image latents.
487
+
488
+ Matches diffusers: channels-last within each patch.
489
+ x: (B, pH*pW, p*p*C) β†’ (B, C, H, W)
490
+
491
+ diffusers logic:
492
+ x.view(1, h, w, 1, pH, pW, C)
493
+ x.permute(6, 0, 3, 1, 4, 2, 5) # (C, 1, 1, h, pH, w, pW)
494
+ reshape β†’ (C, H, W)
495
+ """
496
+ B = x.shape[0]
497
+ p = self.cfg.patch_size
498
+ C = self.cfg.in_channels
499
+ pH, pW = h // p, w // p
500
+ # (B, pH, pW, p, p, C)
501
+ x = x.reshape(B, pH, pW, p, p, C)
502
+ # β†’ (B, C, pH, p, pW, p)
503
+ x = x.transpose(0, 5, 1, 3, 2, 4)
504
+ # β†’ (B, C, H, W)
505
+ x = x.reshape(B, C, h, w)
506
+ return x
507
+
508
+ def __call__(
509
+ self,
510
+ x: mx.array,
511
+ t: mx.array,
512
+ cap_feats: mx.array,
513
+ cap_mask: mx.array | None = None,
514
+ ) -> mx.array:
515
+ """Forward pass β€” matches diffusers ZImageTransformer2DModel.forward().
516
+
517
+ Correct execution order (from diffusers source):
518
+ 1. t_embed
519
+ 2. x_embed β†’ noise_refiner (image tokens with RoPE)
520
+ 3. cap_embed β†’ context_refiner (text tokens with RoPE)
521
+ 4. build unified [img, cap] sequence (IMAGE FIRST in basic mode)
522
+ 5. main layers (30 blocks with AdaLN + RoPE)
523
+ 6. final_layer on FULL unified sequence
524
+ 7. extract image tokens β†’ unpatchify
525
+
526
+ Args:
527
+ x: (B, C, H, W) noisy latents
528
+ t: (B,) timesteps (1-sigma, scaled by pipeline)
529
+ cap_feats: (B, L_text, cap_feat_dim) text encoder hidden states
530
+ cap_mask: (B, L_text) boolean mask for padding
531
+
532
+ Returns:
533
+ noise_pred: (B, C, H, W) predicted noise
534
+ """
535
+ B, C, H, W = x.shape
536
+ cfg = self.cfg
537
+ p = cfg.patch_size
538
+ pH, pW = H // p, W // p
539
+
540
+ # 1. Timestep embedding β†’ adaln_input
541
+ adaln_input = self.t_embedder(t * cfg.t_scale) # (B, 256)
542
+
543
+ # 2. Patchify + embed image latents
544
+ img = self._patchify(x) # (B, pH*pW, patch_dim=64)
545
+ img = self.all_x_embedder["2-1"](img) # (B, pH*pW, dim=3840)
546
+
547
+ L_cap_orig = cap_feats.shape[1]
548
+ L_img = img.shape[1]
549
+
550
+ # Pad caption to SEQ_MULTI_OF=32 (matching diffusers _pad_with_ids)
551
+ SEQ_MULTI_OF = 32
552
+ pad_len = (-L_cap_orig) % SEQ_MULTI_OF
553
+ L_cap = L_cap_orig + pad_len
554
+
555
+ # Build position IDs matching diffusers (cap: t=1..L_cap_orig, img: t=L_cap_orig+1)
556
+ # NOTE: position IDs use original cap length (not padded), padding tokens get (0,0,0) IDs
557
+ img_pos_ids, cap_pos_ids = build_position_ids(L_cap_orig, pH, pW)
558
+
559
+ # Look up RoPE frequencies from precomputed tables
560
+ img_freqs = self._rope(img_pos_ids) # (L_img, rope_half_dim)
561
+ cap_freqs_orig = self._rope(cap_pos_ids) # (L_cap_orig, rope_half_dim)
562
+
563
+ # Pad cap RoPE freqs with zeros for padding positions (same as diffusers)
564
+ if pad_len > 0:
565
+ cap_freqs = mx.concatenate([
566
+ cap_freqs_orig,
567
+ mx.zeros((pad_len, cap_freqs_orig.shape[-1]))
568
+ ], axis=0)
569
+ else:
570
+ cap_freqs = cap_freqs_orig
571
+
572
+ # noise_refiner on image tokens (with AdaLN, with RoPE)
573
+ for block in self.noise_refiner:
574
+ img = block(img, adaln_input, img_freqs)
575
+
576
+ # 3. Caption embedding (cap_embedder is RMSNorm then Linear)
577
+ cap = self.cap_embedder[0](cap_feats) # RMSNorm
578
+ cap = self.cap_embedder[1](cap) # Linear β†’ (B, L_cap_orig, dim=3840)
579
+
580
+ # Pad caption with cap_pad_token (matching diffusers _pad_with_ids).
581
+ # In diffusers, ALL tokens (real + pad) attend to each other fully β€”
582
+ # cap_pad_token is a learned vector, not masked out. The diffusers
583
+ # "attn_mask" is only for batch-level padding (all-True for BS=1).
584
+ if pad_len > 0:
585
+ pad_tok = mx.broadcast_to(self.cap_pad_token, (B, pad_len, cfg.dim))
586
+ cap = mx.concatenate([cap, pad_tok], axis=1) # (B, L_cap, dim)
587
+
588
+ # context_refiner on text tokens (no AdaLN, WITH RoPE, no mask needed)
589
+ for block in self.context_refiner:
590
+ cap = block(cap, cap_freqs)
591
+
592
+ # 4. Build unified sequence [img, cap] β€” IMAGE FIRST (diffusers basic mode)
593
+ unified = mx.concatenate([img, cap], axis=1) # (B, L_img + L_cap, dim)
594
+ unified_freqs = mx.concatenate([img_freqs, cap_freqs], axis=0)
595
+
596
+ # 5. Main DiT layers (30 blocks, with AdaLN conditioning + RoPE)
597
+ for block in self.layers:
598
+ unified = block(unified, adaln_input, unified_freqs)
599
+
600
+ # 6. Final layer on FULL unified sequence (as diffusers does)
601
+ unified = self.all_final_layer["2-1"](unified, adaln_input)
602
+
603
+ # 7. Extract image tokens (first L_img tokens) and unpatchify
604
+ img_out = unified[:, :L_img, :] # (B, L_img, patch_dim=64)
605
+ out = self._unpatchify(img_out, H, W)
606
+ return out