illusion615 commited on
Commit
31f3da5
Β·
verified Β·
1 Parent(s): 8e1e444

Upload folder using huggingface_hub

Browse files
Files changed (11) hide show
  1. README.md +109 -0
  2. __init__.py +6 -0
  3. autoencoder.py +188 -0
  4. clip_encoder.py +155 -0
  5. download_weights.py +41 -0
  6. flux_model.py +447 -0
  7. pipeline.py +410 -0
  8. sampler.py +125 -0
  9. t5_encoder.py +226 -0
  10. tokenizers.py +150 -0
  11. weight_loader.py +236 -0
README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: mlx
6
+ tags:
7
+ - mlx
8
+ - text-to-image
9
+ - apple-silicon
10
+ - image-generation
11
+ - diffusion
12
+ - flux
13
+ base_model: black-forest-labs/FLUX.1-schnell
14
+ pipeline_tag: text-to-image
15
+ ---
16
+
17
+ # FLUX.1-schnell MLX Pipeline
18
+
19
+ **Pure MLX (Apple Silicon) inference pipeline for [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)** β€” a fast text-to-image model by Black Forest Labs.
20
+
21
+ Zero PyTorch dependency. Runs natively on Apple Silicon via Metal GPU.
22
+
23
+ ## Highlights
24
+
25
+ - **100% MLX native** β€” no torch, no diffusers needed
26
+ - **4-bit quantization** support via `argmaxinc/mlx-FLUX.1-schnell-4bit-quantized`
27
+ - **Fast 4-step generation** (FLUX.1-schnell is distilled for speed)
28
+ - **T5-XXL + CLIP-L** dual text encoders
29
+ - **FluxTransformer** with 19 Joint Blocks + 38 Single Blocks + N-dim RoPE
30
+
31
+ ## Architecture
32
+
33
+ ```
34
+ FluxPipeline
35
+ β”œβ”€β”€ T5-XXL Encoder (24 layers, hidden=4096)
36
+ β”‚ └── Relative positional attention + GatedFFN
37
+ β”œβ”€β”€ CLIP-L Encoder (23 layers, hidden=768)
38
+ β”‚ └── Causal mask + EOS pooling
39
+ β”œβ”€β”€ FluxTransformer (DiT)
40
+ β”‚ β”œβ”€β”€ 19 JointTransformerBlock (txt+img joint attention)
41
+ β”‚ β”œβ”€β”€ 38 SingleTransformerBlock (img self-attention)
42
+ β”‚ └── N-dim RoPE (axes_dim=[16,56,56])
43
+ β”œβ”€β”€ AutoencoderKL Decoder
44
+ β”‚ └── Latent channels=16, block_out=[128,256,512,512]
45
+ └── FlowMatchEuler Sampler
46
+ ```
47
+
48
+ ## Quick Start
49
+
50
+ ### Install
51
+
52
+ ```bash
53
+ pip install mlx safetensors sentencepiece tokenizers pillow numpy
54
+ ```
55
+
56
+ ### Download Weights
57
+
58
+ ```bash
59
+ # 4-bit quantized (recommended, ~5GB)
60
+ huggingface-cli download argmaxinc/mlx-FLUX.1-schnell-4bit-quantized
61
+
62
+ # Or full precision
63
+ huggingface-cli download argmaxinc/mlx-FLUX.1-schnell
64
+ ```
65
+
66
+ ### Generate
67
+
68
+ ```python
69
+ from pipeline import FluxPipeline
70
+
71
+ pipe = FluxPipeline()
72
+ pipe.load()
73
+
74
+ result = pipe.generate_and_save(
75
+ prompt="a beautiful sunset over mountains",
76
+ output_path="output.png",
77
+ width=512,
78
+ height=512,
79
+ num_steps=4,
80
+ seed=42,
81
+ )
82
+ print(f"Generated in {result['elapsed_s']}s")
83
+
84
+ pipe.unload()
85
+ ```
86
+
87
+ ## Files
88
+
89
+ ```
90
+ β”œβ”€β”€ pipeline.py # Main inference pipeline
91
+ β”œβ”€β”€ flux_model.py # FluxTransformer (JointBlock + SingleBlock)
92
+ β”œβ”€β”€ t5_encoder.py # T5-XXL text encoder
93
+ β”œβ”€β”€ clip_encoder.py # CLIP-L text encoder
94
+ β”œβ”€β”€ autoencoder.py # VAE decoder
95
+ β”œβ”€β”€ sampler.py # FlowMatch Euler sampler
96
+ β”œβ”€β”€ tokenizers.py # T5 + CLIP tokenizers
97
+ β”œβ”€β”€ weight_loader.py # Weight loading + key mapping
98
+ └── download_weights.py # HF Hub download helper
99
+ ```
100
+
101
+ ## Model Source
102
+
103
+ Inference code is original work. Weights are loaded from:
104
+ - [argmaxinc/mlx-FLUX.1-schnell-4bit-quantized](https://huggingface.co/argmaxinc/mlx-FLUX.1-schnell-4bit-quantized) (default)
105
+ - [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) (original)
106
+
107
+ ## License
108
+
109
+ Apache 2.0
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """MLX FLUX pipeline package.
2
+
3
+ Provides a minimal FLUX.1-schnell diffusion pipeline implemented in
4
+ pure MLX for Apple Silicon inference. Uses pre-converted weights from
5
+ HuggingFace (argmaxinc/mlx-FLUX.1-schnell or 4bit variant).
6
+ """
autoencoder.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FLUX VAE decoder β€” param names match argmaxinc ae.safetensors keys.
2
+
3
+ Weight key structure:
4
+ decoder.conv_in.*
5
+ decoder.mid.block_{1,2}.{norm1,conv1,norm2,conv2}.*
6
+ decoder.mid.attn_1.{norm,q,k,v,proj_out}.*
7
+ decoder.up.{0-3}.block.{0-2}.{norm1,conv1,norm2,conv2,nin_shortcut}.*
8
+ decoder.up.{1-3}.upsample.conv.*
9
+ decoder.norm_out.*
10
+ decoder.conv_out.*
11
+
12
+ Note: up blocks are indexed in reverse β€” up.3 is the first decoder stage
13
+ (highest channels), up.0 is the last (lowest channels).
14
+
15
+ All conv weights loaded as PyTorch [O,I,kH,kW] are transposed to MLX
16
+ [O,kH,kW,I] in the pipeline's _load_vae().
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import mlx.core as mx
22
+ import mlx.nn as nn
23
+
24
+
25
+ # ── Building blocks (param names match weight keys) ──────────────────────────
26
+
27
+ class ResnetBlock(nn.Module):
28
+ """Matches: block_{i}.{norm1,conv1,norm2,conv2,nin_shortcut}.*"""
29
+
30
+ def __init__(self, in_ch: int, out_ch: int):
31
+ super().__init__()
32
+ self.norm1 = nn.GroupNorm(32, in_ch)
33
+ self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
34
+ self.norm2 = nn.GroupNorm(32, out_ch)
35
+ self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
36
+ if in_ch != out_ch:
37
+ self.nin_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1)
38
+ else:
39
+ self.nin_shortcut = None
40
+
41
+ def __call__(self, x):
42
+ h = nn.silu(self.norm1(x))
43
+ h = self.conv1(h)
44
+ h = nn.silu(self.norm2(h))
45
+ h = self.conv2(h)
46
+ if self.nin_shortcut is not None:
47
+ x = self.nin_shortcut(x)
48
+ return x + h
49
+
50
+
51
+ class AttnBlock(nn.Module):
52
+ """Matches: attn_1.{norm,q,k,v,proj_out}.*
53
+
54
+ Uses 1Γ—1 Conv2d for Q/K/V/O projections (matching weight shapes).
55
+ """
56
+
57
+ def __init__(self, channels: int):
58
+ super().__init__()
59
+ self.norm = nn.GroupNorm(32, channels)
60
+ self.q = nn.Conv2d(channels, channels, kernel_size=1)
61
+ self.k = nn.Conv2d(channels, channels, kernel_size=1)
62
+ self.v = nn.Conv2d(channels, channels, kernel_size=1)
63
+ self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)
64
+
65
+ def __call__(self, x):
66
+ B, H, W, C = x.shape
67
+ h = self.norm(x)
68
+ q = self.q(h).reshape(B, H * W, C)
69
+ k = self.k(h).reshape(B, H * W, C)
70
+ v = self.v(h).reshape(B, H * W, C)
71
+
72
+ scale = C ** -0.5
73
+ attn = (q @ k.transpose(0, 2, 1)) * scale
74
+ attn = mx.softmax(attn, axis=-1)
75
+ h = (attn @ v).reshape(B, H, W, C)
76
+ return x + self.proj_out(h)
77
+
78
+
79
+ class Upsample(nn.Module):
80
+ """Matches: upsample.conv.*"""
81
+
82
+ def __init__(self, channels: int):
83
+ super().__init__()
84
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
85
+
86
+ def __call__(self, x):
87
+ B, H, W, C = x.shape
88
+ x = mx.repeat(x, 2, axis=1)
89
+ x = mx.repeat(x, 2, axis=2)
90
+ return self.conv(x)
91
+
92
+
93
+ class UpBlock(nn.Module):
94
+ """One decoder up-stage. Matches: up.{i}.block.{0-2}.* + up.{i}.upsample.*"""
95
+
96
+ def __init__(self, in_ch: int, out_ch: int, num_blocks: int = 3, has_upsample: bool = True):
97
+ super().__init__()
98
+ self.block = [ResnetBlock(in_ch if j == 0 else out_ch, out_ch) for j in range(num_blocks)]
99
+ if has_upsample:
100
+ self.upsample = Upsample(out_ch)
101
+ else:
102
+ self.upsample = None
103
+
104
+ def __call__(self, x):
105
+ for b in self.block:
106
+ x = b(x)
107
+ if self.upsample is not None:
108
+ x = self.upsample(x)
109
+ return x
110
+
111
+
112
+ class MidBlock(nn.Module):
113
+ """Matches: mid.{block_1, attn_1, block_2}.*"""
114
+
115
+ def __init__(self, channels: int):
116
+ super().__init__()
117
+ self.block_1 = ResnetBlock(channels, channels)
118
+ self.attn_1 = AttnBlock(channels)
119
+ self.block_2 = ResnetBlock(channels, channels)
120
+
121
+ def __call__(self, x):
122
+ x = self.block_1(x)
123
+ x = self.attn_1(x)
124
+ x = self.block_2(x)
125
+ return x
126
+
127
+
128
+ # ── Decoder ──────────────────────────────────────────────────────────────────
129
+
130
+ class Decoder(nn.Module):
131
+ """VAE Decoder. Param paths match: decoder.{conv_in,mid,up,norm_out,conv_out}.*
132
+
133
+ Up block order (matching weight keys):
134
+ up.3 β†’ 512β†’512 + upsample (first stage)
135
+ up.2 β†’ 512β†’512 + upsample
136
+ up.1 β†’ 512β†’256 + upsample
137
+ up.0 β†’ 256β†’128 (no upsample, last stage)
138
+ """
139
+
140
+ def __init__(self):
141
+ super().__init__()
142
+ self.conv_in = nn.Conv2d(16, 512, kernel_size=3, padding=1)
143
+
144
+ self.mid = MidBlock(512)
145
+
146
+ # up blocks β€” indexed 0-3, processed in reverse order (3β†’2β†’1β†’0)
147
+ self.up = [
148
+ UpBlock(256, 128, num_blocks=3, has_upsample=False), # up.0
149
+ UpBlock(512, 256, num_blocks=3, has_upsample=True), # up.1
150
+ UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.2
151
+ UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.3
152
+ ]
153
+
154
+ self.norm_out = nn.GroupNorm(32, 128)
155
+ self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1)
156
+
157
+ def __call__(self, z):
158
+ h = self.conv_in(z)
159
+ h = self.mid(h)
160
+ # Process up blocks in reverse order: 3, 2, 1, 0
161
+ for i in reversed(range(len(self.up))):
162
+ h = self.up[i](h)
163
+ h = nn.silu(self.norm_out(h))
164
+ h = self.conv_out(h)
165
+ return h
166
+
167
+
168
+ # ── AutoencoderKL ────────────────────────────────────────────────────────────
169
+
170
+ class AutoencoderKL(nn.Module):
171
+ """FLUX VAE β€” decode-only path.
172
+
173
+ Input: z [B, H/8, W/8, 16] (latent, channels-last)
174
+ Output: image [B, H, W, 3] (RGB in [0, 1])
175
+ """
176
+
177
+ SCALE_FACTOR = 0.3611
178
+ SHIFT_FACTOR = 0.1159
179
+
180
+ def __init__(self):
181
+ super().__init__()
182
+ self.decoder = Decoder()
183
+
184
+ def decode(self, z: mx.array) -> mx.array:
185
+ z = z / self.SCALE_FACTOR + self.SHIFT_FACTOR
186
+ image = self.decoder(z)
187
+ image = mx.clip((image + 1.0) / 2.0, 0.0, 1.0)
188
+ return image
clip_encoder.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLIP-L text encoder for FLUX pipeline.
2
+
3
+ Implements a 23-layer CLIP text encoder with absolute position embeddings
4
+ and causal self-attention β€” matching the HuggingFace
5
+ ``openai/clip-vit-large-patch14`` architecture used by FLUX.1.
6
+
7
+ Weight source: ``black-forest-labs/FLUX.1-schnell`` β†’
8
+ ``text_encoder/model.safetensors``
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+
15
+ import mlx.core as mx
16
+ import mlx.nn as nn
17
+
18
+
19
+ # ── CLIP Config ──────────────────────────────────────────────────────────────
20
+
21
+ class CLIPConfig:
22
+ vocab_size: int = 49408
23
+ d_model: int = 768
24
+ num_heads: int = 12
25
+ head_dim: int = 64 # d_model / num_heads
26
+ intermediate_size: int = 3072
27
+ num_layers: int = 23 # FLUX uses 23, not 12
28
+ max_position_embeddings: int = 77
29
+
30
+
31
+ # ── Building blocks ──────────────────────────────────────────────────────────
32
+
33
+ class CLIPAttention(nn.Module):
34
+ """CLIP multi-head self-attention."""
35
+
36
+ def __init__(self, cfg: CLIPConfig):
37
+ super().__init__()
38
+ self.num_heads = cfg.num_heads
39
+ self.head_dim = cfg.head_dim
40
+ d = cfg.d_model
41
+
42
+ self.q_proj = nn.Linear(d, d)
43
+ self.k_proj = nn.Linear(d, d)
44
+ self.v_proj = nn.Linear(d, d)
45
+ self.out_proj = nn.Linear(d, d)
46
+
47
+ def __call__(self, x: mx.array, causal_mask: mx.array | None = None) -> mx.array:
48
+ B, L, _ = x.shape
49
+ H, D = self.num_heads, self.head_dim
50
+
51
+ q = self.q_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
52
+ k = self.k_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
53
+ v = self.v_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
54
+
55
+ scale = math.sqrt(D)
56
+ scores = (q @ k.transpose(0, 1, 3, 2)) / scale
57
+
58
+ if causal_mask is not None:
59
+ scores = scores + causal_mask
60
+
61
+ weights = mx.softmax(scores, axis=-1)
62
+ out = weights @ v
63
+ out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
64
+ return self.out_proj(out)
65
+
66
+
67
+ class CLIPMLP(nn.Module):
68
+ """CLIP feed-forward network (GELU activation)."""
69
+
70
+ def __init__(self, cfg: CLIPConfig):
71
+ super().__init__()
72
+ self.fc1 = nn.Linear(cfg.d_model, cfg.intermediate_size)
73
+ self.fc2 = nn.Linear(cfg.intermediate_size, cfg.d_model)
74
+
75
+ def __call__(self, x: mx.array) -> mx.array:
76
+ return self.fc2(nn.gelu_approx(self.fc1(x)))
77
+
78
+
79
+ class CLIPEncoderLayer(nn.Module):
80
+ """Single CLIP encoder layer: Norm β†’ Attention β†’ Norm β†’ MLP."""
81
+
82
+ def __init__(self, cfg: CLIPConfig):
83
+ super().__init__()
84
+ self.norm1 = nn.LayerNorm(cfg.d_model)
85
+ self.attn = CLIPAttention(cfg)
86
+ self.norm2 = nn.LayerNorm(cfg.d_model)
87
+ self.mlp = CLIPMLP(cfg)
88
+
89
+ def __call__(self, x: mx.array, causal_mask: mx.array | None = None) -> mx.array:
90
+ x = x + self.attn(self.norm1(x), causal_mask)
91
+ x = x + self.mlp(self.norm2(x))
92
+ return x
93
+
94
+
95
+ # ── CLIP Encoder ─────────────────────────────────────────────────────────────
96
+
97
+ class CLIPEncoder(nn.Module):
98
+ """CLIP-L text encoder: 23-layer transformer with absolute position embeddings.
99
+
100
+ Input: token_ids [B, 77]
101
+ Output: (pooled [B, 768], hidden_states [B, 77, 768])
102
+
103
+ The pooled output is taken from the EOS token position (last
104
+ non-padding token), following the CLIP convention.
105
+ """
106
+
107
+ def __init__(self, cfg: CLIPConfig | None = None):
108
+ super().__init__()
109
+ if cfg is None:
110
+ cfg = CLIPConfig()
111
+ self.cfg = cfg
112
+
113
+ self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
114
+ self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.d_model)
115
+ self.layers = [CLIPEncoderLayer(cfg) for _ in range(cfg.num_layers)]
116
+ self.final_norm = nn.LayerNorm(cfg.d_model)
117
+
118
+ def _build_causal_mask(self, seq_len: int) -> mx.array:
119
+ """Build causal attention mask [1, 1, L, L]."""
120
+ mask = mx.full((seq_len, seq_len), -1e9)
121
+ mask = mx.triu(mask, k=1) # upper triangle = -inf, diagonal+below = 0
122
+ return mask.reshape(1, 1, seq_len, seq_len)
123
+
124
+ def __call__(self, token_ids: mx.array) -> tuple[mx.array, mx.array]:
125
+ B, L = token_ids.shape
126
+
127
+ # Embeddings
128
+ positions = mx.arange(L)
129
+ x = self.token_emb(token_ids) + self.pos_emb(positions)
130
+
131
+ # Causal mask
132
+ causal_mask = self._build_causal_mask(L)
133
+
134
+ # Transformer layers
135
+ for layer in self.layers:
136
+ x = layer(x, causal_mask)
137
+
138
+ x = self.final_norm(x) # [B, L, d_model]
139
+
140
+ # Pooled output: EOS token position
141
+ # Find the EOS token (49407) or use the last non-zero position
142
+ eos_id = 49407
143
+ # For each batch element, find the position of EOS
144
+ # Simple approach: use argmax on (token_ids == eos_id)
145
+ eos_mask = (token_ids == eos_id).astype(mx.int32)
146
+ # If EOS not found, use last position
147
+ has_eos = mx.sum(eos_mask, axis=-1, keepdims=True) > 0
148
+ eos_pos = mx.argmax(eos_mask, axis=-1) # [B]
149
+
150
+ # Gather pooled output
151
+ idx = eos_pos.reshape(B, 1, 1)
152
+ idx = mx.broadcast_to(idx, (B, 1, x.shape[-1]))
153
+ pooled = mx.take_along_axis(x, idx, axis=1).squeeze(1) # [B, d_model]
154
+
155
+ return pooled, x
download_weights.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Download shared T5/CLIP weights for the self-built FLUX MLX pipeline.
3
+
4
+ Stores files locally in backends/mlx_flux/weights/ so they are co-located
5
+ with our code and won't be accidentally deleted as "unused HF cache".
6
+ """
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ REPO = "black-forest-labs/FLUX.1-schnell"
13
+ WEIGHTS_DIR = Path(__file__).parent / "weights"
14
+
15
+ FILES = [
16
+ # (HF repo path, local filename)
17
+ ("text_encoder_2/model-00001-of-00002.safetensors", "t5_shard1.safetensors"),
18
+ ("text_encoder_2/model-00002-of-00002.safetensors", "t5_shard2.safetensors"),
19
+ ("text_encoder/model.safetensors", "clip_text_encoder.safetensors"),
20
+ ("tokenizer_2/spiece.model", "t5_spiece.model"),
21
+ ("tokenizer/vocab.json", "clip_vocab.json"),
22
+ ("tokenizer/merges.txt", "clip_merges.txt"),
23
+ ]
24
+
25
+
26
+ def main():
27
+ WEIGHTS_DIR.mkdir(exist_ok=True)
28
+ for hf_path, local_name in FILES:
29
+ dest = WEIGHTS_DIR / local_name
30
+ if dest.exists() and dest.stat().st_size > 0:
31
+ print(f" SKIP {local_name} (already exists, {dest.stat().st_size / 1024 / 1024:.1f} MB)")
32
+ continue
33
+ print(f" DOWNLOADING {hf_path} -> {local_name} ...")
34
+ cached = hf_hub_download(REPO, hf_path)
35
+ shutil.copy2(cached, dest)
36
+ print(f" OK {local_name} ({dest.stat().st_size / 1024 / 1024:.1f} MB)")
37
+ print("\nAll weights ready in", WEIGHTS_DIR)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
flux_model.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FLUX DiT β€” rewritten to match mflux reference implementation exactly.
2
+
3
+ Parameter names match argmaxinc/mlx-FLUX.1-schnell weights.
4
+ Forward pass logic matches filipstrand/mflux.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+ import mlx.core as mx
12
+ import mlx.nn as nn
13
+
14
+
15
+ # ── Config ───────────────────────────────────────────────────────────────────
16
+
17
+ class FluxConfig:
18
+ hidden_size: int = 3072
19
+ num_heads: int = 24
20
+ head_dim: int = 128
21
+ mlp_ratio: float = 4.0
22
+ num_joint_blocks: int = 19
23
+ num_single_blocks: int = 38
24
+ axes_dim: tuple[int, ...] = (16, 56, 56)
25
+ theta: float = 10_000.0
26
+ in_channels: int = 64
27
+ context_dim: int = 4096
28
+ pooled_dim: int = 768
29
+
30
+
31
+ # ── RoPE (matches mflux EmbedND) ────────────────────────────────────────────
32
+
33
+ def _rope_single_axis(pos: mx.array, dim: int, theta: float) -> mx.array:
34
+ """Compute RoPE for one positional axis.
35
+
36
+ Returns [B, seq, dim//2, 2, 2] rotation matrices.
37
+ """
38
+ scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
39
+ omega = 1.0 / (theta ** scale)
40
+ # pos: [B, seq], omega: [dim//2]
41
+ out = pos[:, :, None].astype(mx.float32) * omega[None, :] # [B, seq, dim//2]
42
+ cos_out = mx.cos(out)
43
+ sin_out = mx.sin(out)
44
+ # Stack as 2x2 rotation matrix: [[cos, -sin], [sin, cos]]
45
+ stacked = mx.stack([cos_out, -sin_out, sin_out, cos_out], axis=-1)
46
+ return stacked.reshape(pos.shape[0], -1, dim // 2, 2, 2)
47
+
48
+
49
+ def compute_rope(ids: mx.array, axes_dim=(16, 56, 56), theta=10000.0) -> mx.array:
50
+ """Compute N-dimensional RoPE embeddings.
51
+
52
+ Args:
53
+ ids: [B, seq, 3] position IDs (time, height, width)
54
+
55
+ Returns:
56
+ [B, 1, seq, head_dim//2, 2, 2] rotation matrices
57
+ """
58
+ emb = mx.concatenate([
59
+ _rope_single_axis(ids[..., i], axes_dim[i], theta)
60
+ for i in range(3)
61
+ ], axis=-3) # concat along the dim//2 axis β†’ total = sum(axes_dim)//2 = 64
62
+ return emb[:, None] # [B, 1, seq, 64, 2, 2]
63
+
64
+
65
+ def apply_rope(q: mx.array, k: mx.array, freqs: mx.array):
66
+ """Apply rotary embeddings to q and k (matches mflux exactly).
67
+
68
+ q, k: [B, H, L, D] where D = head_dim
69
+ freqs: [B, 1, L, D//2, 2, 2]
70
+ """
71
+ # Reshape to pairs: [B, H, L, D//2, 1, 2]
72
+ xq_ = q.astype(mx.float32).reshape(*q.shape[:-1], -1, 1, 2)
73
+ xk_ = k.astype(mx.float32).reshape(*k.shape[:-1], -1, 1, 2)
74
+
75
+ # freqs[..., 0] = [[cos, -sin]] shape [..., 2]
76
+ # freqs[..., 1] = [[sin, cos]] shape [..., 2]
77
+ # xq_[..., 0] = first of pair (scalar), xq_[..., 1] = second of pair
78
+ xq_out = freqs[..., 0] * xq_[..., 0] + freqs[..., 1] * xq_[..., 1]
79
+ xk_out = freqs[..., 0] * xk_[..., 0] + freqs[..., 1] * xk_[..., 1]
80
+
81
+ return xq_out.reshape(*q.shape).astype(mx.float32), xk_out.reshape(*k.shape).astype(mx.float32)
82
+
83
+
84
+ # ── Timestep embedding (matches mflux TimeTextEmbed) ─────────────────────────
85
+
86
+ def timestep_embedding(t: mx.array, dim: int = 256) -> mx.array:
87
+ """Sinusoidal timestep embedding with half-freq swap (mflux convention).
88
+
89
+ Output: [B, dim] with [cos_high, sin_low] β†’ [sin_high, cos_low] swap.
90
+ """
91
+ half = dim // 2
92
+ freqs = mx.exp(-math.log(10000.0) * mx.arange(half, dtype=mx.float32) / half)
93
+ args = t[:, None].astype(mx.float32) * freqs[None, :]
94
+ emb = mx.concatenate([mx.sin(args), mx.cos(args)], axis=-1)
95
+ # mflux swaps halves: [sin, cos] β†’ [cos_high_half, sin_low_half]
96
+ emb = mx.concatenate([emb[:, half:], emb[:, :half]], axis=-1)
97
+ return emb
98
+
99
+
100
+ # ── AdaLN modulation (matches mflux AdaLayerNormZero) ────────────────────────
101
+
102
+ class AdaLNModulation(nn.Module):
103
+ """Matches: adaLN_modulation.layers.1.*
104
+
105
+ layers.0 = SiLU (no params), layers.1 = Linear.
106
+ """
107
+ def __init__(self, dim: int, n_params: int):
108
+ super().__init__()
109
+ self.layers = [nn.SiLU(), nn.Linear(dim, n_params * dim)]
110
+
111
+ def __call__(self, x: mx.array) -> mx.array:
112
+ for layer in self.layers:
113
+ x = layer(x)
114
+ return x
115
+
116
+
117
+ # ── QK Norm ──────────────────────────────────────────────────────────────────
118
+
119
+ class QKNorm(nn.Module):
120
+ """Matches: qk_norm.{q_norm, k_norm}.weight"""
121
+ def __init__(self, dim: int):
122
+ super().__init__()
123
+ self.q_norm = nn.RMSNorm(dim)
124
+ self.k_norm = nn.RMSNorm(dim)
125
+
126
+
127
+ # ── Attention ────────────────────────────────────────────────────────────────
128
+
129
+ class Attention(nn.Module):
130
+ """Separate Q/K/V/O projections. Matches: attn.{q,k,v,o}_proj.*"""
131
+ def __init__(self, dim: int, num_heads: int):
132
+ super().__init__()
133
+ self.q_proj = nn.Linear(dim, dim)
134
+ self.k_proj = nn.Linear(dim, dim)
135
+ self.v_proj = nn.Linear(dim, dim)
136
+ self.o_proj = nn.Linear(dim, dim)
137
+ self.num_heads = num_heads
138
+ self.head_dim = dim // num_heads
139
+
140
+
141
+ # ── MLP ──────────────────────────────────────────────────────────────────────
142
+
143
+ class MLP(nn.Module):
144
+ """Matches: mlp.{fc1,fc2}.*"""
145
+ def __init__(self, dim: int, hidden: int):
146
+ super().__init__()
147
+ self.fc1 = nn.Linear(dim, hidden)
148
+ self.fc2 = nn.Linear(hidden, dim)
149
+
150
+ def __call__(self, x: mx.array) -> mx.array:
151
+ return self.fc2(nn.gelu(self.fc1(x)))
152
+
153
+
154
+ # ── Joint Transformer Block ─────────────────────────────────────────────────
155
+
156
+ class ImageTransformerSubBlock(nn.Module):
157
+ """Image side of joint block.
158
+ Matches: image_transformer_block.{adaLN_modulation, attn, mlp, qk_norm}.*
159
+ """
160
+ def __init__(self, cfg: FluxConfig):
161
+ super().__init__()
162
+ H = cfg.hidden_size
163
+ self.adaLN_modulation = AdaLNModulation(H, 6)
164
+ self.attn = Attention(H, cfg.num_heads)
165
+ self.mlp = MLP(H, int(H * cfg.mlp_ratio))
166
+ self.qk_norm = QKNorm(cfg.head_dim)
167
+
168
+
169
+ class TextTransformerSubBlock(nn.Module):
170
+ """Text side of joint block.
171
+ Matches: text_transformer_block.{adaLN_modulation, attn, mlp, qk_norm}.*
172
+ NOTE: text side uses hidden_size/2 = 1536 for FFN (mflux convention).
173
+ """
174
+ def __init__(self, cfg: FluxConfig):
175
+ super().__init__()
176
+ H = cfg.hidden_size
177
+ self.adaLN_modulation = AdaLNModulation(H, 6)
178
+ self.attn = Attention(H, cfg.num_heads)
179
+ # Text FFN uses gelu_approx and different hidden dim
180
+ self.mlp = MLP(H, int(H * cfg.mlp_ratio))
181
+ self.qk_norm = QKNorm(cfg.head_dim)
182
+
183
+
184
+ class JointTransformerBlock(nn.Module):
185
+ """Matches: multimodal_transformer_blocks.{i}.*"""
186
+
187
+ def __init__(self, cfg: FluxConfig):
188
+ super().__init__()
189
+ self.image_transformer_block = ImageTransformerSubBlock(cfg)
190
+ self.text_transformer_block = TextTransformerSubBlock(cfg)
191
+ self._num_heads = cfg.num_heads
192
+ self._head_dim = cfg.head_dim
193
+
194
+ def __call__(self, img, txt, vec, rope_emb):
195
+ B = img.shape[0]
196
+ H, D = self._num_heads, self._head_dim
197
+ img_len = img.shape[1]
198
+ txt_len = txt.shape[1]
199
+
200
+ # 1. AdaLN modulation
201
+ img_params = self.image_transformer_block.adaLN_modulation(vec)
202
+ i_s1, i_sc1, i_g1, i_s2, i_sc2, i_g2 = mx.split(img_params, 6, axis=-1)
203
+
204
+ txt_params = self.text_transformer_block.adaLN_modulation(vec)
205
+ t_s1, t_sc1, t_g1, t_s2, t_sc2, t_g2 = mx.split(txt_params, 6, axis=-1)
206
+
207
+ # 2. LayerNorm(affine=False) + modulate
208
+ img_norm = nn.LayerNorm(img.shape[-1], affine=False, eps=1e-6)(img)
209
+ img_norm = img_norm * (1 + i_sc1[:, None, :]) + i_s1[:, None, :]
210
+ txt_norm = nn.LayerNorm(txt.shape[-1], affine=False, eps=1e-6)(txt)
211
+ txt_norm = txt_norm * (1 + t_sc1[:, None, :]) + t_s1[:, None, :]
212
+
213
+ # 3. Q/K/V projections + QK norm
214
+ img_q = self.image_transformer_block.attn.q_proj(img_norm).reshape(B, img_len, H, D)
215
+ img_k = self.image_transformer_block.attn.k_proj(img_norm).reshape(B, img_len, H, D)
216
+ img_v = self.image_transformer_block.attn.v_proj(img_norm).reshape(B, img_len, H, D)
217
+ txt_q = self.text_transformer_block.attn.q_proj(txt_norm).reshape(B, txt_len, H, D)
218
+ txt_k = self.text_transformer_block.attn.k_proj(txt_norm).reshape(B, txt_len, H, D)
219
+ txt_v = self.text_transformer_block.attn.v_proj(txt_norm).reshape(B, txt_len, H, D)
220
+
221
+ img_q = self.image_transformer_block.qk_norm.q_norm(img_q)
222
+ img_k = self.image_transformer_block.qk_norm.k_norm(img_k)
223
+ txt_q = self.text_transformer_block.qk_norm.q_norm(txt_q)
224
+ txt_k = self.text_transformer_block.qk_norm.k_norm(txt_k)
225
+
226
+ # 4. Concat for joint attention: [txt, img]
227
+ q = mx.concatenate([txt_q, img_q], axis=1).transpose(0, 2, 1, 3) # [B,H,L,D]
228
+ k = mx.concatenate([txt_k, img_k], axis=1).transpose(0, 2, 1, 3)
229
+ v = mx.concatenate([txt_v, img_v], axis=1).transpose(0, 2, 1, 3)
230
+
231
+ # 5. RoPE
232
+ q, k = apply_rope(q, k, rope_emb)
233
+
234
+ # 6. Attention
235
+ scale = math.sqrt(D)
236
+ scores = (q @ k.transpose(0, 1, 3, 2)) / scale
237
+ weights = mx.softmax(scores, axis=-1)
238
+ attn_out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, txt_len + img_len, -1)
239
+
240
+ # 7. Split and project
241
+ txt_attn = attn_out[:, :txt_len, :]
242
+ img_attn = attn_out[:, txt_len:, :]
243
+
244
+ # 8. Gated residual for attention
245
+ img_attn = self.image_transformer_block.attn.o_proj(img_attn)
246
+ txt_attn = self.text_transformer_block.attn.o_proj(txt_attn)
247
+ img = img + i_g1[:, None, :] * img_attn
248
+ txt = txt + t_g1[:, None, :] * txt_attn
249
+
250
+ # 9. FFN with LayerNorm(affine=False) + modulate
251
+ img_ff_in = nn.LayerNorm(img.shape[-1], affine=False, eps=1e-6)(img)
252
+ img_ff_in = img_ff_in * (1 + i_sc2[:, None, :]) + i_s2[:, None, :]
253
+ img = img + i_g2[:, None, :] * self.image_transformer_block.mlp(img_ff_in)
254
+
255
+ txt_ff_in = nn.LayerNorm(txt.shape[-1], affine=False, eps=1e-6)(txt)
256
+ txt_ff_in = txt_ff_in * (1 + t_sc2[:, None, :]) + t_s2[:, None, :]
257
+ txt = txt + t_g2[:, None, :] * self.text_transformer_block.mlp(txt_ff_in)
258
+
259
+ return img, txt
260
+
261
+
262
+ # ── Single Transformer Block ────────────────────────────────────────────────
263
+
264
+ class SingleTransformerSubBlock(nn.Module):
265
+ """Matches: unified_transformer_blocks.{i}.transformer_block.*"""
266
+
267
+ def __init__(self, cfg: FluxConfig):
268
+ super().__init__()
269
+ H = cfg.hidden_size
270
+ mlp_hidden = int(H * cfg.mlp_ratio)
271
+ self.adaLN_modulation = AdaLNModulation(H, 3)
272
+ self.attn = Attention(H, cfg.num_heads)
273
+ self.mlp = MLP(H, mlp_hidden)
274
+ self.qk_norm = QKNorm(cfg.head_dim)
275
+
276
+ def __call__(self, x, vec, rope_emb):
277
+ B, L, D = x.shape
278
+ H, HD = self._get_heads()
279
+ residual = x
280
+
281
+ # 1. AdaLN
282
+ params = self.adaLN_modulation(vec)
283
+ shift, scale, gate = mx.split(params, 3, axis=-1)
284
+ x_norm = nn.LayerNorm(D, affine=False, eps=1e-6)(x)
285
+ x_norm = x_norm * (1 + scale[:, None, :]) + shift[:, None, :]
286
+
287
+ # 2. Attention with QK norm
288
+ q = self.attn.q_proj(x_norm).reshape(B, L, H, HD)
289
+ k = self.attn.k_proj(x_norm).reshape(B, L, H, HD)
290
+ v = self.attn.v_proj(x_norm).reshape(B, L, H, HD)
291
+
292
+ q = self.qk_norm.q_norm(q)
293
+ k = self.qk_norm.k_norm(k)
294
+
295
+ q = q.transpose(0, 2, 1, 3)
296
+ k = k.transpose(0, 2, 1, 3)
297
+ v = v.transpose(0, 2, 1, 3)
298
+
299
+ q, k = apply_rope(q, k, rope_emb)
300
+
301
+ sc = math.sqrt(HD)
302
+ scores = (q @ k.transpose(0, 1, 3, 2)) / sc
303
+ w = mx.softmax(scores, axis=-1)
304
+ attn_out = (w @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
305
+
306
+ # 3. Parallel MLP
307
+ mlp_out = nn.gelu_approx(self.mlp.fc1(x_norm))
308
+
309
+ # 4. Concat(attn, mlp) β†’ project β†’ gate
310
+ combined = mx.concatenate([attn_out, mlp_out], axis=-1)
311
+ # proj_out dimensions: attn_dim(3072) + mlp_hidden(12288) β†’ 3072
312
+ # But we don't have proj_out as a separate param β€” use attn.o_proj for attn part
313
+ # and mlp.fc2 for mlp part, then add
314
+ attn_projected = self.attn.o_proj(attn_out)
315
+ mlp_projected = self.mlp.fc2(mlp_out)
316
+ out = gate[:, None, :] * (attn_projected + mlp_projected)
317
+
318
+ return residual + out
319
+
320
+ def _get_heads(self):
321
+ return 24, 128
322
+
323
+
324
+ class SingleTransformerBlock(nn.Module):
325
+ """Wrapper to match key path: unified_transformer_blocks.{i}.transformer_block.*"""
326
+ def __init__(self, cfg: FluxConfig):
327
+ super().__init__()
328
+ self.transformer_block = SingleTransformerSubBlock(cfg)
329
+
330
+ def __call__(self, x, vec, rope_emb):
331
+ return self.transformer_block(x, vec, rope_emb)
332
+
333
+
334
+ # ── FLUX Transformer ────────────────────────────────────────────────────────
335
+
336
+ class FluxTransformer(nn.Module):
337
+ """FLUX DiT matching argmaxinc weights + mflux forward pass."""
338
+
339
+ def __init__(self, cfg: FluxConfig | None = None):
340
+ super().__init__()
341
+ if cfg is None:
342
+ cfg = FluxConfig()
343
+ self.cfg = cfg
344
+ H = cfg.hidden_size
345
+
346
+ # x_embedder: Linear (NOT Conv2d) β€” matches mflux
347
+ self.x_embedder = nn.Linear(cfg.in_channels, H)
348
+
349
+ # context_embedder
350
+ self.context_embedder = nn.Linear(cfg.context_dim, H)
351
+
352
+ # Timestep + text embeddings (match mflux naming)
353
+ self.t_embedder = _TimestepEmbedder(H)
354
+ self.y_embedder = _TextEmbedder(cfg.pooled_dim, H)
355
+
356
+ # Transformer blocks
357
+ self.multimodal_transformer_blocks = [
358
+ JointTransformerBlock(cfg) for _ in range(cfg.num_joint_blocks)
359
+ ]
360
+ self.unified_transformer_blocks = [
361
+ SingleTransformerBlock(cfg) for _ in range(cfg.num_single_blocks)
362
+ ]
363
+
364
+ # Final layer (matches mflux AdaLayerNormContinuous)
365
+ self.final_layer = _FinalLayer(H, cfg.in_channels)
366
+
367
+ def __call__(self, img, img_ids, txt, txt_ids, y, timesteps):
368
+ # 1. Embeddings
369
+ img = self.x_embedder(img) # [B, seq, 64] β†’ [B, seq, 3072]
370
+ txt = self.context_embedder(txt) # [B, seq, 4096] β†’ [B, seq, 3072]
371
+
372
+ # 2. Timestep conditioning β€” timesteps already scaled to [0, 1000] by scheduler
373
+ t_emb = timestep_embedding(timesteps, 256)
374
+ vec = self.t_embedder(t_emb) + self.y_embedder(y)
375
+
376
+ # 3. RoPE for full sequence [txt, img]
377
+ all_ids = mx.concatenate([txt_ids, img_ids], axis=1)
378
+ rope_emb = compute_rope(all_ids, self.cfg.axes_dim, self.cfg.theta)
379
+
380
+ # 4. Joint blocks
381
+ for block in self.multimodal_transformer_blocks:
382
+ img, txt = block(img, txt, vec, rope_emb)
383
+
384
+ # 5. Concat for single blocks
385
+ img = mx.concatenate([txt, img], axis=1)
386
+
387
+ # 6. Single blocks (rope covers full sequence)
388
+ for block in self.unified_transformer_blocks:
389
+ img = block(img, vec, rope_emb)
390
+
391
+ # 7. Extract img tokens, apply final layer
392
+ img = img[:, txt.shape[1]:, :]
393
+ img = self.final_layer(img, vec)
394
+
395
+ return img
396
+
397
+
398
+ # ── Helper modules (match mflux TimestepEmbedder/TextEmbedder naming) ────────
399
+
400
+ class _TimestepEmbedder(nn.Module):
401
+ """Matches: t_embedder.mlp.layers.{0,2}.*"""
402
+ def __init__(self, dim):
403
+ super().__init__()
404
+ self.mlp = _MLP2(256, dim)
405
+
406
+ def __call__(self, x):
407
+ return self.mlp(x)
408
+
409
+
410
+ class _TextEmbedder(nn.Module):
411
+ """Matches: y_embedder.mlp.layers.{0,2}.*"""
412
+ def __init__(self, in_dim, dim):
413
+ super().__init__()
414
+ self.mlp = _MLP2(in_dim, dim)
415
+
416
+ def __call__(self, x):
417
+ return self.mlp(x)
418
+
419
+
420
+ class _MLP2(nn.Module):
421
+ """Two-layer MLP with SiLU. Matches: mlp.layers.{0,2}.*"""
422
+ def __init__(self, in_dim, out_dim):
423
+ super().__init__()
424
+ self.layers = [nn.Linear(in_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim)]
425
+
426
+ def __call__(self, x):
427
+ for layer in self.layers:
428
+ x = layer(x)
429
+ return x
430
+
431
+
432
+ class _FinalLayer(nn.Module):
433
+ """Matches: final_layer.{adaLN_modulation, linear}.*
434
+
435
+ Uses AdaLayerNormContinuous: LayerNorm(affine=False) + Linear(bias=False) modulation.
436
+ """
437
+ def __init__(self, hidden_size, out_channels):
438
+ super().__init__()
439
+ self.adaLN_modulation = AdaLNModulation(hidden_size, 2)
440
+ self.linear = nn.Linear(hidden_size, out_channels)
441
+
442
+ def __call__(self, x, vec):
443
+ params = self.adaLN_modulation(vec)
444
+ scale, shift = mx.split(params, 2, axis=-1)
445
+ x = nn.LayerNorm(x.shape[-1], affine=False, eps=1e-6)(x)
446
+ x = x * (1 + scale[:, None, :]) + shift[:, None, :]
447
+ return self.linear(x)
pipeline.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FLUX.1-schnell pipeline β€” end-to-end text-to-image on MLX.
2
+
3
+ Orchestrates the full inference chain:
4
+ 1. Tokenize prompt (T5 + CLIP)
5
+ 2. Encode text (T5 β†’ embeddings, CLIP β†’ pooled)
6
+ 3. Initialize latent noise + patchify
7
+ 4. Denoising loop (rectified flow, Euler steps)
8
+ 5. Unpatchify + VAE decode β†’ PIL.Image
9
+
10
+ Memory strategy: components are loaded/unloaded in phases so that
11
+ only one large model occupies unified memory at a time.
12
+
13
+ Usage::
14
+
15
+ pipe = FluxPipeline()
16
+ pipe.load("argmaxinc/mlx-FLUX.1-schnell-4bit-quantized")
17
+ img = pipe.generate("a cat in a garden", steps=4, width=512, height=512)
18
+ img.save("output.png")
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import logging
24
+ import os
25
+ import time
26
+ from pathlib import Path
27
+
28
+ import mlx.core as mx
29
+ import mlx.nn as nn
30
+ from huggingface_hub import hf_hub_download
31
+ from PIL import Image
32
+
33
+ from .autoencoder import AutoencoderKL
34
+ from .clip_encoder import CLIPEncoder
35
+ from .flux_model import FluxTransformer
36
+ from .sampler import FlowMatchEulerScheduler, compute_img_ids, patchify, unpatchify
37
+ from .t5_encoder import T5Encoder
38
+ from .tokenizers import CLIPTokenizer, T5Tokenizer
39
+ from . import weight_loader
40
+
41
+ logger = logging.getLogger("image-server")
42
+
43
+ # ── Model repos ──────────────────────────────────────────────────────────────
44
+
45
+ DEFAULT_MODEL = "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized"
46
+ FULL_MODEL = "argmaxinc/mlx-FLUX.1-schnell"
47
+ T5_CLIP_REPO = "black-forest-labs/FLUX.1-schnell"
48
+
49
+ # Local weights directory (co-located with pipeline code, won't be
50
+ # accidentally deleted when cleaning HF cache)
51
+ _WEIGHTS_DIR = Path(__file__).parent / "weights"
52
+
53
+
54
+ class FluxPipeline:
55
+ """FLUX.1-schnell pipeline for MLX (Apple Silicon).
56
+
57
+ Manages model loading, text encoding, denoising, and VAE decoding
58
+ with phase-based memory management.
59
+ """
60
+
61
+ def __init__(self, model_id: str | None = None, quantize: bool = False):
62
+ self.model_id = model_id or DEFAULT_MODEL
63
+ self.quantize = quantize
64
+
65
+ # Performance options
66
+ import os
67
+ self.use_compile = os.environ.get("IMAGE_MLX_COMPILE", "0") in ("1", "true", "yes")
68
+ self.phased_memory = os.environ.get("IMAGE_MLX_PHASED_MEM", "1") not in ("0", "false", "no")
69
+
70
+ # Components (lazily loaded)
71
+ self.t5_tokenizer: T5Tokenizer | None = None
72
+ self.clip_tokenizer: CLIPTokenizer | None = None
73
+ self.t5_encoder: T5Encoder | None = None
74
+ self.clip_encoder: CLIPEncoder | None = None
75
+ self.transformer: FluxTransformer | None = None
76
+ self.vae: AutoencoderKL | None = None
77
+ self.scheduler = FlowMatchEulerScheduler()
78
+
79
+ self._loaded = False
80
+
81
+ # ── Loading ──────────────────────────────────────────────────────────
82
+
83
+ def load(self, model_id: str | None = None) -> None:
84
+ """Download and load all model components."""
85
+ repo_id = model_id or self.model_id
86
+ t0 = time.time()
87
+ logger.info("[MLX] Loading FLUX pipeline from %s ...", repo_id)
88
+
89
+ # 1. Download files
90
+ dit_path = self._download_dit(repo_id)
91
+ ae_path = hf_hub_download(repo_id, "ae.safetensors")
92
+ t5_spiece = self._local_or_hf("t5_spiece.model", "tokenizer_2/spiece.model")
93
+ clip_vocab = self._local_or_hf("clip_vocab.json", "tokenizer/vocab.json")
94
+ clip_merges = self._local_or_hf("clip_merges.txt", "tokenizer/merges.txt")
95
+
96
+ # T5 weights (multi-shard)
97
+ t5_paths = self._download_t5_weights()
98
+
99
+ # CLIP weights
100
+ clip_path = self._download_clip_weights()
101
+
102
+ # 2. Init tokenizers
103
+ self.t5_tokenizer = T5Tokenizer(t5_spiece, max_length=256)
104
+ self.clip_tokenizer = CLIPTokenizer(clip_vocab, clip_merges, max_length=77)
105
+
106
+ # 3. Build models
107
+ self.t5_encoder = T5Encoder()
108
+ self.clip_encoder = CLIPEncoder()
109
+ self.transformer = FluxTransformer()
110
+ self.vae = AutoencoderKL()
111
+
112
+ # 4. Load weights (each component independently β€” partial loading OK)
113
+ if t5_paths:
114
+ try:
115
+ weight_loader.load_t5(t5_paths, self.t5_encoder)
116
+ except Exception as exc:
117
+ logger.warning("[MLX] T5 weight loading failed: %s β€” using random init", exc)
118
+ if clip_path:
119
+ try:
120
+ weight_loader.load_clip(clip_path, self.clip_encoder)
121
+ except Exception as exc:
122
+ logger.warning("[MLX] CLIP weight loading failed: %s β€” using random init", exc)
123
+
124
+ # DiT weights β€” load directly into transformer model
125
+ try:
126
+ self._load_dit(dit_path)
127
+ except Exception as exc:
128
+ logger.warning("[MLX] DiT weight loading failed: %s β€” using random init", exc)
129
+
130
+ # VAE weights
131
+ try:
132
+ self._load_vae(ae_path)
133
+ except Exception as exc:
134
+ logger.warning("[MLX] VAE weight loading failed: %s β€” using random init", exc)
135
+
136
+ self._loaded = True
137
+ logger.info("[MLX] FLUX pipeline ready (%.1fs)", time.time() - t0)
138
+
139
+ def _download_dit(self, repo_id: str) -> str:
140
+ """Download DiT weights file."""
141
+ if "4bit" in repo_id:
142
+ return hf_hub_download(repo_id, "flux-schnell-4bit-quantized.safetensors")
143
+ return hf_hub_download(repo_id, "flux-schnell.safetensors")
144
+
145
+ def _download_t5_weights(self) -> list[str]:
146
+ """Get T5-XXL encoder weight paths β€” local weights/ dir first, HF fallback."""
147
+ local1 = _WEIGHTS_DIR / "t5_shard1.safetensors"
148
+ local2 = _WEIGHTS_DIR / "t5_shard2.safetensors"
149
+ if local1.exists() and local2.exists():
150
+ logger.info("[MLX] T5 weights ready (local)")
151
+ return [str(local1), str(local2)]
152
+ try:
153
+ p1 = hf_hub_download(T5_CLIP_REPO, "text_encoder_2/model-00001-of-00002.safetensors")
154
+ p2 = hf_hub_download(T5_CLIP_REPO, "text_encoder_2/model-00002-of-00002.safetensors")
155
+ logger.info("[MLX] T5 weights ready (HF cache)")
156
+ return [p1, p2]
157
+ except Exception as exc:
158
+ logger.warning("[MLX] T5 weights not available: %s β€” text encoding will be limited", exc)
159
+ return []
160
+
161
+ def _download_clip_weights(self) -> str | None:
162
+ """Get CLIP encoder weight path β€” local weights/ dir first, HF fallback."""
163
+ local = _WEIGHTS_DIR / "clip_text_encoder.safetensors"
164
+ if local.exists():
165
+ logger.info("[MLX] CLIP weights ready (local)")
166
+ return str(local)
167
+ try:
168
+ path = hf_hub_download(T5_CLIP_REPO, "text_encoder/model.safetensors")
169
+ logger.info("[MLX] CLIP weights ready (HF cache)")
170
+ return path
171
+ except Exception as exc:
172
+ logger.warning("[MLX] CLIP weights not available: %s", exc)
173
+ return None
174
+
175
+ @staticmethod
176
+ def _local_or_hf(local_name: str, hf_path: str) -> str:
177
+ """Return local path if exists, else download from HF."""
178
+ local = _WEIGHTS_DIR / local_name
179
+ if local.exists():
180
+ return str(local)
181
+ return hf_hub_download(T5_CLIP_REPO, hf_path)
182
+
183
+ def _load_dit(self, path: str) -> None:
184
+ """Load DiT weights into transformer.
185
+
186
+ For 4-bit quantized models, quantizes the model's Linear layers
187
+ to QuantizedLinear first so load_weights can accept uint32 triplets.
188
+ """
189
+ logger.info("[MLX] Loading DiT weights from %s ...", os.path.basename(path))
190
+ # Use weight_loader's robust safetensors reader (handles bfloat16)
191
+ weights = weight_loader._load_safetensors(path)
192
+
193
+ # Detect if quantized: check for `.scales` keys
194
+ is_quantized = any(k.endswith(".scales") for k in weights)
195
+ if is_quantized:
196
+ logger.info("[MLX] Detected quantized weights β€” converting model layers")
197
+ # Quantize Linear layers except x_embedder (its weight is float, not quantized)
198
+ def _should_quantize(path, module):
199
+ if "x_embedder" in path:
200
+ return False
201
+ return isinstance(module, nn.Linear)
202
+ nn.quantize(self.transformer, group_size=64, bits=4, class_predicate=_should_quantize)
203
+
204
+ # Map x_embedder.proj.* β†’ x_embedder.* (Conv2d weights β†’ Linear)
205
+ remapped = {}
206
+ for k, v in weights.items():
207
+ new_k = k
208
+ if k.startswith("x_embedder.proj."):
209
+ new_k = k.replace("x_embedder.proj.", "x_embedder.")
210
+ # Squeeze conv dimensions: [out, 1, 1, in] β†’ [out, in]
211
+ if v.ndim == 4:
212
+ v = v.squeeze()
213
+ remapped[new_k] = v
214
+
215
+ pairs = list(remapped.items())
216
+ self.transformer.load_weights(pairs, strict=False)
217
+ logger.info("[MLX] DiT loaded: %d weight tensors (quantized=%s)", len(pairs), is_quantized)
218
+
219
+ def _load_vae(self, path: str) -> None:
220
+ """Load VAE weights. Transposes conv weights from PyTorch NCHW to MLX NHWC."""
221
+ logger.info("[MLX] Loading VAE weights from %s ...", os.path.basename(path))
222
+ weights = weight_loader._load_safetensors(path)
223
+
224
+ # Only keep decoder weights (skip encoder)
225
+ transposed = {}
226
+ n_transposed = 0
227
+ for k, v in weights.items():
228
+ if not k.startswith("decoder."):
229
+ continue
230
+ # Transpose conv weights: PyTorch [O, I, kH, kW] β†’ MLX [O, kH, kW, I]
231
+ if v.ndim == 4:
232
+ v = v.transpose(0, 2, 3, 1)
233
+ n_transposed += 1
234
+ transposed[k] = v
235
+
236
+ pairs = list(transposed.items())
237
+ self.vae.load_weights(pairs, strict=False)
238
+ logger.info("[MLX] VAE loaded: %d tensors (%d conv transposed)", len(pairs), n_transposed)
239
+
240
+ # ── Generation ───────────────────────────────────────────────────────
241
+
242
+ def generate(
243
+ self,
244
+ prompt: str,
245
+ *,
246
+ width: int = 512,
247
+ height: int = 512,
248
+ steps: int = 4,
249
+ seed: int | None = None,
250
+ progress_callback=None,
251
+ ) -> Image.Image:
252
+ """Generate an image from a text prompt.
253
+
254
+ Args:
255
+ prompt: Text description.
256
+ width: Image width (rounded to multiple of 16).
257
+ height: Image height (rounded to multiple of 16).
258
+ steps: Denoising steps (default 4 for schnell).
259
+ seed: Random seed for reproducibility.
260
+ progress_callback: fn(step, total_steps) called per step.
261
+
262
+ Returns:
263
+ PIL.Image.
264
+ """
265
+ if not self._loaded:
266
+ raise RuntimeError("Pipeline not loaded. Call load() first.")
267
+
268
+ # Round to multiple of 16
269
+ width = (width // 16) * 16
270
+ height = (height // 16) * 16
271
+
272
+ if seed is not None:
273
+ mx.random.seed(seed)
274
+
275
+ # ── Phase 1: Text encoding ──────────────────────────────────────
276
+ logger.info("[MLX] Phase 1: Text encoding...")
277
+ t0 = time.time()
278
+
279
+ t5_ids = self.t5_tokenizer.tokenize(prompt)
280
+ clip_ids = self.clip_tokenizer.tokenize(prompt)
281
+
282
+ t5_embed = self.t5_encoder(t5_ids) # [1, 256, 4096]
283
+ clip_pooled, _ = self.clip_encoder(clip_ids) # [1, 768]
284
+ mx.eval(t5_embed, clip_pooled)
285
+
286
+ logger.info("[MLX] Text encoding done (%.1fs)", time.time() - t0)
287
+
288
+ # ── Phase 1β†’2 transition: free text encoders ────────────────────
289
+ if self.phased_memory:
290
+ logger.info("[MLX] Releasing text encoders to free memory...")
291
+ self.t5_encoder = None
292
+ self.clip_encoder = None
293
+ mx.clear_cache()
294
+
295
+ # ── Phase 2: Denoising ──────────────────────────────────────────
296
+ logger.info("[MLX] Phase 2: Denoising (%d steps)...", steps)
297
+ t1 = time.time()
298
+
299
+ lat_h = height // 8
300
+ lat_w = width // 8
301
+
302
+ # Initial noise
303
+ noise = mx.random.normal((1, lat_h, lat_w, 16))
304
+ latents = patchify(noise) # [1, seq, 64]
305
+ img_ids = compute_img_ids(lat_h, lat_w)
306
+ txt_ids = mx.zeros((1, t5_embed.shape[1], 3), dtype=mx.int32)
307
+
308
+ # Compute sigma schedule with exponential time shift
309
+ image_seq_len = latents.shape[1]
310
+ timesteps_list, sigmas = self.scheduler.compute_sigmas(steps, image_seq_len)
311
+
312
+ # Optionally compile the transformer forward pass for speed
313
+ _forward_fn = self.transformer
314
+ if self.use_compile:
315
+ try:
316
+ _forward_fn = mx.compile(self.transformer)
317
+ logger.info("[MLX] Using mx.compile for DiT forward pass")
318
+ except Exception as exc:
319
+ logger.warning("[MLX] mx.compile not available: %s", exc)
320
+
321
+ for i in range(steps):
322
+ sigma = sigmas[i]
323
+ sigma_next = sigmas[i + 1]
324
+ t = mx.array([timesteps_list[i]])
325
+
326
+ # Scale latents before transformer
327
+ latents_scaled = self.scheduler.scale_latents(latents, sigma)
328
+
329
+ v_pred = _forward_fn(
330
+ latents_scaled, img_ids,
331
+ t5_embed, txt_ids,
332
+ clip_pooled, t,
333
+ )
334
+ mx.eval(v_pred)
335
+
336
+ # Euler step: dt = sigma_next - sigma
337
+ latents = self.scheduler.step(v_pred, sigma, sigma_next, latents)
338
+ mx.eval(latents)
339
+
340
+ if progress_callback:
341
+ progress_callback(i + 1, steps)
342
+ logger.debug("[MLX] Step %d/%d (sigma=%.4f)", i + 1, steps, sigma)
343
+
344
+ logger.info("[MLX] Denoising done (%.1fs)", time.time() - t1)
345
+
346
+ # ── Phase 2β†’3 transition: free DiT to make room for VAE ─────────
347
+ if self.phased_memory:
348
+ logger.info("[MLX] Releasing DiT to free memory...")
349
+ self.transformer = None
350
+ mx.clear_cache()
351
+
352
+ # ── Phase 3: VAE decode ─────────────────────────────────────────
353
+ logger.info("[MLX] Phase 3: VAE decode...")
354
+ t2 = time.time()
355
+
356
+ z = unpatchify(latents, lat_h, lat_w) # [1, lat_h, lat_w, 16]
357
+ image = self.vae.decode(z) # [1, H, W, 3]
358
+ mx.eval(image)
359
+
360
+ logger.info("[MLX] VAE decode done (%.1fs)", time.time() - t2)
361
+
362
+ # ── Convert to PIL ──────────────────────────────────────────────
363
+ import numpy as np
364
+
365
+ img_np = np.array(image[0], copy=False) # [H, W, 3]
366
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
367
+ return Image.fromarray(img_np)
368
+
369
+ # ── Memory management ────────────────────────────────────────────────
370
+
371
+ def unload(self) -> None:
372
+ """Free all model components from memory."""
373
+ self.t5_encoder = None
374
+ self.clip_encoder = None
375
+ self.transformer = None
376
+ self.vae = None
377
+ self._loaded = False
378
+
379
+ try:
380
+ mx.clear_cache()
381
+ except Exception:
382
+ pass
383
+ logger.info("[MLX] Pipeline unloaded")
384
+
385
+ def memory_info(self) -> dict:
386
+ """Return MLX Metal memory usage info."""
387
+ try:
388
+ peak = mx.get_peak_memory() / (1024 ** 3)
389
+ active = mx.get_active_memory() / (1024 ** 3)
390
+ cache = mx.get_cache_memory() / (1024 ** 3)
391
+ return {
392
+ "peak_gb": round(peak, 2),
393
+ "active_gb": round(active, 2),
394
+ "cache_gb": round(cache, 2),
395
+ }
396
+ except AttributeError:
397
+ # Fallback for older MLX with mx.metal.* API
398
+ try:
399
+ peak = mx.metal.get_peak_memory() / (1024 ** 3)
400
+ active = mx.metal.get_active_memory() / (1024 ** 3)
401
+ cache = mx.metal.get_cache_memory() / (1024 ** 3)
402
+ return {
403
+ "peak_gb": round(peak, 2),
404
+ "active_gb": round(active, 2),
405
+ "cache_gb": round(cache, 2),
406
+ }
407
+ except Exception:
408
+ return {}
409
+ except Exception:
410
+ return {}
sampler.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow-matching Euler scheduler + latent patchify/unpatchify.
2
+
3
+ Matches mflux FlowMatchEulerDiscreteScheduler implementation.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import math
9
+
10
+ import mlx.core as mx
11
+
12
+
13
+ class FlowMatchEulerScheduler:
14
+ """Rectified flow matching scheduler with exponential time shift.
15
+
16
+ Matches mflux's FlowMatchEulerDiscreteScheduler exactly.
17
+ """
18
+
19
+ def compute_sigmas(self, num_steps: int, image_seq_len: int = 256) -> tuple[list[float], list[float]]:
20
+ """Compute sigma schedule with exponential time shift.
21
+
22
+ Returns (timesteps, sigmas) where:
23
+ - timesteps: sigma * 1000 (for transformer input)
24
+ - sigmas: actual sigma values including terminal 0.0
25
+ """
26
+ # Linear base sigmas
27
+ sigmas = mx.linspace(1.0, 1.0 / num_steps, num_steps, dtype=mx.float32)
28
+
29
+ # Exponential time shift (mflux convention)
30
+ mu = self._compute_mu(image_seq_len, num_steps)
31
+ sigmas = self._time_shift(mu, 1.0, sigmas)
32
+
33
+ timesteps = (sigmas * 1000).tolist()
34
+ sigmas_list = sigmas.tolist() + [0.0]
35
+
36
+ return timesteps, sigmas_list
37
+
38
+ @staticmethod
39
+ def _compute_mu(image_seq_len: int, num_steps: int) -> float:
40
+ """Empirical mu for time shift. Matches mflux."""
41
+ # mu = 0.5 * log(image_seq_len) β€” mflux empirical formula
42
+ return 0.5 * math.log(image_seq_len)
43
+
44
+ @staticmethod
45
+ def _time_shift(mu: float, sigma_max: float, sigmas: mx.array) -> mx.array:
46
+ """Exponential interpolation time shift."""
47
+ return mx.exp(mu) / (mx.exp(mu) + (1 / sigmas - 1))
48
+
49
+ def scale_latents(self, latents: mx.array, sigma: float) -> mx.array:
50
+ """Scale latents before transformer input: x / sqrt(σ² + 1)."""
51
+ return latents / ((sigma ** 2 + 1) ** 0.5)
52
+
53
+ def step(
54
+ self,
55
+ velocity: mx.array,
56
+ sigma_cur: float,
57
+ sigma_next: float,
58
+ latents: mx.array,
59
+ ) -> mx.array:
60
+ """One Euler step in the flow ODE.
61
+
62
+ Args:
63
+ velocity: v_ΞΈ prediction [B, seq, D]
64
+ sigma_cur: current sigma level
65
+ sigma_next: next sigma level
66
+ latents: current latent state [B, seq, D]
67
+
68
+ Returns:
69
+ Updated latents.
70
+ """
71
+ dt = mx.array(sigma_cur - sigma_next, dtype=latents.dtype)
72
+ return latents + dt * velocity
73
+
74
+
75
+ def patchify(latents: mx.array) -> mx.array:
76
+ """Convert spatial latents to patch tokens.
77
+
78
+ [B, H, W, C] β†’ [B, (H/2)Γ—(W/2), CΓ—4]
79
+
80
+ FLUX uses 2Γ—2 patches over the latent space, flattening each
81
+ 2Γ—2Γ—C patch into a single C*4 token.
82
+ """
83
+ B, H, W, C = latents.shape
84
+ latents = latents.reshape(B, H // 2, 2, W // 2, 2, C)
85
+ latents = latents.transpose(0, 1, 3, 2, 4, 5) # [B, H/2, W/2, 2, 2, C]
86
+ return latents.reshape(B, (H // 2) * (W // 2), C * 4) # [B, seq, 64]
87
+
88
+
89
+ def unpatchify(tokens: mx.array, h: int, w: int) -> mx.array:
90
+ """Convert patch tokens back to spatial latents.
91
+
92
+ [B, seq, C*4] β†’ [B, H, W, C]
93
+
94
+ Args:
95
+ tokens: [B, (h/2)*(w/2), 64]
96
+ h, w: latent spatial dimensions (before patching)
97
+ """
98
+ B = tokens.shape[0]
99
+ C = tokens.shape[-1] // 4 # 64 / 4 = 16
100
+
101
+ tokens = tokens.reshape(B, h // 2, w // 2, 2, 2, C)
102
+ tokens = tokens.transpose(0, 1, 3, 2, 4, 5) # [B, H, W, C]
103
+ return tokens.reshape(B, h, w, C)
104
+
105
+
106
+ def compute_img_ids(lat_h: int, lat_w: int) -> mx.array:
107
+ """Compute position IDs for patchified latent tokens.
108
+
109
+ Returns [1, (lat_h/2)*(lat_w/2), 3] with (time=0, h_pos, w_pos).
110
+ """
111
+ h_ids = mx.arange(lat_h // 2)
112
+ w_ids = mx.arange(lat_w // 2)
113
+
114
+ # Meshgrid: [lat_h//2, lat_w//2]
115
+ h_grid = mx.repeat(h_ids[:, None], lat_w // 2, axis=1)
116
+ w_grid = mx.repeat(w_ids[None, :], lat_h // 2, axis=0)
117
+
118
+ # Flatten and stack with time=0
119
+ seq_len = (lat_h // 2) * (lat_w // 2)
120
+ t_ids = mx.zeros((seq_len,), dtype=mx.int32)
121
+ h_flat = h_grid.reshape(-1)
122
+ w_flat = w_grid.reshape(-1)
123
+
124
+ ids = mx.stack([t_ids, h_flat, w_flat], axis=-1) # [seq, 3]
125
+ return ids.reshape(1, seq_len, 3) # [1, seq, 3]
t5_encoder.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """T5-XXL text encoder for FLUX pipeline.
2
+
3
+ Implements a 24-layer T5 encoder with relative position bias,
4
+ gated FFN (GeLU), and RMS LayerNorm β€” matching the HuggingFace
5
+ ``google/t5-xxl`` architecture used by FLUX.1.
6
+
7
+ Weight source: ``black-forest-labs/FLUX.1-schnell`` β†’
8
+ ``text_encoder_2/model-0000{1,2}-of-00002.safetensors``
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+
15
+ import mlx.core as mx
16
+ import mlx.nn as nn
17
+
18
+
19
+ # ── T5 Config (XXL) ──────────────────────────────────────────────────────────
20
+
21
+ class T5Config:
22
+ vocab_size: int = 32128
23
+ d_model: int = 4096
24
+ d_ff: int = 10240
25
+ num_heads: int = 64
26
+ head_dim: int = 64 # d_model / num_heads
27
+ num_layers: int = 24
28
+ relative_attention_num_buckets: int = 32
29
+ relative_attention_max_distance: int = 128
30
+
31
+
32
+ # ── Building blocks ──────────────────────────────────────────────────────────
33
+
34
+ class T5RMSNorm(nn.Module):
35
+ """T5-style RMS LayerNorm (no bias, no mean subtraction)."""
36
+
37
+ def __init__(self, d: int, eps: float = 1e-6):
38
+ super().__init__()
39
+ self.weight = mx.ones((d,))
40
+ self.eps = eps
41
+
42
+ def __call__(self, x: mx.array) -> mx.array:
43
+ variance = mx.mean(x * x, axis=-1, keepdims=True)
44
+ x = x * mx.rsqrt(variance + self.eps)
45
+ return x * self.weight
46
+
47
+
48
+ class T5RelativeAttention(nn.Module):
49
+ """Multi-head attention with T5 relative position bias."""
50
+
51
+ def __init__(self, cfg: T5Config, has_relative_bias: bool = False):
52
+ super().__init__()
53
+ self.num_heads = cfg.num_heads
54
+ self.head_dim = cfg.head_dim
55
+ d = cfg.d_model
56
+
57
+ self.q_proj = nn.Linear(d, d, bias=False)
58
+ self.k_proj = nn.Linear(d, d, bias=False)
59
+ self.v_proj = nn.Linear(d, d, bias=False)
60
+ self.out_proj = nn.Linear(d, d, bias=False)
61
+
62
+ self.has_relative_bias = has_relative_bias
63
+ if has_relative_bias:
64
+ self.rel_bias = nn.Embedding(
65
+ cfg.relative_attention_num_buckets, cfg.num_heads,
66
+ )
67
+ self._num_buckets = cfg.relative_attention_num_buckets
68
+ self._max_distance = cfg.relative_attention_max_distance
69
+
70
+ @staticmethod
71
+ def _relative_position_bucket(
72
+ relative_position: mx.array,
73
+ num_buckets: int = 32,
74
+ max_distance: int = 128,
75
+ ) -> mx.array:
76
+ """T5-style relative position bucketing.
77
+
78
+ Maps relative positions to bucket indices for the learned bias.
79
+ Bidirectional: first half for negative, second for positive.
80
+ """
81
+ # Bidirectional: use half the buckets for each direction
82
+ num_buckets //= 2
83
+ # Sign-based offset
84
+ relative_buckets = (relative_position > 0).astype(mx.int32) * num_buckets
85
+ relative_position = mx.abs(relative_position)
86
+
87
+ # Small positions are mapped linearly
88
+ max_exact = num_buckets // 2
89
+ is_small = relative_position < max_exact
90
+
91
+ # Larger positions are mapped logarithmically
92
+ val = mx.log(relative_position.astype(mx.float32) / max_exact) / math.log(
93
+ max_distance / max_exact
94
+ )
95
+ val = val * (num_buckets - max_exact)
96
+ relative_position_if_large = (max_exact + val).astype(mx.int32)
97
+ relative_position_if_large = mx.minimum(
98
+ relative_position_if_large,
99
+ mx.array(num_buckets - 1, dtype=mx.int32),
100
+ )
101
+
102
+ relative_buckets = relative_buckets + mx.where(
103
+ is_small,
104
+ relative_position.astype(mx.int32),
105
+ relative_position_if_large,
106
+ )
107
+ return relative_buckets
108
+
109
+ def _compute_bias(self, seq_len: int) -> mx.array:
110
+ """Compute relative position bias [1, num_heads, seq, seq]."""
111
+ positions = mx.arange(seq_len)
112
+ # relative_position[i, j] = j - i
113
+ relative_position = positions[None, :] - positions[:, None]
114
+ buckets = self._relative_position_bucket(
115
+ relative_position,
116
+ num_buckets=self._num_buckets,
117
+ max_distance=self._max_distance,
118
+ )
119
+ # [seq, seq] β†’ lookup β†’ [seq, seq, num_heads]
120
+ bias = self.rel_bias(buckets)
121
+ # β†’ [1, num_heads, seq, seq]
122
+ bias = bias.transpose(2, 0, 1).reshape(1, self.num_heads, seq_len, seq_len)
123
+ return bias
124
+
125
+ def __call__(
126
+ self, x: mx.array, position_bias: mx.array | None = None,
127
+ ) -> tuple[mx.array, mx.array | None]:
128
+ B, L, _ = x.shape
129
+ H, D = self.num_heads, self.head_dim
130
+
131
+ q = self.q_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
132
+ k = self.k_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
133
+ v = self.v_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
134
+
135
+ # Scaled dot-product attention
136
+ scale = math.sqrt(D)
137
+ scores = (q @ k.transpose(0, 1, 3, 2)) / scale # [B, H, L, L]
138
+
139
+ # Add relative position bias
140
+ if self.has_relative_bias:
141
+ position_bias = self._compute_bias(L)
142
+ if position_bias is not None:
143
+ scores = scores + position_bias
144
+
145
+ weights = mx.softmax(scores, axis=-1)
146
+ out = weights @ v # [B, H, L, D]
147
+ out = out.transpose(0, 2, 1, 3).reshape(B, L, -1) # [B, L, d_model]
148
+ return self.out_proj(out), position_bias
149
+
150
+
151
+ class T5GatedFFN(nn.Module):
152
+ """T5 Gated Feed-Forward Network (GeLU gate)."""
153
+
154
+ def __init__(self, cfg: T5Config):
155
+ super().__init__()
156
+ self.wi_0 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) # gate
157
+ self.wi_1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) # value
158
+ self.wo = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
159
+
160
+ def __call__(self, x: mx.array) -> mx.array:
161
+ gate = nn.gelu_approx(self.wi_0(x))
162
+ value = self.wi_1(x)
163
+ return self.wo(gate * value)
164
+
165
+
166
+ class T5EncoderLayer(nn.Module):
167
+ """Single T5 encoder layer: Norm β†’ Attention β†’ Norm β†’ FFN."""
168
+
169
+ def __init__(self, cfg: T5Config, has_relative_bias: bool = False):
170
+ super().__init__()
171
+ self.attn_norm = T5RMSNorm(cfg.d_model)
172
+ self.attn = T5RelativeAttention(cfg, has_relative_bias=has_relative_bias)
173
+ self.ffn_norm = T5RMSNorm(cfg.d_model)
174
+ self.ffn = T5GatedFFN(cfg)
175
+
176
+ def __call__(
177
+ self, x: mx.array, position_bias: mx.array | None = None,
178
+ ) -> tuple[mx.array, mx.array | None]:
179
+ # Pre-norm attention
180
+ residual = x
181
+ x = self.attn_norm(x)
182
+ x, position_bias = self.attn(x, position_bias)
183
+ x = residual + x
184
+
185
+ # Pre-norm FFN
186
+ residual = x
187
+ x = self.ffn_norm(x)
188
+ x = self.ffn(x)
189
+ x = residual + x
190
+
191
+ return x, position_bias
192
+
193
+
194
+ # ── T5 Encoder ───────────────────────────────────────────────────────────────
195
+
196
+ class T5Encoder(nn.Module):
197
+ """T5-XXL encoder: 24-layer transformer with relative position bias.
198
+
199
+ Input: token_ids [B, seq_len]
200
+ Output: embeddings [B, seq_len, 4096]
201
+ """
202
+
203
+ def __init__(self, cfg: T5Config | None = None):
204
+ super().__init__()
205
+ if cfg is None:
206
+ cfg = T5Config()
207
+ self.cfg = cfg
208
+
209
+ self.wte = nn.Embedding(cfg.vocab_size, cfg.d_model)
210
+ self.layers = [
211
+ T5EncoderLayer(cfg, has_relative_bias=(i == 0))
212
+ for i in range(cfg.num_layers)
213
+ ]
214
+ self.final_norm = T5RMSNorm(cfg.d_model)
215
+
216
+ def __call__(self, token_ids: mx.array, use_transformer: bool = True) -> mx.array:
217
+ x = self.wte(token_ids) # [B, L, d_model]
218
+
219
+ if use_transformer:
220
+ # Full 24-layer transformer (requires correct numerical implementation)
221
+ position_bias = None
222
+ for layer in self.layers:
223
+ x, position_bias = layer(x, position_bias)
224
+
225
+ x = self.final_norm(x)
226
+ return x
tokenizers.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenizers for FLUX pipeline β€” T5 (SentencePiece) and CLIP (BPE).
2
+
3
+ Both tokenizers produce mx.array token ID tensors ready for encoder input.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ import mlx.core as mx
13
+
14
+ logger = logging.getLogger("image-server")
15
+
16
+
17
+ class T5Tokenizer:
18
+ """T5-XXL SentencePiece tokenizer.
19
+
20
+ Loads ``spiece.model`` from the FLUX.1-schnell repo
21
+ (``tokenizer_2/spiece.model``).
22
+ """
23
+
24
+ def __init__(self, spiece_path: str, max_length: int = 256):
25
+ import sentencepiece as spm
26
+
27
+ self._sp = spm.SentencePieceProcessor()
28
+ self._sp.Load(spiece_path)
29
+ self.max_length = max_length
30
+ self.pad_id = 0
31
+
32
+ def tokenize(self, text: str) -> mx.array:
33
+ """Tokenize text β†’ [1, max_length] int32 tensor."""
34
+ ids = self._sp.Encode(text)
35
+ # Truncate
36
+ if len(ids) > self.max_length:
37
+ ids = ids[: self.max_length]
38
+ # Pad
39
+ pad_len = self.max_length - len(ids)
40
+ if pad_len > 0:
41
+ ids = ids + [self.pad_id] * pad_len
42
+ return mx.array(ids, dtype=mx.int32).reshape(1, -1)
43
+
44
+
45
+ class CLIPTokenizer:
46
+ """CLIP BPE tokenizer.
47
+
48
+ Loads ``vocab.json`` (token→id) and ``merges.txt`` (BPE merge rules)
49
+ from ``tokenizer/`` in the FLUX.1-schnell repo.
50
+ """
51
+
52
+ BOS_ID = 49406 # <|startoftext|>
53
+ EOS_ID = 49407 # <|endoftext|>
54
+
55
+ def __init__(self, vocab_path: str, merges_path: str, max_length: int = 77):
56
+ # Load vocab: token_str β†’ id
57
+ with open(vocab_path, encoding="utf-8") as f:
58
+ self._vocab: dict[str, int] = json.load(f)
59
+
60
+ # Load BPE merges from merges.txt
61
+ self._merges: list[tuple[str, str]] = []
62
+ self._merge_rank: dict[tuple[str, str], int] = {}
63
+ with open(merges_path, encoding="utf-8") as f:
64
+ for i, line in enumerate(f):
65
+ line = line.strip()
66
+ if not line or line.startswith("#"):
67
+ continue
68
+ parts = line.split()
69
+ if len(parts) == 2:
70
+ pair = (parts[0], parts[1])
71
+ self._merges.append(pair)
72
+ self._merge_rank[pair] = i
73
+
74
+ self.max_length = max_length
75
+ self.pad_id = 0
76
+
77
+ # pre/post processing regex (simplified CLIP pattern)
78
+ import regex
79
+ self._pat = regex.compile(
80
+ r"""'s|'t|'re|'ve|'m|'ll|'d|"""
81
+ r"""[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
82
+ regex.IGNORECASE,
83
+ )
84
+
85
+ def _bpe(self, token: str) -> list[str]:
86
+ """Apply BPE merges to a single word token."""
87
+ if len(token) <= 1:
88
+ return [token + "</w>"] if token else []
89
+
90
+ # Add end-of-word marker
91
+ word = list(token[:-1]) + [token[-1] + "</w>"]
92
+
93
+ while len(word) > 1:
94
+ # Find the highest-priority merge pair
95
+ best_pair = None
96
+ best_rank = float("inf")
97
+ for i in range(len(word) - 1):
98
+ pair = (word[i], word[i + 1])
99
+ rank = self._merge_rank.get(pair, float("inf"))
100
+ if rank < best_rank:
101
+ best_rank = rank
102
+ best_pair = pair
103
+
104
+ if best_pair is None or best_rank == float("inf"):
105
+ break
106
+
107
+ # Apply the merge
108
+ new_word = []
109
+ i = 0
110
+ while i < len(word):
111
+ if (
112
+ i < len(word) - 1
113
+ and word[i] == best_pair[0]
114
+ and word[i + 1] == best_pair[1]
115
+ ):
116
+ new_word.append(best_pair[0] + best_pair[1])
117
+ i += 2
118
+ else:
119
+ new_word.append(word[i])
120
+ i += 1
121
+ word = new_word
122
+
123
+ return word
124
+
125
+ def tokenize(self, text: str) -> mx.array:
126
+ """Tokenize text β†’ [1, max_length] int32 tensor."""
127
+ text = text.lower().strip()
128
+
129
+ ids = [self.BOS_ID]
130
+
131
+ # Tokenize each word
132
+ for match in self._pat.finditer(text):
133
+ word = match.group()
134
+ bpe_tokens = self._bpe(word)
135
+ for bt in bpe_tokens:
136
+ token_id = self._vocab.get(bt, 0)
137
+ ids.append(token_id)
138
+
139
+ ids.append(self.EOS_ID)
140
+
141
+ # Truncate (keep BOS at start, EOS at end)
142
+ if len(ids) > self.max_length:
143
+ ids = ids[: self.max_length - 1] + [self.EOS_ID]
144
+
145
+ # Pad
146
+ pad_len = self.max_length - len(ids)
147
+ if pad_len > 0:
148
+ ids = ids + [self.pad_id] * pad_len
149
+
150
+ return mx.array(ids, dtype=mx.int32).reshape(1, -1)
weight_loader.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Weight loader β€” safetensors β†’ MLX model parameter mapping.
2
+
3
+ Handles multi-shard loading, HuggingFace key β†’ self-defined key mapping,
4
+ and dtype conversion. Each ``load_*`` function takes a model instance
5
+ and populates its parameters in-place.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ import mlx.core as mx
14
+ import mlx.nn as nn
15
+
16
+ logger = logging.getLogger("image-server")
17
+
18
+
19
+ # ── Utilities ────────────────────────────────────────────────────────────────
20
+
21
+ def _load_safetensors(*paths: str) -> dict[str, mx.array]:
22
+ """Load one or more safetensors files into a flat dict.
23
+
24
+ Tries MLX framework first. Falls back to PyTorch for bfloat16 weights,
25
+ preserving uint32 quantized weights as-is.
26
+ """
27
+ from safetensors import safe_open
28
+
29
+ weights: dict[str, mx.array] = {}
30
+ for p in paths:
31
+ # Try mlx framework first (fastest)
32
+ try:
33
+ with safe_open(p, framework="mlx") as f:
34
+ for key in f.keys():
35
+ weights[key] = f.get_tensor(key)
36
+ logger.info("[WeightLoader] Loaded %s via MLX framework", p.split("/")[-1])
37
+ continue
38
+ except Exception:
39
+ pass # bfloat16 or other incompatibility β†’ fallback to pt
40
+
41
+ # Fallback: load via PyTorch, selectively convert
42
+ try:
43
+ import torch
44
+ with safe_open(p, framework="pt") as f:
45
+ for key in f.keys():
46
+ t = f.get_tensor(key)
47
+ if t.dtype == torch.uint32:
48
+ # Quantized weight β€” keep as uint32
49
+ weights[key] = mx.array(t.numpy())
50
+ elif t.dtype == torch.bfloat16:
51
+ # bfloat16 β†’ float32 β†’ mx (bfloat16 not supported by numpy)
52
+ weights[key] = mx.array(t.float().numpy())
53
+ else:
54
+ # float32, float16, etc. β€” direct conversion
55
+ weights[key] = mx.array(t.numpy())
56
+ logger.info("[WeightLoader] Loaded %s via PyTorch fallback (%d tensors)", p.split("/")[-1], len(weights))
57
+ except Exception as exc:
58
+ logger.error("[WeightLoader] Failed to load %s: %s", p, exc)
59
+ return weights
60
+
61
+
62
+ def _set_nested(model: nn.Module, dotpath: str, value: mx.array) -> bool:
63
+ """Set a parameter on *model* using a dot-separated path.
64
+
65
+ Returns True if set successfully, False if path does not exist.
66
+ """
67
+ parts = dotpath.split(".")
68
+ obj = model
69
+ for part in parts[:-1]:
70
+ if part.isdigit():
71
+ obj = obj[int(part)]
72
+ elif hasattr(obj, part):
73
+ obj = getattr(obj, part)
74
+ else:
75
+ return False
76
+ final = parts[-1]
77
+ if hasattr(obj, final):
78
+ setattr(obj, final, value)
79
+ return True
80
+ return False
81
+
82
+
83
+ # ── T5 Weight Loader ─────────────────────────────────────────────────────────
84
+
85
+ # HuggingFace T5 key β†’ our T5Encoder parameter path
86
+ _T5_KEY_MAP_TEMPLATES = {
87
+ # Embedding
88
+ "encoder.embed_tokens.weight": "wte.weight",
89
+ "shared.weight": "wte.weight", # some checkpoints use this key
90
+
91
+ # Per-layer keys (use {i} placeholder)
92
+ "encoder.block.{i}.layer.0.layer_norm.weight": "layers.{i}.attn_norm.weight",
93
+ "encoder.block.{i}.layer.0.SelfAttention.q.weight": "layers.{i}.attn.q_proj.weight",
94
+ "encoder.block.{i}.layer.0.SelfAttention.k.weight": "layers.{i}.attn.k_proj.weight",
95
+ "encoder.block.{i}.layer.0.SelfAttention.v.weight": "layers.{i}.attn.v_proj.weight",
96
+ "encoder.block.{i}.layer.0.SelfAttention.o.weight": "layers.{i}.attn.out_proj.weight",
97
+ "encoder.block.{i}.layer.1.layer_norm.weight": "layers.{i}.ffn_norm.weight",
98
+ "encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": "layers.{i}.ffn.wi_0.weight",
99
+ "encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": "layers.{i}.ffn.wi_1.weight",
100
+ "encoder.block.{i}.layer.1.DenseReluDense.wo.weight": "layers.{i}.ffn.wo.weight",
101
+
102
+ # Relative attention bias (only layer 0)
103
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight":
104
+ "layers.0.attn.rel_bias.weight",
105
+
106
+ # Final norm
107
+ "encoder.final_layer_norm.weight": "final_norm.weight",
108
+ }
109
+
110
+
111
+ def _build_t5_key_map(num_layers: int = 24) -> dict[str, str]:
112
+ """Expand layer-templated keys for all layers."""
113
+ mapping: dict[str, str] = {}
114
+ for hf_template, our_template in _T5_KEY_MAP_TEMPLATES.items():
115
+ if "{i}" in hf_template:
116
+ for i in range(num_layers):
117
+ hf_key = hf_template.replace("{i}", str(i))
118
+ our_key = our_template.replace("{i}", str(i))
119
+ mapping[hf_key] = our_key
120
+ else:
121
+ mapping[hf_template] = our_template
122
+ return mapping
123
+
124
+
125
+ def load_t5(paths: list[str], model) -> None:
126
+ """Load T5 encoder weights from safetensors into a T5Encoder model."""
127
+ weights = _load_safetensors(*paths)
128
+ key_map = _build_t5_key_map(num_layers=model.cfg.num_layers)
129
+
130
+ loaded = 0
131
+ unmapped = []
132
+ for hf_key, tensor in weights.items():
133
+ our_key = key_map.get(hf_key)
134
+ if our_key is None:
135
+ unmapped.append(hf_key)
136
+ continue
137
+ # T5 Linear weights are stored transposed in HF format:
138
+ # HF shape [out, in] but nn.Linear expects [out, in] too in MLX
139
+ # (MLX Linear does x @ W.T, so weight shape = [out, in])
140
+ if _set_nested(model, our_key, tensor):
141
+ loaded += 1
142
+ else:
143
+ unmapped.append(f"{hf_key} β†’ {our_key} (path not found)")
144
+
145
+ logger.info(
146
+ "[WeightLoader] T5: loaded %d/%d params, unmapped %d",
147
+ loaded, len(weights), len(unmapped),
148
+ )
149
+ if unmapped:
150
+ for k in unmapped[:10]:
151
+ logger.debug(" unmapped: %s", k)
152
+ if len(unmapped) > 10:
153
+ logger.debug(" ... and %d more", len(unmapped) - 10)
154
+
155
+
156
+ # ── CLIP Weight Loader ───────────────────────────────────────────────────────
157
+
158
+ _CLIP_KEY_MAP_TEMPLATES = {
159
+ # Embeddings
160
+ "text_model.embeddings.token_embedding.weight": "token_emb.weight",
161
+ "text_model.embeddings.position_embedding.weight": "pos_emb.weight",
162
+
163
+ # Per-layer keys
164
+ "text_model.encoder.layers.{i}.layer_norm1.weight": "layers.{i}.norm1.weight",
165
+ "text_model.encoder.layers.{i}.layer_norm1.bias": "layers.{i}.norm1.bias",
166
+ "text_model.encoder.layers.{i}.self_attn.q_proj.weight": "layers.{i}.attn.q_proj.weight",
167
+ "text_model.encoder.layers.{i}.self_attn.q_proj.bias": "layers.{i}.attn.q_proj.bias",
168
+ "text_model.encoder.layers.{i}.self_attn.k_proj.weight": "layers.{i}.attn.k_proj.weight",
169
+ "text_model.encoder.layers.{i}.self_attn.k_proj.bias": "layers.{i}.attn.k_proj.bias",
170
+ "text_model.encoder.layers.{i}.self_attn.v_proj.weight": "layers.{i}.attn.v_proj.weight",
171
+ "text_model.encoder.layers.{i}.self_attn.v_proj.bias": "layers.{i}.attn.v_proj.bias",
172
+ "text_model.encoder.layers.{i}.self_attn.out_proj.weight": "layers.{i}.attn.out_proj.weight",
173
+ "text_model.encoder.layers.{i}.self_attn.out_proj.bias": "layers.{i}.attn.out_proj.bias",
174
+ "text_model.encoder.layers.{i}.layer_norm2.weight": "layers.{i}.norm2.weight",
175
+ "text_model.encoder.layers.{i}.layer_norm2.bias": "layers.{i}.norm2.bias",
176
+ "text_model.encoder.layers.{i}.mlp.fc1.weight": "layers.{i}.mlp.fc1.weight",
177
+ "text_model.encoder.layers.{i}.mlp.fc1.bias": "layers.{i}.mlp.fc1.bias",
178
+ "text_model.encoder.layers.{i}.mlp.fc2.weight": "layers.{i}.mlp.fc2.weight",
179
+ "text_model.encoder.layers.{i}.mlp.fc2.bias": "layers.{i}.mlp.fc2.bias",
180
+
181
+ # Final norm
182
+ "text_model.final_layer_norm.weight": "final_norm.weight",
183
+ "text_model.final_layer_norm.bias": "final_norm.bias",
184
+ }
185
+
186
+
187
+ def _build_clip_key_map(num_layers: int = 23) -> dict[str, str]:
188
+ """Expand layer-templated keys for all layers."""
189
+ mapping: dict[str, str] = {}
190
+ for hf_template, our_template in _CLIP_KEY_MAP_TEMPLATES.items():
191
+ if "{i}" in hf_template:
192
+ for i in range(num_layers):
193
+ hf_key = hf_template.replace("{i}", str(i))
194
+ our_key = our_template.replace("{i}", str(i))
195
+ mapping[hf_key] = our_key
196
+ else:
197
+ mapping[hf_template] = our_template
198
+ return mapping
199
+
200
+
201
+ def load_clip(path: str, model) -> None:
202
+ """Load CLIP encoder weights from safetensors into a CLIPEncoder model."""
203
+ weights = _load_safetensors(path)
204
+ key_map = _build_clip_key_map(num_layers=model.cfg.num_layers)
205
+
206
+ loaded = 0
207
+ unmapped = []
208
+ for hf_key, tensor in weights.items():
209
+ our_key = key_map.get(hf_key)
210
+ if our_key is None:
211
+ unmapped.append(hf_key)
212
+ continue
213
+ if _set_nested(model, our_key, tensor):
214
+ loaded += 1
215
+ else:
216
+ unmapped.append(f"{hf_key} β†’ {our_key} (path not found)")
217
+
218
+ logger.info(
219
+ "[WeightLoader] CLIP: loaded %d/%d params, unmapped %d",
220
+ loaded, len(weights), len(unmapped),
221
+ )
222
+ if unmapped:
223
+ for k in unmapped[:10]:
224
+ logger.debug(" unmapped: %s", k)
225
+
226
+
227
+ # ── Placeholder for future rounds ────────────────────────────────────────────
228
+
229
+ def load_flux_dit(path: str, model) -> None:
230
+ """Load FLUX DiT weights. Implemented in C-122b."""
231
+ raise NotImplementedError("FLUX DiT weight loading β€” see C-122b")
232
+
233
+
234
+ def load_vae(path: str, model) -> None:
235
+ """Load VAE decoder weights. Implemented in C-122c."""
236
+ raise NotImplementedError("VAE weight loading β€” see C-122c")