Upload folder using huggingface_hub
Browse files- README.md +109 -0
- __init__.py +6 -0
- autoencoder.py +188 -0
- clip_encoder.py +155 -0
- download_weights.py +41 -0
- flux_model.py +447 -0
- pipeline.py +410 -0
- sampler.py +125 -0
- t5_encoder.py +226 -0
- tokenizers.py +150 -0
- weight_loader.py +236 -0
README.md
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: mlx
|
| 6 |
+
tags:
|
| 7 |
+
- mlx
|
| 8 |
+
- text-to-image
|
| 9 |
+
- apple-silicon
|
| 10 |
+
- image-generation
|
| 11 |
+
- diffusion
|
| 12 |
+
- flux
|
| 13 |
+
base_model: black-forest-labs/FLUX.1-schnell
|
| 14 |
+
pipeline_tag: text-to-image
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# FLUX.1-schnell MLX Pipeline
|
| 18 |
+
|
| 19 |
+
**Pure MLX (Apple Silicon) inference pipeline for [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)** β a fast text-to-image model by Black Forest Labs.
|
| 20 |
+
|
| 21 |
+
Zero PyTorch dependency. Runs natively on Apple Silicon via Metal GPU.
|
| 22 |
+
|
| 23 |
+
## Highlights
|
| 24 |
+
|
| 25 |
+
- **100% MLX native** β no torch, no diffusers needed
|
| 26 |
+
- **4-bit quantization** support via `argmaxinc/mlx-FLUX.1-schnell-4bit-quantized`
|
| 27 |
+
- **Fast 4-step generation** (FLUX.1-schnell is distilled for speed)
|
| 28 |
+
- **T5-XXL + CLIP-L** dual text encoders
|
| 29 |
+
- **FluxTransformer** with 19 Joint Blocks + 38 Single Blocks + N-dim RoPE
|
| 30 |
+
|
| 31 |
+
## Architecture
|
| 32 |
+
|
| 33 |
+
```
|
| 34 |
+
FluxPipeline
|
| 35 |
+
βββ T5-XXL Encoder (24 layers, hidden=4096)
|
| 36 |
+
β βββ Relative positional attention + GatedFFN
|
| 37 |
+
βββ CLIP-L Encoder (23 layers, hidden=768)
|
| 38 |
+
β βββ Causal mask + EOS pooling
|
| 39 |
+
βββ FluxTransformer (DiT)
|
| 40 |
+
β βββ 19 JointTransformerBlock (txt+img joint attention)
|
| 41 |
+
β βββ 38 SingleTransformerBlock (img self-attention)
|
| 42 |
+
β βββ N-dim RoPE (axes_dim=[16,56,56])
|
| 43 |
+
βββ AutoencoderKL Decoder
|
| 44 |
+
β βββ Latent channels=16, block_out=[128,256,512,512]
|
| 45 |
+
βββ FlowMatchEuler Sampler
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Quick Start
|
| 49 |
+
|
| 50 |
+
### Install
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
pip install mlx safetensors sentencepiece tokenizers pillow numpy
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Download Weights
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# 4-bit quantized (recommended, ~5GB)
|
| 60 |
+
huggingface-cli download argmaxinc/mlx-FLUX.1-schnell-4bit-quantized
|
| 61 |
+
|
| 62 |
+
# Or full precision
|
| 63 |
+
huggingface-cli download argmaxinc/mlx-FLUX.1-schnell
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Generate
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from pipeline import FluxPipeline
|
| 70 |
+
|
| 71 |
+
pipe = FluxPipeline()
|
| 72 |
+
pipe.load()
|
| 73 |
+
|
| 74 |
+
result = pipe.generate_and_save(
|
| 75 |
+
prompt="a beautiful sunset over mountains",
|
| 76 |
+
output_path="output.png",
|
| 77 |
+
width=512,
|
| 78 |
+
height=512,
|
| 79 |
+
num_steps=4,
|
| 80 |
+
seed=42,
|
| 81 |
+
)
|
| 82 |
+
print(f"Generated in {result['elapsed_s']}s")
|
| 83 |
+
|
| 84 |
+
pipe.unload()
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## Files
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
βββ pipeline.py # Main inference pipeline
|
| 91 |
+
βββ flux_model.py # FluxTransformer (JointBlock + SingleBlock)
|
| 92 |
+
βββ t5_encoder.py # T5-XXL text encoder
|
| 93 |
+
βββ clip_encoder.py # CLIP-L text encoder
|
| 94 |
+
βββ autoencoder.py # VAE decoder
|
| 95 |
+
βββ sampler.py # FlowMatch Euler sampler
|
| 96 |
+
βββ tokenizers.py # T5 + CLIP tokenizers
|
| 97 |
+
βββ weight_loader.py # Weight loading + key mapping
|
| 98 |
+
βββ download_weights.py # HF Hub download helper
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Model Source
|
| 102 |
+
|
| 103 |
+
Inference code is original work. Weights are loaded from:
|
| 104 |
+
- [argmaxinc/mlx-FLUX.1-schnell-4bit-quantized](https://huggingface.co/argmaxinc/mlx-FLUX.1-schnell-4bit-quantized) (default)
|
| 105 |
+
- [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) (original)
|
| 106 |
+
|
| 107 |
+
## License
|
| 108 |
+
|
| 109 |
+
Apache 2.0
|
__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLX FLUX pipeline package.
|
| 2 |
+
|
| 3 |
+
Provides a minimal FLUX.1-schnell diffusion pipeline implemented in
|
| 4 |
+
pure MLX for Apple Silicon inference. Uses pre-converted weights from
|
| 5 |
+
HuggingFace (argmaxinc/mlx-FLUX.1-schnell or 4bit variant).
|
| 6 |
+
"""
|
autoencoder.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FLUX VAE decoder β param names match argmaxinc ae.safetensors keys.
|
| 2 |
+
|
| 3 |
+
Weight key structure:
|
| 4 |
+
decoder.conv_in.*
|
| 5 |
+
decoder.mid.block_{1,2}.{norm1,conv1,norm2,conv2}.*
|
| 6 |
+
decoder.mid.attn_1.{norm,q,k,v,proj_out}.*
|
| 7 |
+
decoder.up.{0-3}.block.{0-2}.{norm1,conv1,norm2,conv2,nin_shortcut}.*
|
| 8 |
+
decoder.up.{1-3}.upsample.conv.*
|
| 9 |
+
decoder.norm_out.*
|
| 10 |
+
decoder.conv_out.*
|
| 11 |
+
|
| 12 |
+
Note: up blocks are indexed in reverse β up.3 is the first decoder stage
|
| 13 |
+
(highest channels), up.0 is the last (lowest channels).
|
| 14 |
+
|
| 15 |
+
All conv weights loaded as PyTorch [O,I,kH,kW] are transposed to MLX
|
| 16 |
+
[O,kH,kW,I] in the pipeline's _load_vae().
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import mlx.core as mx
|
| 22 |
+
import mlx.nn as nn
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ββ Building blocks (param names match weight keys) ββββββββββββββββββββββββββ
|
| 26 |
+
|
| 27 |
+
class ResnetBlock(nn.Module):
|
| 28 |
+
"""Matches: block_{i}.{norm1,conv1,norm2,conv2,nin_shortcut}.*"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, in_ch: int, out_ch: int):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.norm1 = nn.GroupNorm(32, in_ch)
|
| 33 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
| 34 |
+
self.norm2 = nn.GroupNorm(32, out_ch)
|
| 35 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
|
| 36 |
+
if in_ch != out_ch:
|
| 37 |
+
self.nin_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1)
|
| 38 |
+
else:
|
| 39 |
+
self.nin_shortcut = None
|
| 40 |
+
|
| 41 |
+
def __call__(self, x):
|
| 42 |
+
h = nn.silu(self.norm1(x))
|
| 43 |
+
h = self.conv1(h)
|
| 44 |
+
h = nn.silu(self.norm2(h))
|
| 45 |
+
h = self.conv2(h)
|
| 46 |
+
if self.nin_shortcut is not None:
|
| 47 |
+
x = self.nin_shortcut(x)
|
| 48 |
+
return x + h
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AttnBlock(nn.Module):
|
| 52 |
+
"""Matches: attn_1.{norm,q,k,v,proj_out}.*
|
| 53 |
+
|
| 54 |
+
Uses 1Γ1 Conv2d for Q/K/V/O projections (matching weight shapes).
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, channels: int):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.norm = nn.GroupNorm(32, channels)
|
| 60 |
+
self.q = nn.Conv2d(channels, channels, kernel_size=1)
|
| 61 |
+
self.k = nn.Conv2d(channels, channels, kernel_size=1)
|
| 62 |
+
self.v = nn.Conv2d(channels, channels, kernel_size=1)
|
| 63 |
+
self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)
|
| 64 |
+
|
| 65 |
+
def __call__(self, x):
|
| 66 |
+
B, H, W, C = x.shape
|
| 67 |
+
h = self.norm(x)
|
| 68 |
+
q = self.q(h).reshape(B, H * W, C)
|
| 69 |
+
k = self.k(h).reshape(B, H * W, C)
|
| 70 |
+
v = self.v(h).reshape(B, H * W, C)
|
| 71 |
+
|
| 72 |
+
scale = C ** -0.5
|
| 73 |
+
attn = (q @ k.transpose(0, 2, 1)) * scale
|
| 74 |
+
attn = mx.softmax(attn, axis=-1)
|
| 75 |
+
h = (attn @ v).reshape(B, H, W, C)
|
| 76 |
+
return x + self.proj_out(h)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Upsample(nn.Module):
|
| 80 |
+
"""Matches: upsample.conv.*"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, channels: int):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 85 |
+
|
| 86 |
+
def __call__(self, x):
|
| 87 |
+
B, H, W, C = x.shape
|
| 88 |
+
x = mx.repeat(x, 2, axis=1)
|
| 89 |
+
x = mx.repeat(x, 2, axis=2)
|
| 90 |
+
return self.conv(x)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class UpBlock(nn.Module):
|
| 94 |
+
"""One decoder up-stage. Matches: up.{i}.block.{0-2}.* + up.{i}.upsample.*"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, in_ch: int, out_ch: int, num_blocks: int = 3, has_upsample: bool = True):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.block = [ResnetBlock(in_ch if j == 0 else out_ch, out_ch) for j in range(num_blocks)]
|
| 99 |
+
if has_upsample:
|
| 100 |
+
self.upsample = Upsample(out_ch)
|
| 101 |
+
else:
|
| 102 |
+
self.upsample = None
|
| 103 |
+
|
| 104 |
+
def __call__(self, x):
|
| 105 |
+
for b in self.block:
|
| 106 |
+
x = b(x)
|
| 107 |
+
if self.upsample is not None:
|
| 108 |
+
x = self.upsample(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class MidBlock(nn.Module):
|
| 113 |
+
"""Matches: mid.{block_1, attn_1, block_2}.*"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, channels: int):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.block_1 = ResnetBlock(channels, channels)
|
| 118 |
+
self.attn_1 = AttnBlock(channels)
|
| 119 |
+
self.block_2 = ResnetBlock(channels, channels)
|
| 120 |
+
|
| 121 |
+
def __call__(self, x):
|
| 122 |
+
x = self.block_1(x)
|
| 123 |
+
x = self.attn_1(x)
|
| 124 |
+
x = self.block_2(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ββ Decoder ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
|
| 130 |
+
class Decoder(nn.Module):
|
| 131 |
+
"""VAE Decoder. Param paths match: decoder.{conv_in,mid,up,norm_out,conv_out}.*
|
| 132 |
+
|
| 133 |
+
Up block order (matching weight keys):
|
| 134 |
+
up.3 β 512β512 + upsample (first stage)
|
| 135 |
+
up.2 β 512β512 + upsample
|
| 136 |
+
up.1 β 512β256 + upsample
|
| 137 |
+
up.0 β 256β128 (no upsample, last stage)
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.conv_in = nn.Conv2d(16, 512, kernel_size=3, padding=1)
|
| 143 |
+
|
| 144 |
+
self.mid = MidBlock(512)
|
| 145 |
+
|
| 146 |
+
# up blocks β indexed 0-3, processed in reverse order (3β2β1β0)
|
| 147 |
+
self.up = [
|
| 148 |
+
UpBlock(256, 128, num_blocks=3, has_upsample=False), # up.0
|
| 149 |
+
UpBlock(512, 256, num_blocks=3, has_upsample=True), # up.1
|
| 150 |
+
UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.2
|
| 151 |
+
UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.3
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
self.norm_out = nn.GroupNorm(32, 128)
|
| 155 |
+
self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
| 156 |
+
|
| 157 |
+
def __call__(self, z):
|
| 158 |
+
h = self.conv_in(z)
|
| 159 |
+
h = self.mid(h)
|
| 160 |
+
# Process up blocks in reverse order: 3, 2, 1, 0
|
| 161 |
+
for i in reversed(range(len(self.up))):
|
| 162 |
+
h = self.up[i](h)
|
| 163 |
+
h = nn.silu(self.norm_out(h))
|
| 164 |
+
h = self.conv_out(h)
|
| 165 |
+
return h
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ββ AutoencoderKL ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
|
| 170 |
+
class AutoencoderKL(nn.Module):
|
| 171 |
+
"""FLUX VAE β decode-only path.
|
| 172 |
+
|
| 173 |
+
Input: z [B, H/8, W/8, 16] (latent, channels-last)
|
| 174 |
+
Output: image [B, H, W, 3] (RGB in [0, 1])
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
SCALE_FACTOR = 0.3611
|
| 178 |
+
SHIFT_FACTOR = 0.1159
|
| 179 |
+
|
| 180 |
+
def __init__(self):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.decoder = Decoder()
|
| 183 |
+
|
| 184 |
+
def decode(self, z: mx.array) -> mx.array:
|
| 185 |
+
z = z / self.SCALE_FACTOR + self.SHIFT_FACTOR
|
| 186 |
+
image = self.decoder(z)
|
| 187 |
+
image = mx.clip((image + 1.0) / 2.0, 0.0, 1.0)
|
| 188 |
+
return image
|
clip_encoder.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLIP-L text encoder for FLUX pipeline.
|
| 2 |
+
|
| 3 |
+
Implements a 23-layer CLIP text encoder with absolute position embeddings
|
| 4 |
+
and causal self-attention β matching the HuggingFace
|
| 5 |
+
``openai/clip-vit-large-patch14`` architecture used by FLUX.1.
|
| 6 |
+
|
| 7 |
+
Weight source: ``black-forest-labs/FLUX.1-schnell`` β
|
| 8 |
+
``text_encoder/model.safetensors``
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
import mlx.core as mx
|
| 16 |
+
import mlx.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ββ CLIP Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
class CLIPConfig:
|
| 22 |
+
vocab_size: int = 49408
|
| 23 |
+
d_model: int = 768
|
| 24 |
+
num_heads: int = 12
|
| 25 |
+
head_dim: int = 64 # d_model / num_heads
|
| 26 |
+
intermediate_size: int = 3072
|
| 27 |
+
num_layers: int = 23 # FLUX uses 23, not 12
|
| 28 |
+
max_position_embeddings: int = 77
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ββ Building blocks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
|
| 33 |
+
class CLIPAttention(nn.Module):
|
| 34 |
+
"""CLIP multi-head self-attention."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, cfg: CLIPConfig):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.num_heads = cfg.num_heads
|
| 39 |
+
self.head_dim = cfg.head_dim
|
| 40 |
+
d = cfg.d_model
|
| 41 |
+
|
| 42 |
+
self.q_proj = nn.Linear(d, d)
|
| 43 |
+
self.k_proj = nn.Linear(d, d)
|
| 44 |
+
self.v_proj = nn.Linear(d, d)
|
| 45 |
+
self.out_proj = nn.Linear(d, d)
|
| 46 |
+
|
| 47 |
+
def __call__(self, x: mx.array, causal_mask: mx.array | None = None) -> mx.array:
|
| 48 |
+
B, L, _ = x.shape
|
| 49 |
+
H, D = self.num_heads, self.head_dim
|
| 50 |
+
|
| 51 |
+
q = self.q_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
|
| 52 |
+
k = self.k_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
|
| 53 |
+
v = self.v_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
|
| 54 |
+
|
| 55 |
+
scale = math.sqrt(D)
|
| 56 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) / scale
|
| 57 |
+
|
| 58 |
+
if causal_mask is not None:
|
| 59 |
+
scores = scores + causal_mask
|
| 60 |
+
|
| 61 |
+
weights = mx.softmax(scores, axis=-1)
|
| 62 |
+
out = weights @ v
|
| 63 |
+
out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
| 64 |
+
return self.out_proj(out)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CLIPMLP(nn.Module):
|
| 68 |
+
"""CLIP feed-forward network (GELU activation)."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, cfg: CLIPConfig):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.fc1 = nn.Linear(cfg.d_model, cfg.intermediate_size)
|
| 73 |
+
self.fc2 = nn.Linear(cfg.intermediate_size, cfg.d_model)
|
| 74 |
+
|
| 75 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 76 |
+
return self.fc2(nn.gelu_approx(self.fc1(x)))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CLIPEncoderLayer(nn.Module):
|
| 80 |
+
"""Single CLIP encoder layer: Norm β Attention β Norm β MLP."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, cfg: CLIPConfig):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.norm1 = nn.LayerNorm(cfg.d_model)
|
| 85 |
+
self.attn = CLIPAttention(cfg)
|
| 86 |
+
self.norm2 = nn.LayerNorm(cfg.d_model)
|
| 87 |
+
self.mlp = CLIPMLP(cfg)
|
| 88 |
+
|
| 89 |
+
def __call__(self, x: mx.array, causal_mask: mx.array | None = None) -> mx.array:
|
| 90 |
+
x = x + self.attn(self.norm1(x), causal_mask)
|
| 91 |
+
x = x + self.mlp(self.norm2(x))
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ββ CLIP Encoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
|
| 97 |
+
class CLIPEncoder(nn.Module):
|
| 98 |
+
"""CLIP-L text encoder: 23-layer transformer with absolute position embeddings.
|
| 99 |
+
|
| 100 |
+
Input: token_ids [B, 77]
|
| 101 |
+
Output: (pooled [B, 768], hidden_states [B, 77, 768])
|
| 102 |
+
|
| 103 |
+
The pooled output is taken from the EOS token position (last
|
| 104 |
+
non-padding token), following the CLIP convention.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, cfg: CLIPConfig | None = None):
|
| 108 |
+
super().__init__()
|
| 109 |
+
if cfg is None:
|
| 110 |
+
cfg = CLIPConfig()
|
| 111 |
+
self.cfg = cfg
|
| 112 |
+
|
| 113 |
+
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 114 |
+
self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.d_model)
|
| 115 |
+
self.layers = [CLIPEncoderLayer(cfg) for _ in range(cfg.num_layers)]
|
| 116 |
+
self.final_norm = nn.LayerNorm(cfg.d_model)
|
| 117 |
+
|
| 118 |
+
def _build_causal_mask(self, seq_len: int) -> mx.array:
|
| 119 |
+
"""Build causal attention mask [1, 1, L, L]."""
|
| 120 |
+
mask = mx.full((seq_len, seq_len), -1e9)
|
| 121 |
+
mask = mx.triu(mask, k=1) # upper triangle = -inf, diagonal+below = 0
|
| 122 |
+
return mask.reshape(1, 1, seq_len, seq_len)
|
| 123 |
+
|
| 124 |
+
def __call__(self, token_ids: mx.array) -> tuple[mx.array, mx.array]:
|
| 125 |
+
B, L = token_ids.shape
|
| 126 |
+
|
| 127 |
+
# Embeddings
|
| 128 |
+
positions = mx.arange(L)
|
| 129 |
+
x = self.token_emb(token_ids) + self.pos_emb(positions)
|
| 130 |
+
|
| 131 |
+
# Causal mask
|
| 132 |
+
causal_mask = self._build_causal_mask(L)
|
| 133 |
+
|
| 134 |
+
# Transformer layers
|
| 135 |
+
for layer in self.layers:
|
| 136 |
+
x = layer(x, causal_mask)
|
| 137 |
+
|
| 138 |
+
x = self.final_norm(x) # [B, L, d_model]
|
| 139 |
+
|
| 140 |
+
# Pooled output: EOS token position
|
| 141 |
+
# Find the EOS token (49407) or use the last non-zero position
|
| 142 |
+
eos_id = 49407
|
| 143 |
+
# For each batch element, find the position of EOS
|
| 144 |
+
# Simple approach: use argmax on (token_ids == eos_id)
|
| 145 |
+
eos_mask = (token_ids == eos_id).astype(mx.int32)
|
| 146 |
+
# If EOS not found, use last position
|
| 147 |
+
has_eos = mx.sum(eos_mask, axis=-1, keepdims=True) > 0
|
| 148 |
+
eos_pos = mx.argmax(eos_mask, axis=-1) # [B]
|
| 149 |
+
|
| 150 |
+
# Gather pooled output
|
| 151 |
+
idx = eos_pos.reshape(B, 1, 1)
|
| 152 |
+
idx = mx.broadcast_to(idx, (B, 1, x.shape[-1]))
|
| 153 |
+
pooled = mx.take_along_axis(x, idx, axis=1).squeeze(1) # [B, d_model]
|
| 154 |
+
|
| 155 |
+
return pooled, x
|
download_weights.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Download shared T5/CLIP weights for the self-built FLUX MLX pipeline.
|
| 3 |
+
|
| 4 |
+
Stores files locally in backends/mlx_flux/weights/ so they are co-located
|
| 5 |
+
with our code and won't be accidentally deleted as "unused HF cache".
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
REPO = "black-forest-labs/FLUX.1-schnell"
|
| 13 |
+
WEIGHTS_DIR = Path(__file__).parent / "weights"
|
| 14 |
+
|
| 15 |
+
FILES = [
|
| 16 |
+
# (HF repo path, local filename)
|
| 17 |
+
("text_encoder_2/model-00001-of-00002.safetensors", "t5_shard1.safetensors"),
|
| 18 |
+
("text_encoder_2/model-00002-of-00002.safetensors", "t5_shard2.safetensors"),
|
| 19 |
+
("text_encoder/model.safetensors", "clip_text_encoder.safetensors"),
|
| 20 |
+
("tokenizer_2/spiece.model", "t5_spiece.model"),
|
| 21 |
+
("tokenizer/vocab.json", "clip_vocab.json"),
|
| 22 |
+
("tokenizer/merges.txt", "clip_merges.txt"),
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
WEIGHTS_DIR.mkdir(exist_ok=True)
|
| 28 |
+
for hf_path, local_name in FILES:
|
| 29 |
+
dest = WEIGHTS_DIR / local_name
|
| 30 |
+
if dest.exists() and dest.stat().st_size > 0:
|
| 31 |
+
print(f" SKIP {local_name} (already exists, {dest.stat().st_size / 1024 / 1024:.1f} MB)")
|
| 32 |
+
continue
|
| 33 |
+
print(f" DOWNLOADING {hf_path} -> {local_name} ...")
|
| 34 |
+
cached = hf_hub_download(REPO, hf_path)
|
| 35 |
+
shutil.copy2(cached, dest)
|
| 36 |
+
print(f" OK {local_name} ({dest.stat().st_size / 1024 / 1024:.1f} MB)")
|
| 37 |
+
print("\nAll weights ready in", WEIGHTS_DIR)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
flux_model.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FLUX DiT β rewritten to match mflux reference implementation exactly.
|
| 2 |
+
|
| 3 |
+
Parameter names match argmaxinc/mlx-FLUX.1-schnell weights.
|
| 4 |
+
Forward pass logic matches filipstrand/mflux.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import mlx.core as mx
|
| 12 |
+
import mlx.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
|
| 17 |
+
class FluxConfig:
|
| 18 |
+
hidden_size: int = 3072
|
| 19 |
+
num_heads: int = 24
|
| 20 |
+
head_dim: int = 128
|
| 21 |
+
mlp_ratio: float = 4.0
|
| 22 |
+
num_joint_blocks: int = 19
|
| 23 |
+
num_single_blocks: int = 38
|
| 24 |
+
axes_dim: tuple[int, ...] = (16, 56, 56)
|
| 25 |
+
theta: float = 10_000.0
|
| 26 |
+
in_channels: int = 64
|
| 27 |
+
context_dim: int = 4096
|
| 28 |
+
pooled_dim: int = 768
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ββ RoPE (matches mflux EmbedND) ββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
|
| 33 |
+
def _rope_single_axis(pos: mx.array, dim: int, theta: float) -> mx.array:
|
| 34 |
+
"""Compute RoPE for one positional axis.
|
| 35 |
+
|
| 36 |
+
Returns [B, seq, dim//2, 2, 2] rotation matrices.
|
| 37 |
+
"""
|
| 38 |
+
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
| 39 |
+
omega = 1.0 / (theta ** scale)
|
| 40 |
+
# pos: [B, seq], omega: [dim//2]
|
| 41 |
+
out = pos[:, :, None].astype(mx.float32) * omega[None, :] # [B, seq, dim//2]
|
| 42 |
+
cos_out = mx.cos(out)
|
| 43 |
+
sin_out = mx.sin(out)
|
| 44 |
+
# Stack as 2x2 rotation matrix: [[cos, -sin], [sin, cos]]
|
| 45 |
+
stacked = mx.stack([cos_out, -sin_out, sin_out, cos_out], axis=-1)
|
| 46 |
+
return stacked.reshape(pos.shape[0], -1, dim // 2, 2, 2)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_rope(ids: mx.array, axes_dim=(16, 56, 56), theta=10000.0) -> mx.array:
|
| 50 |
+
"""Compute N-dimensional RoPE embeddings.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
ids: [B, seq, 3] position IDs (time, height, width)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
[B, 1, seq, head_dim//2, 2, 2] rotation matrices
|
| 57 |
+
"""
|
| 58 |
+
emb = mx.concatenate([
|
| 59 |
+
_rope_single_axis(ids[..., i], axes_dim[i], theta)
|
| 60 |
+
for i in range(3)
|
| 61 |
+
], axis=-3) # concat along the dim//2 axis β total = sum(axes_dim)//2 = 64
|
| 62 |
+
return emb[:, None] # [B, 1, seq, 64, 2, 2]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_rope(q: mx.array, k: mx.array, freqs: mx.array):
|
| 66 |
+
"""Apply rotary embeddings to q and k (matches mflux exactly).
|
| 67 |
+
|
| 68 |
+
q, k: [B, H, L, D] where D = head_dim
|
| 69 |
+
freqs: [B, 1, L, D//2, 2, 2]
|
| 70 |
+
"""
|
| 71 |
+
# Reshape to pairs: [B, H, L, D//2, 1, 2]
|
| 72 |
+
xq_ = q.astype(mx.float32).reshape(*q.shape[:-1], -1, 1, 2)
|
| 73 |
+
xk_ = k.astype(mx.float32).reshape(*k.shape[:-1], -1, 1, 2)
|
| 74 |
+
|
| 75 |
+
# freqs[..., 0] = [[cos, -sin]] shape [..., 2]
|
| 76 |
+
# freqs[..., 1] = [[sin, cos]] shape [..., 2]
|
| 77 |
+
# xq_[..., 0] = first of pair (scalar), xq_[..., 1] = second of pair
|
| 78 |
+
xq_out = freqs[..., 0] * xq_[..., 0] + freqs[..., 1] * xq_[..., 1]
|
| 79 |
+
xk_out = freqs[..., 0] * xk_[..., 0] + freqs[..., 1] * xk_[..., 1]
|
| 80 |
+
|
| 81 |
+
return xq_out.reshape(*q.shape).astype(mx.float32), xk_out.reshape(*k.shape).astype(mx.float32)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ββ Timestep embedding (matches mflux TimeTextEmbed) βββββββββββββββββββββββββ
|
| 85 |
+
|
| 86 |
+
def timestep_embedding(t: mx.array, dim: int = 256) -> mx.array:
|
| 87 |
+
"""Sinusoidal timestep embedding with half-freq swap (mflux convention).
|
| 88 |
+
|
| 89 |
+
Output: [B, dim] with [cos_high, sin_low] β [sin_high, cos_low] swap.
|
| 90 |
+
"""
|
| 91 |
+
half = dim // 2
|
| 92 |
+
freqs = mx.exp(-math.log(10000.0) * mx.arange(half, dtype=mx.float32) / half)
|
| 93 |
+
args = t[:, None].astype(mx.float32) * freqs[None, :]
|
| 94 |
+
emb = mx.concatenate([mx.sin(args), mx.cos(args)], axis=-1)
|
| 95 |
+
# mflux swaps halves: [sin, cos] β [cos_high_half, sin_low_half]
|
| 96 |
+
emb = mx.concatenate([emb[:, half:], emb[:, :half]], axis=-1)
|
| 97 |
+
return emb
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ββ AdaLN modulation (matches mflux AdaLayerNormZero) ββββββββββββββββββββββββ
|
| 101 |
+
|
| 102 |
+
class AdaLNModulation(nn.Module):
|
| 103 |
+
"""Matches: adaLN_modulation.layers.1.*
|
| 104 |
+
|
| 105 |
+
layers.0 = SiLU (no params), layers.1 = Linear.
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self, dim: int, n_params: int):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.layers = [nn.SiLU(), nn.Linear(dim, n_params * dim)]
|
| 110 |
+
|
| 111 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 112 |
+
for layer in self.layers:
|
| 113 |
+
x = layer(x)
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ββ QK Norm ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
class QKNorm(nn.Module):
|
| 120 |
+
"""Matches: qk_norm.{q_norm, k_norm}.weight"""
|
| 121 |
+
def __init__(self, dim: int):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.q_norm = nn.RMSNorm(dim)
|
| 124 |
+
self.k_norm = nn.RMSNorm(dim)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ββ Attention ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
class Attention(nn.Module):
|
| 130 |
+
"""Separate Q/K/V/O projections. Matches: attn.{q,k,v,o}_proj.*"""
|
| 131 |
+
def __init__(self, dim: int, num_heads: int):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.q_proj = nn.Linear(dim, dim)
|
| 134 |
+
self.k_proj = nn.Linear(dim, dim)
|
| 135 |
+
self.v_proj = nn.Linear(dim, dim)
|
| 136 |
+
self.o_proj = nn.Linear(dim, dim)
|
| 137 |
+
self.num_heads = num_heads
|
| 138 |
+
self.head_dim = dim // num_heads
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ββ MLP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
|
| 143 |
+
class MLP(nn.Module):
|
| 144 |
+
"""Matches: mlp.{fc1,fc2}.*"""
|
| 145 |
+
def __init__(self, dim: int, hidden: int):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.fc1 = nn.Linear(dim, hidden)
|
| 148 |
+
self.fc2 = nn.Linear(hidden, dim)
|
| 149 |
+
|
| 150 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 151 |
+
return self.fc2(nn.gelu(self.fc1(x)))
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ββ Joint Transformer Block βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
|
| 156 |
+
class ImageTransformerSubBlock(nn.Module):
|
| 157 |
+
"""Image side of joint block.
|
| 158 |
+
Matches: image_transformer_block.{adaLN_modulation, attn, mlp, qk_norm}.*
|
| 159 |
+
"""
|
| 160 |
+
def __init__(self, cfg: FluxConfig):
|
| 161 |
+
super().__init__()
|
| 162 |
+
H = cfg.hidden_size
|
| 163 |
+
self.adaLN_modulation = AdaLNModulation(H, 6)
|
| 164 |
+
self.attn = Attention(H, cfg.num_heads)
|
| 165 |
+
self.mlp = MLP(H, int(H * cfg.mlp_ratio))
|
| 166 |
+
self.qk_norm = QKNorm(cfg.head_dim)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class TextTransformerSubBlock(nn.Module):
|
| 170 |
+
"""Text side of joint block.
|
| 171 |
+
Matches: text_transformer_block.{adaLN_modulation, attn, mlp, qk_norm}.*
|
| 172 |
+
NOTE: text side uses hidden_size/2 = 1536 for FFN (mflux convention).
|
| 173 |
+
"""
|
| 174 |
+
def __init__(self, cfg: FluxConfig):
|
| 175 |
+
super().__init__()
|
| 176 |
+
H = cfg.hidden_size
|
| 177 |
+
self.adaLN_modulation = AdaLNModulation(H, 6)
|
| 178 |
+
self.attn = Attention(H, cfg.num_heads)
|
| 179 |
+
# Text FFN uses gelu_approx and different hidden dim
|
| 180 |
+
self.mlp = MLP(H, int(H * cfg.mlp_ratio))
|
| 181 |
+
self.qk_norm = QKNorm(cfg.head_dim)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JointTransformerBlock(nn.Module):
|
| 185 |
+
"""Matches: multimodal_transformer_blocks.{i}.*"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, cfg: FluxConfig):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.image_transformer_block = ImageTransformerSubBlock(cfg)
|
| 190 |
+
self.text_transformer_block = TextTransformerSubBlock(cfg)
|
| 191 |
+
self._num_heads = cfg.num_heads
|
| 192 |
+
self._head_dim = cfg.head_dim
|
| 193 |
+
|
| 194 |
+
def __call__(self, img, txt, vec, rope_emb):
|
| 195 |
+
B = img.shape[0]
|
| 196 |
+
H, D = self._num_heads, self._head_dim
|
| 197 |
+
img_len = img.shape[1]
|
| 198 |
+
txt_len = txt.shape[1]
|
| 199 |
+
|
| 200 |
+
# 1. AdaLN modulation
|
| 201 |
+
img_params = self.image_transformer_block.adaLN_modulation(vec)
|
| 202 |
+
i_s1, i_sc1, i_g1, i_s2, i_sc2, i_g2 = mx.split(img_params, 6, axis=-1)
|
| 203 |
+
|
| 204 |
+
txt_params = self.text_transformer_block.adaLN_modulation(vec)
|
| 205 |
+
t_s1, t_sc1, t_g1, t_s2, t_sc2, t_g2 = mx.split(txt_params, 6, axis=-1)
|
| 206 |
+
|
| 207 |
+
# 2. LayerNorm(affine=False) + modulate
|
| 208 |
+
img_norm = nn.LayerNorm(img.shape[-1], affine=False, eps=1e-6)(img)
|
| 209 |
+
img_norm = img_norm * (1 + i_sc1[:, None, :]) + i_s1[:, None, :]
|
| 210 |
+
txt_norm = nn.LayerNorm(txt.shape[-1], affine=False, eps=1e-6)(txt)
|
| 211 |
+
txt_norm = txt_norm * (1 + t_sc1[:, None, :]) + t_s1[:, None, :]
|
| 212 |
+
|
| 213 |
+
# 3. Q/K/V projections + QK norm
|
| 214 |
+
img_q = self.image_transformer_block.attn.q_proj(img_norm).reshape(B, img_len, H, D)
|
| 215 |
+
img_k = self.image_transformer_block.attn.k_proj(img_norm).reshape(B, img_len, H, D)
|
| 216 |
+
img_v = self.image_transformer_block.attn.v_proj(img_norm).reshape(B, img_len, H, D)
|
| 217 |
+
txt_q = self.text_transformer_block.attn.q_proj(txt_norm).reshape(B, txt_len, H, D)
|
| 218 |
+
txt_k = self.text_transformer_block.attn.k_proj(txt_norm).reshape(B, txt_len, H, D)
|
| 219 |
+
txt_v = self.text_transformer_block.attn.v_proj(txt_norm).reshape(B, txt_len, H, D)
|
| 220 |
+
|
| 221 |
+
img_q = self.image_transformer_block.qk_norm.q_norm(img_q)
|
| 222 |
+
img_k = self.image_transformer_block.qk_norm.k_norm(img_k)
|
| 223 |
+
txt_q = self.text_transformer_block.qk_norm.q_norm(txt_q)
|
| 224 |
+
txt_k = self.text_transformer_block.qk_norm.k_norm(txt_k)
|
| 225 |
+
|
| 226 |
+
# 4. Concat for joint attention: [txt, img]
|
| 227 |
+
q = mx.concatenate([txt_q, img_q], axis=1).transpose(0, 2, 1, 3) # [B,H,L,D]
|
| 228 |
+
k = mx.concatenate([txt_k, img_k], axis=1).transpose(0, 2, 1, 3)
|
| 229 |
+
v = mx.concatenate([txt_v, img_v], axis=1).transpose(0, 2, 1, 3)
|
| 230 |
+
|
| 231 |
+
# 5. RoPE
|
| 232 |
+
q, k = apply_rope(q, k, rope_emb)
|
| 233 |
+
|
| 234 |
+
# 6. Attention
|
| 235 |
+
scale = math.sqrt(D)
|
| 236 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) / scale
|
| 237 |
+
weights = mx.softmax(scores, axis=-1)
|
| 238 |
+
attn_out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, txt_len + img_len, -1)
|
| 239 |
+
|
| 240 |
+
# 7. Split and project
|
| 241 |
+
txt_attn = attn_out[:, :txt_len, :]
|
| 242 |
+
img_attn = attn_out[:, txt_len:, :]
|
| 243 |
+
|
| 244 |
+
# 8. Gated residual for attention
|
| 245 |
+
img_attn = self.image_transformer_block.attn.o_proj(img_attn)
|
| 246 |
+
txt_attn = self.text_transformer_block.attn.o_proj(txt_attn)
|
| 247 |
+
img = img + i_g1[:, None, :] * img_attn
|
| 248 |
+
txt = txt + t_g1[:, None, :] * txt_attn
|
| 249 |
+
|
| 250 |
+
# 9. FFN with LayerNorm(affine=False) + modulate
|
| 251 |
+
img_ff_in = nn.LayerNorm(img.shape[-1], affine=False, eps=1e-6)(img)
|
| 252 |
+
img_ff_in = img_ff_in * (1 + i_sc2[:, None, :]) + i_s2[:, None, :]
|
| 253 |
+
img = img + i_g2[:, None, :] * self.image_transformer_block.mlp(img_ff_in)
|
| 254 |
+
|
| 255 |
+
txt_ff_in = nn.LayerNorm(txt.shape[-1], affine=False, eps=1e-6)(txt)
|
| 256 |
+
txt_ff_in = txt_ff_in * (1 + t_sc2[:, None, :]) + t_s2[:, None, :]
|
| 257 |
+
txt = txt + t_g2[:, None, :] * self.text_transformer_block.mlp(txt_ff_in)
|
| 258 |
+
|
| 259 |
+
return img, txt
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# ββ Single Transformer Block ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 263 |
+
|
| 264 |
+
class SingleTransformerSubBlock(nn.Module):
|
| 265 |
+
"""Matches: unified_transformer_blocks.{i}.transformer_block.*"""
|
| 266 |
+
|
| 267 |
+
def __init__(self, cfg: FluxConfig):
|
| 268 |
+
super().__init__()
|
| 269 |
+
H = cfg.hidden_size
|
| 270 |
+
mlp_hidden = int(H * cfg.mlp_ratio)
|
| 271 |
+
self.adaLN_modulation = AdaLNModulation(H, 3)
|
| 272 |
+
self.attn = Attention(H, cfg.num_heads)
|
| 273 |
+
self.mlp = MLP(H, mlp_hidden)
|
| 274 |
+
self.qk_norm = QKNorm(cfg.head_dim)
|
| 275 |
+
|
| 276 |
+
def __call__(self, x, vec, rope_emb):
|
| 277 |
+
B, L, D = x.shape
|
| 278 |
+
H, HD = self._get_heads()
|
| 279 |
+
residual = x
|
| 280 |
+
|
| 281 |
+
# 1. AdaLN
|
| 282 |
+
params = self.adaLN_modulation(vec)
|
| 283 |
+
shift, scale, gate = mx.split(params, 3, axis=-1)
|
| 284 |
+
x_norm = nn.LayerNorm(D, affine=False, eps=1e-6)(x)
|
| 285 |
+
x_norm = x_norm * (1 + scale[:, None, :]) + shift[:, None, :]
|
| 286 |
+
|
| 287 |
+
# 2. Attention with QK norm
|
| 288 |
+
q = self.attn.q_proj(x_norm).reshape(B, L, H, HD)
|
| 289 |
+
k = self.attn.k_proj(x_norm).reshape(B, L, H, HD)
|
| 290 |
+
v = self.attn.v_proj(x_norm).reshape(B, L, H, HD)
|
| 291 |
+
|
| 292 |
+
q = self.qk_norm.q_norm(q)
|
| 293 |
+
k = self.qk_norm.k_norm(k)
|
| 294 |
+
|
| 295 |
+
q = q.transpose(0, 2, 1, 3)
|
| 296 |
+
k = k.transpose(0, 2, 1, 3)
|
| 297 |
+
v = v.transpose(0, 2, 1, 3)
|
| 298 |
+
|
| 299 |
+
q, k = apply_rope(q, k, rope_emb)
|
| 300 |
+
|
| 301 |
+
sc = math.sqrt(HD)
|
| 302 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) / sc
|
| 303 |
+
w = mx.softmax(scores, axis=-1)
|
| 304 |
+
attn_out = (w @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
| 305 |
+
|
| 306 |
+
# 3. Parallel MLP
|
| 307 |
+
mlp_out = nn.gelu_approx(self.mlp.fc1(x_norm))
|
| 308 |
+
|
| 309 |
+
# 4. Concat(attn, mlp) β project β gate
|
| 310 |
+
combined = mx.concatenate([attn_out, mlp_out], axis=-1)
|
| 311 |
+
# proj_out dimensions: attn_dim(3072) + mlp_hidden(12288) β 3072
|
| 312 |
+
# But we don't have proj_out as a separate param β use attn.o_proj for attn part
|
| 313 |
+
# and mlp.fc2 for mlp part, then add
|
| 314 |
+
attn_projected = self.attn.o_proj(attn_out)
|
| 315 |
+
mlp_projected = self.mlp.fc2(mlp_out)
|
| 316 |
+
out = gate[:, None, :] * (attn_projected + mlp_projected)
|
| 317 |
+
|
| 318 |
+
return residual + out
|
| 319 |
+
|
| 320 |
+
def _get_heads(self):
|
| 321 |
+
return 24, 128
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class SingleTransformerBlock(nn.Module):
|
| 325 |
+
"""Wrapper to match key path: unified_transformer_blocks.{i}.transformer_block.*"""
|
| 326 |
+
def __init__(self, cfg: FluxConfig):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.transformer_block = SingleTransformerSubBlock(cfg)
|
| 329 |
+
|
| 330 |
+
def __call__(self, x, vec, rope_emb):
|
| 331 |
+
return self.transformer_block(x, vec, rope_emb)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ββ FLUX Transformer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 335 |
+
|
| 336 |
+
class FluxTransformer(nn.Module):
|
| 337 |
+
"""FLUX DiT matching argmaxinc weights + mflux forward pass."""
|
| 338 |
+
|
| 339 |
+
def __init__(self, cfg: FluxConfig | None = None):
|
| 340 |
+
super().__init__()
|
| 341 |
+
if cfg is None:
|
| 342 |
+
cfg = FluxConfig()
|
| 343 |
+
self.cfg = cfg
|
| 344 |
+
H = cfg.hidden_size
|
| 345 |
+
|
| 346 |
+
# x_embedder: Linear (NOT Conv2d) β matches mflux
|
| 347 |
+
self.x_embedder = nn.Linear(cfg.in_channels, H)
|
| 348 |
+
|
| 349 |
+
# context_embedder
|
| 350 |
+
self.context_embedder = nn.Linear(cfg.context_dim, H)
|
| 351 |
+
|
| 352 |
+
# Timestep + text embeddings (match mflux naming)
|
| 353 |
+
self.t_embedder = _TimestepEmbedder(H)
|
| 354 |
+
self.y_embedder = _TextEmbedder(cfg.pooled_dim, H)
|
| 355 |
+
|
| 356 |
+
# Transformer blocks
|
| 357 |
+
self.multimodal_transformer_blocks = [
|
| 358 |
+
JointTransformerBlock(cfg) for _ in range(cfg.num_joint_blocks)
|
| 359 |
+
]
|
| 360 |
+
self.unified_transformer_blocks = [
|
| 361 |
+
SingleTransformerBlock(cfg) for _ in range(cfg.num_single_blocks)
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
# Final layer (matches mflux AdaLayerNormContinuous)
|
| 365 |
+
self.final_layer = _FinalLayer(H, cfg.in_channels)
|
| 366 |
+
|
| 367 |
+
def __call__(self, img, img_ids, txt, txt_ids, y, timesteps):
|
| 368 |
+
# 1. Embeddings
|
| 369 |
+
img = self.x_embedder(img) # [B, seq, 64] β [B, seq, 3072]
|
| 370 |
+
txt = self.context_embedder(txt) # [B, seq, 4096] β [B, seq, 3072]
|
| 371 |
+
|
| 372 |
+
# 2. Timestep conditioning β timesteps already scaled to [0, 1000] by scheduler
|
| 373 |
+
t_emb = timestep_embedding(timesteps, 256)
|
| 374 |
+
vec = self.t_embedder(t_emb) + self.y_embedder(y)
|
| 375 |
+
|
| 376 |
+
# 3. RoPE for full sequence [txt, img]
|
| 377 |
+
all_ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
| 378 |
+
rope_emb = compute_rope(all_ids, self.cfg.axes_dim, self.cfg.theta)
|
| 379 |
+
|
| 380 |
+
# 4. Joint blocks
|
| 381 |
+
for block in self.multimodal_transformer_blocks:
|
| 382 |
+
img, txt = block(img, txt, vec, rope_emb)
|
| 383 |
+
|
| 384 |
+
# 5. Concat for single blocks
|
| 385 |
+
img = mx.concatenate([txt, img], axis=1)
|
| 386 |
+
|
| 387 |
+
# 6. Single blocks (rope covers full sequence)
|
| 388 |
+
for block in self.unified_transformer_blocks:
|
| 389 |
+
img = block(img, vec, rope_emb)
|
| 390 |
+
|
| 391 |
+
# 7. Extract img tokens, apply final layer
|
| 392 |
+
img = img[:, txt.shape[1]:, :]
|
| 393 |
+
img = self.final_layer(img, vec)
|
| 394 |
+
|
| 395 |
+
return img
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# ββ Helper modules (match mflux TimestepEmbedder/TextEmbedder naming) ββββββββ
|
| 399 |
+
|
| 400 |
+
class _TimestepEmbedder(nn.Module):
|
| 401 |
+
"""Matches: t_embedder.mlp.layers.{0,2}.*"""
|
| 402 |
+
def __init__(self, dim):
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.mlp = _MLP2(256, dim)
|
| 405 |
+
|
| 406 |
+
def __call__(self, x):
|
| 407 |
+
return self.mlp(x)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class _TextEmbedder(nn.Module):
|
| 411 |
+
"""Matches: y_embedder.mlp.layers.{0,2}.*"""
|
| 412 |
+
def __init__(self, in_dim, dim):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.mlp = _MLP2(in_dim, dim)
|
| 415 |
+
|
| 416 |
+
def __call__(self, x):
|
| 417 |
+
return self.mlp(x)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class _MLP2(nn.Module):
|
| 421 |
+
"""Two-layer MLP with SiLU. Matches: mlp.layers.{0,2}.*"""
|
| 422 |
+
def __init__(self, in_dim, out_dim):
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.layers = [nn.Linear(in_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim)]
|
| 425 |
+
|
| 426 |
+
def __call__(self, x):
|
| 427 |
+
for layer in self.layers:
|
| 428 |
+
x = layer(x)
|
| 429 |
+
return x
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class _FinalLayer(nn.Module):
|
| 433 |
+
"""Matches: final_layer.{adaLN_modulation, linear}.*
|
| 434 |
+
|
| 435 |
+
Uses AdaLayerNormContinuous: LayerNorm(affine=False) + Linear(bias=False) modulation.
|
| 436 |
+
"""
|
| 437 |
+
def __init__(self, hidden_size, out_channels):
|
| 438 |
+
super().__init__()
|
| 439 |
+
self.adaLN_modulation = AdaLNModulation(hidden_size, 2)
|
| 440 |
+
self.linear = nn.Linear(hidden_size, out_channels)
|
| 441 |
+
|
| 442 |
+
def __call__(self, x, vec):
|
| 443 |
+
params = self.adaLN_modulation(vec)
|
| 444 |
+
scale, shift = mx.split(params, 2, axis=-1)
|
| 445 |
+
x = nn.LayerNorm(x.shape[-1], affine=False, eps=1e-6)(x)
|
| 446 |
+
x = x * (1 + scale[:, None, :]) + shift[:, None, :]
|
| 447 |
+
return self.linear(x)
|
pipeline.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FLUX.1-schnell pipeline β end-to-end text-to-image on MLX.
|
| 2 |
+
|
| 3 |
+
Orchestrates the full inference chain:
|
| 4 |
+
1. Tokenize prompt (T5 + CLIP)
|
| 5 |
+
2. Encode text (T5 β embeddings, CLIP β pooled)
|
| 6 |
+
3. Initialize latent noise + patchify
|
| 7 |
+
4. Denoising loop (rectified flow, Euler steps)
|
| 8 |
+
5. Unpatchify + VAE decode β PIL.Image
|
| 9 |
+
|
| 10 |
+
Memory strategy: components are loaded/unloaded in phases so that
|
| 11 |
+
only one large model occupies unified memory at a time.
|
| 12 |
+
|
| 13 |
+
Usage::
|
| 14 |
+
|
| 15 |
+
pipe = FluxPipeline()
|
| 16 |
+
pipe.load("argmaxinc/mlx-FLUX.1-schnell-4bit-quantized")
|
| 17 |
+
img = pipe.generate("a cat in a garden", steps=4, width=512, height=512)
|
| 18 |
+
img.save("output.png")
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import mlx.core as mx
|
| 29 |
+
import mlx.nn as nn
|
| 30 |
+
from huggingface_hub import hf_hub_download
|
| 31 |
+
from PIL import Image
|
| 32 |
+
|
| 33 |
+
from .autoencoder import AutoencoderKL
|
| 34 |
+
from .clip_encoder import CLIPEncoder
|
| 35 |
+
from .flux_model import FluxTransformer
|
| 36 |
+
from .sampler import FlowMatchEulerScheduler, compute_img_ids, patchify, unpatchify
|
| 37 |
+
from .t5_encoder import T5Encoder
|
| 38 |
+
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
| 39 |
+
from . import weight_loader
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger("image-server")
|
| 42 |
+
|
| 43 |
+
# ββ Model repos ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
DEFAULT_MODEL = "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized"
|
| 46 |
+
FULL_MODEL = "argmaxinc/mlx-FLUX.1-schnell"
|
| 47 |
+
T5_CLIP_REPO = "black-forest-labs/FLUX.1-schnell"
|
| 48 |
+
|
| 49 |
+
# Local weights directory (co-located with pipeline code, won't be
|
| 50 |
+
# accidentally deleted when cleaning HF cache)
|
| 51 |
+
_WEIGHTS_DIR = Path(__file__).parent / "weights"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FluxPipeline:
|
| 55 |
+
"""FLUX.1-schnell pipeline for MLX (Apple Silicon).
|
| 56 |
+
|
| 57 |
+
Manages model loading, text encoding, denoising, and VAE decoding
|
| 58 |
+
with phase-based memory management.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, model_id: str | None = None, quantize: bool = False):
|
| 62 |
+
self.model_id = model_id or DEFAULT_MODEL
|
| 63 |
+
self.quantize = quantize
|
| 64 |
+
|
| 65 |
+
# Performance options
|
| 66 |
+
import os
|
| 67 |
+
self.use_compile = os.environ.get("IMAGE_MLX_COMPILE", "0") in ("1", "true", "yes")
|
| 68 |
+
self.phased_memory = os.environ.get("IMAGE_MLX_PHASED_MEM", "1") not in ("0", "false", "no")
|
| 69 |
+
|
| 70 |
+
# Components (lazily loaded)
|
| 71 |
+
self.t5_tokenizer: T5Tokenizer | None = None
|
| 72 |
+
self.clip_tokenizer: CLIPTokenizer | None = None
|
| 73 |
+
self.t5_encoder: T5Encoder | None = None
|
| 74 |
+
self.clip_encoder: CLIPEncoder | None = None
|
| 75 |
+
self.transformer: FluxTransformer | None = None
|
| 76 |
+
self.vae: AutoencoderKL | None = None
|
| 77 |
+
self.scheduler = FlowMatchEulerScheduler()
|
| 78 |
+
|
| 79 |
+
self._loaded = False
|
| 80 |
+
|
| 81 |
+
# ββ Loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
|
| 83 |
+
def load(self, model_id: str | None = None) -> None:
|
| 84 |
+
"""Download and load all model components."""
|
| 85 |
+
repo_id = model_id or self.model_id
|
| 86 |
+
t0 = time.time()
|
| 87 |
+
logger.info("[MLX] Loading FLUX pipeline from %s ...", repo_id)
|
| 88 |
+
|
| 89 |
+
# 1. Download files
|
| 90 |
+
dit_path = self._download_dit(repo_id)
|
| 91 |
+
ae_path = hf_hub_download(repo_id, "ae.safetensors")
|
| 92 |
+
t5_spiece = self._local_or_hf("t5_spiece.model", "tokenizer_2/spiece.model")
|
| 93 |
+
clip_vocab = self._local_or_hf("clip_vocab.json", "tokenizer/vocab.json")
|
| 94 |
+
clip_merges = self._local_or_hf("clip_merges.txt", "tokenizer/merges.txt")
|
| 95 |
+
|
| 96 |
+
# T5 weights (multi-shard)
|
| 97 |
+
t5_paths = self._download_t5_weights()
|
| 98 |
+
|
| 99 |
+
# CLIP weights
|
| 100 |
+
clip_path = self._download_clip_weights()
|
| 101 |
+
|
| 102 |
+
# 2. Init tokenizers
|
| 103 |
+
self.t5_tokenizer = T5Tokenizer(t5_spiece, max_length=256)
|
| 104 |
+
self.clip_tokenizer = CLIPTokenizer(clip_vocab, clip_merges, max_length=77)
|
| 105 |
+
|
| 106 |
+
# 3. Build models
|
| 107 |
+
self.t5_encoder = T5Encoder()
|
| 108 |
+
self.clip_encoder = CLIPEncoder()
|
| 109 |
+
self.transformer = FluxTransformer()
|
| 110 |
+
self.vae = AutoencoderKL()
|
| 111 |
+
|
| 112 |
+
# 4. Load weights (each component independently β partial loading OK)
|
| 113 |
+
if t5_paths:
|
| 114 |
+
try:
|
| 115 |
+
weight_loader.load_t5(t5_paths, self.t5_encoder)
|
| 116 |
+
except Exception as exc:
|
| 117 |
+
logger.warning("[MLX] T5 weight loading failed: %s β using random init", exc)
|
| 118 |
+
if clip_path:
|
| 119 |
+
try:
|
| 120 |
+
weight_loader.load_clip(clip_path, self.clip_encoder)
|
| 121 |
+
except Exception as exc:
|
| 122 |
+
logger.warning("[MLX] CLIP weight loading failed: %s β using random init", exc)
|
| 123 |
+
|
| 124 |
+
# DiT weights β load directly into transformer model
|
| 125 |
+
try:
|
| 126 |
+
self._load_dit(dit_path)
|
| 127 |
+
except Exception as exc:
|
| 128 |
+
logger.warning("[MLX] DiT weight loading failed: %s β using random init", exc)
|
| 129 |
+
|
| 130 |
+
# VAE weights
|
| 131 |
+
try:
|
| 132 |
+
self._load_vae(ae_path)
|
| 133 |
+
except Exception as exc:
|
| 134 |
+
logger.warning("[MLX] VAE weight loading failed: %s β using random init", exc)
|
| 135 |
+
|
| 136 |
+
self._loaded = True
|
| 137 |
+
logger.info("[MLX] FLUX pipeline ready (%.1fs)", time.time() - t0)
|
| 138 |
+
|
| 139 |
+
def _download_dit(self, repo_id: str) -> str:
|
| 140 |
+
"""Download DiT weights file."""
|
| 141 |
+
if "4bit" in repo_id:
|
| 142 |
+
return hf_hub_download(repo_id, "flux-schnell-4bit-quantized.safetensors")
|
| 143 |
+
return hf_hub_download(repo_id, "flux-schnell.safetensors")
|
| 144 |
+
|
| 145 |
+
def _download_t5_weights(self) -> list[str]:
|
| 146 |
+
"""Get T5-XXL encoder weight paths β local weights/ dir first, HF fallback."""
|
| 147 |
+
local1 = _WEIGHTS_DIR / "t5_shard1.safetensors"
|
| 148 |
+
local2 = _WEIGHTS_DIR / "t5_shard2.safetensors"
|
| 149 |
+
if local1.exists() and local2.exists():
|
| 150 |
+
logger.info("[MLX] T5 weights ready (local)")
|
| 151 |
+
return [str(local1), str(local2)]
|
| 152 |
+
try:
|
| 153 |
+
p1 = hf_hub_download(T5_CLIP_REPO, "text_encoder_2/model-00001-of-00002.safetensors")
|
| 154 |
+
p2 = hf_hub_download(T5_CLIP_REPO, "text_encoder_2/model-00002-of-00002.safetensors")
|
| 155 |
+
logger.info("[MLX] T5 weights ready (HF cache)")
|
| 156 |
+
return [p1, p2]
|
| 157 |
+
except Exception as exc:
|
| 158 |
+
logger.warning("[MLX] T5 weights not available: %s β text encoding will be limited", exc)
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
def _download_clip_weights(self) -> str | None:
|
| 162 |
+
"""Get CLIP encoder weight path β local weights/ dir first, HF fallback."""
|
| 163 |
+
local = _WEIGHTS_DIR / "clip_text_encoder.safetensors"
|
| 164 |
+
if local.exists():
|
| 165 |
+
logger.info("[MLX] CLIP weights ready (local)")
|
| 166 |
+
return str(local)
|
| 167 |
+
try:
|
| 168 |
+
path = hf_hub_download(T5_CLIP_REPO, "text_encoder/model.safetensors")
|
| 169 |
+
logger.info("[MLX] CLIP weights ready (HF cache)")
|
| 170 |
+
return path
|
| 171 |
+
except Exception as exc:
|
| 172 |
+
logger.warning("[MLX] CLIP weights not available: %s", exc)
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def _local_or_hf(local_name: str, hf_path: str) -> str:
|
| 177 |
+
"""Return local path if exists, else download from HF."""
|
| 178 |
+
local = _WEIGHTS_DIR / local_name
|
| 179 |
+
if local.exists():
|
| 180 |
+
return str(local)
|
| 181 |
+
return hf_hub_download(T5_CLIP_REPO, hf_path)
|
| 182 |
+
|
| 183 |
+
def _load_dit(self, path: str) -> None:
|
| 184 |
+
"""Load DiT weights into transformer.
|
| 185 |
+
|
| 186 |
+
For 4-bit quantized models, quantizes the model's Linear layers
|
| 187 |
+
to QuantizedLinear first so load_weights can accept uint32 triplets.
|
| 188 |
+
"""
|
| 189 |
+
logger.info("[MLX] Loading DiT weights from %s ...", os.path.basename(path))
|
| 190 |
+
# Use weight_loader's robust safetensors reader (handles bfloat16)
|
| 191 |
+
weights = weight_loader._load_safetensors(path)
|
| 192 |
+
|
| 193 |
+
# Detect if quantized: check for `.scales` keys
|
| 194 |
+
is_quantized = any(k.endswith(".scales") for k in weights)
|
| 195 |
+
if is_quantized:
|
| 196 |
+
logger.info("[MLX] Detected quantized weights β converting model layers")
|
| 197 |
+
# Quantize Linear layers except x_embedder (its weight is float, not quantized)
|
| 198 |
+
def _should_quantize(path, module):
|
| 199 |
+
if "x_embedder" in path:
|
| 200 |
+
return False
|
| 201 |
+
return isinstance(module, nn.Linear)
|
| 202 |
+
nn.quantize(self.transformer, group_size=64, bits=4, class_predicate=_should_quantize)
|
| 203 |
+
|
| 204 |
+
# Map x_embedder.proj.* β x_embedder.* (Conv2d weights β Linear)
|
| 205 |
+
remapped = {}
|
| 206 |
+
for k, v in weights.items():
|
| 207 |
+
new_k = k
|
| 208 |
+
if k.startswith("x_embedder.proj."):
|
| 209 |
+
new_k = k.replace("x_embedder.proj.", "x_embedder.")
|
| 210 |
+
# Squeeze conv dimensions: [out, 1, 1, in] β [out, in]
|
| 211 |
+
if v.ndim == 4:
|
| 212 |
+
v = v.squeeze()
|
| 213 |
+
remapped[new_k] = v
|
| 214 |
+
|
| 215 |
+
pairs = list(remapped.items())
|
| 216 |
+
self.transformer.load_weights(pairs, strict=False)
|
| 217 |
+
logger.info("[MLX] DiT loaded: %d weight tensors (quantized=%s)", len(pairs), is_quantized)
|
| 218 |
+
|
| 219 |
+
def _load_vae(self, path: str) -> None:
|
| 220 |
+
"""Load VAE weights. Transposes conv weights from PyTorch NCHW to MLX NHWC."""
|
| 221 |
+
logger.info("[MLX] Loading VAE weights from %s ...", os.path.basename(path))
|
| 222 |
+
weights = weight_loader._load_safetensors(path)
|
| 223 |
+
|
| 224 |
+
# Only keep decoder weights (skip encoder)
|
| 225 |
+
transposed = {}
|
| 226 |
+
n_transposed = 0
|
| 227 |
+
for k, v in weights.items():
|
| 228 |
+
if not k.startswith("decoder."):
|
| 229 |
+
continue
|
| 230 |
+
# Transpose conv weights: PyTorch [O, I, kH, kW] β MLX [O, kH, kW, I]
|
| 231 |
+
if v.ndim == 4:
|
| 232 |
+
v = v.transpose(0, 2, 3, 1)
|
| 233 |
+
n_transposed += 1
|
| 234 |
+
transposed[k] = v
|
| 235 |
+
|
| 236 |
+
pairs = list(transposed.items())
|
| 237 |
+
self.vae.load_weights(pairs, strict=False)
|
| 238 |
+
logger.info("[MLX] VAE loaded: %d tensors (%d conv transposed)", len(pairs), n_transposed)
|
| 239 |
+
|
| 240 |
+
# ββ Generation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 241 |
+
|
| 242 |
+
def generate(
|
| 243 |
+
self,
|
| 244 |
+
prompt: str,
|
| 245 |
+
*,
|
| 246 |
+
width: int = 512,
|
| 247 |
+
height: int = 512,
|
| 248 |
+
steps: int = 4,
|
| 249 |
+
seed: int | None = None,
|
| 250 |
+
progress_callback=None,
|
| 251 |
+
) -> Image.Image:
|
| 252 |
+
"""Generate an image from a text prompt.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
prompt: Text description.
|
| 256 |
+
width: Image width (rounded to multiple of 16).
|
| 257 |
+
height: Image height (rounded to multiple of 16).
|
| 258 |
+
steps: Denoising steps (default 4 for schnell).
|
| 259 |
+
seed: Random seed for reproducibility.
|
| 260 |
+
progress_callback: fn(step, total_steps) called per step.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
PIL.Image.
|
| 264 |
+
"""
|
| 265 |
+
if not self._loaded:
|
| 266 |
+
raise RuntimeError("Pipeline not loaded. Call load() first.")
|
| 267 |
+
|
| 268 |
+
# Round to multiple of 16
|
| 269 |
+
width = (width // 16) * 16
|
| 270 |
+
height = (height // 16) * 16
|
| 271 |
+
|
| 272 |
+
if seed is not None:
|
| 273 |
+
mx.random.seed(seed)
|
| 274 |
+
|
| 275 |
+
# ββ Phase 1: Text encoding ββββββββββββββββββββββββββββββββββββββ
|
| 276 |
+
logger.info("[MLX] Phase 1: Text encoding...")
|
| 277 |
+
t0 = time.time()
|
| 278 |
+
|
| 279 |
+
t5_ids = self.t5_tokenizer.tokenize(prompt)
|
| 280 |
+
clip_ids = self.clip_tokenizer.tokenize(prompt)
|
| 281 |
+
|
| 282 |
+
t5_embed = self.t5_encoder(t5_ids) # [1, 256, 4096]
|
| 283 |
+
clip_pooled, _ = self.clip_encoder(clip_ids) # [1, 768]
|
| 284 |
+
mx.eval(t5_embed, clip_pooled)
|
| 285 |
+
|
| 286 |
+
logger.info("[MLX] Text encoding done (%.1fs)", time.time() - t0)
|
| 287 |
+
|
| 288 |
+
# ββ Phase 1β2 transition: free text encoders ββββββββββββββββββββ
|
| 289 |
+
if self.phased_memory:
|
| 290 |
+
logger.info("[MLX] Releasing text encoders to free memory...")
|
| 291 |
+
self.t5_encoder = None
|
| 292 |
+
self.clip_encoder = None
|
| 293 |
+
mx.clear_cache()
|
| 294 |
+
|
| 295 |
+
# ββ Phase 2: Denoising ββββββββββββββββββββββββββββββββββββββββββ
|
| 296 |
+
logger.info("[MLX] Phase 2: Denoising (%d steps)...", steps)
|
| 297 |
+
t1 = time.time()
|
| 298 |
+
|
| 299 |
+
lat_h = height // 8
|
| 300 |
+
lat_w = width // 8
|
| 301 |
+
|
| 302 |
+
# Initial noise
|
| 303 |
+
noise = mx.random.normal((1, lat_h, lat_w, 16))
|
| 304 |
+
latents = patchify(noise) # [1, seq, 64]
|
| 305 |
+
img_ids = compute_img_ids(lat_h, lat_w)
|
| 306 |
+
txt_ids = mx.zeros((1, t5_embed.shape[1], 3), dtype=mx.int32)
|
| 307 |
+
|
| 308 |
+
# Compute sigma schedule with exponential time shift
|
| 309 |
+
image_seq_len = latents.shape[1]
|
| 310 |
+
timesteps_list, sigmas = self.scheduler.compute_sigmas(steps, image_seq_len)
|
| 311 |
+
|
| 312 |
+
# Optionally compile the transformer forward pass for speed
|
| 313 |
+
_forward_fn = self.transformer
|
| 314 |
+
if self.use_compile:
|
| 315 |
+
try:
|
| 316 |
+
_forward_fn = mx.compile(self.transformer)
|
| 317 |
+
logger.info("[MLX] Using mx.compile for DiT forward pass")
|
| 318 |
+
except Exception as exc:
|
| 319 |
+
logger.warning("[MLX] mx.compile not available: %s", exc)
|
| 320 |
+
|
| 321 |
+
for i in range(steps):
|
| 322 |
+
sigma = sigmas[i]
|
| 323 |
+
sigma_next = sigmas[i + 1]
|
| 324 |
+
t = mx.array([timesteps_list[i]])
|
| 325 |
+
|
| 326 |
+
# Scale latents before transformer
|
| 327 |
+
latents_scaled = self.scheduler.scale_latents(latents, sigma)
|
| 328 |
+
|
| 329 |
+
v_pred = _forward_fn(
|
| 330 |
+
latents_scaled, img_ids,
|
| 331 |
+
t5_embed, txt_ids,
|
| 332 |
+
clip_pooled, t,
|
| 333 |
+
)
|
| 334 |
+
mx.eval(v_pred)
|
| 335 |
+
|
| 336 |
+
# Euler step: dt = sigma_next - sigma
|
| 337 |
+
latents = self.scheduler.step(v_pred, sigma, sigma_next, latents)
|
| 338 |
+
mx.eval(latents)
|
| 339 |
+
|
| 340 |
+
if progress_callback:
|
| 341 |
+
progress_callback(i + 1, steps)
|
| 342 |
+
logger.debug("[MLX] Step %d/%d (sigma=%.4f)", i + 1, steps, sigma)
|
| 343 |
+
|
| 344 |
+
logger.info("[MLX] Denoising done (%.1fs)", time.time() - t1)
|
| 345 |
+
|
| 346 |
+
# ββ Phase 2β3 transition: free DiT to make room for VAE βββββββββ
|
| 347 |
+
if self.phased_memory:
|
| 348 |
+
logger.info("[MLX] Releasing DiT to free memory...")
|
| 349 |
+
self.transformer = None
|
| 350 |
+
mx.clear_cache()
|
| 351 |
+
|
| 352 |
+
# ββ Phase 3: VAE decode βββββββββββββββββββββββββββββββββββββββββ
|
| 353 |
+
logger.info("[MLX] Phase 3: VAE decode...")
|
| 354 |
+
t2 = time.time()
|
| 355 |
+
|
| 356 |
+
z = unpatchify(latents, lat_h, lat_w) # [1, lat_h, lat_w, 16]
|
| 357 |
+
image = self.vae.decode(z) # [1, H, W, 3]
|
| 358 |
+
mx.eval(image)
|
| 359 |
+
|
| 360 |
+
logger.info("[MLX] VAE decode done (%.1fs)", time.time() - t2)
|
| 361 |
+
|
| 362 |
+
# ββ Convert to PIL ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 363 |
+
import numpy as np
|
| 364 |
+
|
| 365 |
+
img_np = np.array(image[0], copy=False) # [H, W, 3]
|
| 366 |
+
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
| 367 |
+
return Image.fromarray(img_np)
|
| 368 |
+
|
| 369 |
+
# ββ Memory management ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 370 |
+
|
| 371 |
+
def unload(self) -> None:
|
| 372 |
+
"""Free all model components from memory."""
|
| 373 |
+
self.t5_encoder = None
|
| 374 |
+
self.clip_encoder = None
|
| 375 |
+
self.transformer = None
|
| 376 |
+
self.vae = None
|
| 377 |
+
self._loaded = False
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
mx.clear_cache()
|
| 381 |
+
except Exception:
|
| 382 |
+
pass
|
| 383 |
+
logger.info("[MLX] Pipeline unloaded")
|
| 384 |
+
|
| 385 |
+
def memory_info(self) -> dict:
|
| 386 |
+
"""Return MLX Metal memory usage info."""
|
| 387 |
+
try:
|
| 388 |
+
peak = mx.get_peak_memory() / (1024 ** 3)
|
| 389 |
+
active = mx.get_active_memory() / (1024 ** 3)
|
| 390 |
+
cache = mx.get_cache_memory() / (1024 ** 3)
|
| 391 |
+
return {
|
| 392 |
+
"peak_gb": round(peak, 2),
|
| 393 |
+
"active_gb": round(active, 2),
|
| 394 |
+
"cache_gb": round(cache, 2),
|
| 395 |
+
}
|
| 396 |
+
except AttributeError:
|
| 397 |
+
# Fallback for older MLX with mx.metal.* API
|
| 398 |
+
try:
|
| 399 |
+
peak = mx.metal.get_peak_memory() / (1024 ** 3)
|
| 400 |
+
active = mx.metal.get_active_memory() / (1024 ** 3)
|
| 401 |
+
cache = mx.metal.get_cache_memory() / (1024 ** 3)
|
| 402 |
+
return {
|
| 403 |
+
"peak_gb": round(peak, 2),
|
| 404 |
+
"active_gb": round(active, 2),
|
| 405 |
+
"cache_gb": round(cache, 2),
|
| 406 |
+
}
|
| 407 |
+
except Exception:
|
| 408 |
+
return {}
|
| 409 |
+
except Exception:
|
| 410 |
+
return {}
|
sampler.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow-matching Euler scheduler + latent patchify/unpatchify.
|
| 2 |
+
|
| 3 |
+
Matches mflux FlowMatchEulerDiscreteScheduler implementation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import mlx.core as mx
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FlowMatchEulerScheduler:
|
| 14 |
+
"""Rectified flow matching scheduler with exponential time shift.
|
| 15 |
+
|
| 16 |
+
Matches mflux's FlowMatchEulerDiscreteScheduler exactly.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def compute_sigmas(self, num_steps: int, image_seq_len: int = 256) -> tuple[list[float], list[float]]:
|
| 20 |
+
"""Compute sigma schedule with exponential time shift.
|
| 21 |
+
|
| 22 |
+
Returns (timesteps, sigmas) where:
|
| 23 |
+
- timesteps: sigma * 1000 (for transformer input)
|
| 24 |
+
- sigmas: actual sigma values including terminal 0.0
|
| 25 |
+
"""
|
| 26 |
+
# Linear base sigmas
|
| 27 |
+
sigmas = mx.linspace(1.0, 1.0 / num_steps, num_steps, dtype=mx.float32)
|
| 28 |
+
|
| 29 |
+
# Exponential time shift (mflux convention)
|
| 30 |
+
mu = self._compute_mu(image_seq_len, num_steps)
|
| 31 |
+
sigmas = self._time_shift(mu, 1.0, sigmas)
|
| 32 |
+
|
| 33 |
+
timesteps = (sigmas * 1000).tolist()
|
| 34 |
+
sigmas_list = sigmas.tolist() + [0.0]
|
| 35 |
+
|
| 36 |
+
return timesteps, sigmas_list
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def _compute_mu(image_seq_len: int, num_steps: int) -> float:
|
| 40 |
+
"""Empirical mu for time shift. Matches mflux."""
|
| 41 |
+
# mu = 0.5 * log(image_seq_len) β mflux empirical formula
|
| 42 |
+
return 0.5 * math.log(image_seq_len)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def _time_shift(mu: float, sigma_max: float, sigmas: mx.array) -> mx.array:
|
| 46 |
+
"""Exponential interpolation time shift."""
|
| 47 |
+
return mx.exp(mu) / (mx.exp(mu) + (1 / sigmas - 1))
|
| 48 |
+
|
| 49 |
+
def scale_latents(self, latents: mx.array, sigma: float) -> mx.array:
|
| 50 |
+
"""Scale latents before transformer input: x / sqrt(ΟΒ² + 1)."""
|
| 51 |
+
return latents / ((sigma ** 2 + 1) ** 0.5)
|
| 52 |
+
|
| 53 |
+
def step(
|
| 54 |
+
self,
|
| 55 |
+
velocity: mx.array,
|
| 56 |
+
sigma_cur: float,
|
| 57 |
+
sigma_next: float,
|
| 58 |
+
latents: mx.array,
|
| 59 |
+
) -> mx.array:
|
| 60 |
+
"""One Euler step in the flow ODE.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
velocity: v_ΞΈ prediction [B, seq, D]
|
| 64 |
+
sigma_cur: current sigma level
|
| 65 |
+
sigma_next: next sigma level
|
| 66 |
+
latents: current latent state [B, seq, D]
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Updated latents.
|
| 70 |
+
"""
|
| 71 |
+
dt = mx.array(sigma_cur - sigma_next, dtype=latents.dtype)
|
| 72 |
+
return latents + dt * velocity
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def patchify(latents: mx.array) -> mx.array:
|
| 76 |
+
"""Convert spatial latents to patch tokens.
|
| 77 |
+
|
| 78 |
+
[B, H, W, C] β [B, (H/2)Γ(W/2), CΓ4]
|
| 79 |
+
|
| 80 |
+
FLUX uses 2Γ2 patches over the latent space, flattening each
|
| 81 |
+
2Γ2ΓC patch into a single C*4 token.
|
| 82 |
+
"""
|
| 83 |
+
B, H, W, C = latents.shape
|
| 84 |
+
latents = latents.reshape(B, H // 2, 2, W // 2, 2, C)
|
| 85 |
+
latents = latents.transpose(0, 1, 3, 2, 4, 5) # [B, H/2, W/2, 2, 2, C]
|
| 86 |
+
return latents.reshape(B, (H // 2) * (W // 2), C * 4) # [B, seq, 64]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def unpatchify(tokens: mx.array, h: int, w: int) -> mx.array:
|
| 90 |
+
"""Convert patch tokens back to spatial latents.
|
| 91 |
+
|
| 92 |
+
[B, seq, C*4] β [B, H, W, C]
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
tokens: [B, (h/2)*(w/2), 64]
|
| 96 |
+
h, w: latent spatial dimensions (before patching)
|
| 97 |
+
"""
|
| 98 |
+
B = tokens.shape[0]
|
| 99 |
+
C = tokens.shape[-1] // 4 # 64 / 4 = 16
|
| 100 |
+
|
| 101 |
+
tokens = tokens.reshape(B, h // 2, w // 2, 2, 2, C)
|
| 102 |
+
tokens = tokens.transpose(0, 1, 3, 2, 4, 5) # [B, H, W, C]
|
| 103 |
+
return tokens.reshape(B, h, w, C)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def compute_img_ids(lat_h: int, lat_w: int) -> mx.array:
|
| 107 |
+
"""Compute position IDs for patchified latent tokens.
|
| 108 |
+
|
| 109 |
+
Returns [1, (lat_h/2)*(lat_w/2), 3] with (time=0, h_pos, w_pos).
|
| 110 |
+
"""
|
| 111 |
+
h_ids = mx.arange(lat_h // 2)
|
| 112 |
+
w_ids = mx.arange(lat_w // 2)
|
| 113 |
+
|
| 114 |
+
# Meshgrid: [lat_h//2, lat_w//2]
|
| 115 |
+
h_grid = mx.repeat(h_ids[:, None], lat_w // 2, axis=1)
|
| 116 |
+
w_grid = mx.repeat(w_ids[None, :], lat_h // 2, axis=0)
|
| 117 |
+
|
| 118 |
+
# Flatten and stack with time=0
|
| 119 |
+
seq_len = (lat_h // 2) * (lat_w // 2)
|
| 120 |
+
t_ids = mx.zeros((seq_len,), dtype=mx.int32)
|
| 121 |
+
h_flat = h_grid.reshape(-1)
|
| 122 |
+
w_flat = w_grid.reshape(-1)
|
| 123 |
+
|
| 124 |
+
ids = mx.stack([t_ids, h_flat, w_flat], axis=-1) # [seq, 3]
|
| 125 |
+
return ids.reshape(1, seq_len, 3) # [1, seq, 3]
|
t5_encoder.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""T5-XXL text encoder for FLUX pipeline.
|
| 2 |
+
|
| 3 |
+
Implements a 24-layer T5 encoder with relative position bias,
|
| 4 |
+
gated FFN (GeLU), and RMS LayerNorm β matching the HuggingFace
|
| 5 |
+
``google/t5-xxl`` architecture used by FLUX.1.
|
| 6 |
+
|
| 7 |
+
Weight source: ``black-forest-labs/FLUX.1-schnell`` β
|
| 8 |
+
``text_encoder_2/model-0000{1,2}-of-00002.safetensors``
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
import mlx.core as mx
|
| 16 |
+
import mlx.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ββ T5 Config (XXL) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
class T5Config:
|
| 22 |
+
vocab_size: int = 32128
|
| 23 |
+
d_model: int = 4096
|
| 24 |
+
d_ff: int = 10240
|
| 25 |
+
num_heads: int = 64
|
| 26 |
+
head_dim: int = 64 # d_model / num_heads
|
| 27 |
+
num_layers: int = 24
|
| 28 |
+
relative_attention_num_buckets: int = 32
|
| 29 |
+
relative_attention_max_distance: int = 128
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ββ Building blocks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
|
| 34 |
+
class T5RMSNorm(nn.Module):
|
| 35 |
+
"""T5-style RMS LayerNorm (no bias, no mean subtraction)."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, d: int, eps: float = 1e-6):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.weight = mx.ones((d,))
|
| 40 |
+
self.eps = eps
|
| 41 |
+
|
| 42 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 43 |
+
variance = mx.mean(x * x, axis=-1, keepdims=True)
|
| 44 |
+
x = x * mx.rsqrt(variance + self.eps)
|
| 45 |
+
return x * self.weight
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class T5RelativeAttention(nn.Module):
|
| 49 |
+
"""Multi-head attention with T5 relative position bias."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, cfg: T5Config, has_relative_bias: bool = False):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.num_heads = cfg.num_heads
|
| 54 |
+
self.head_dim = cfg.head_dim
|
| 55 |
+
d = cfg.d_model
|
| 56 |
+
|
| 57 |
+
self.q_proj = nn.Linear(d, d, bias=False)
|
| 58 |
+
self.k_proj = nn.Linear(d, d, bias=False)
|
| 59 |
+
self.v_proj = nn.Linear(d, d, bias=False)
|
| 60 |
+
self.out_proj = nn.Linear(d, d, bias=False)
|
| 61 |
+
|
| 62 |
+
self.has_relative_bias = has_relative_bias
|
| 63 |
+
if has_relative_bias:
|
| 64 |
+
self.rel_bias = nn.Embedding(
|
| 65 |
+
cfg.relative_attention_num_buckets, cfg.num_heads,
|
| 66 |
+
)
|
| 67 |
+
self._num_buckets = cfg.relative_attention_num_buckets
|
| 68 |
+
self._max_distance = cfg.relative_attention_max_distance
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def _relative_position_bucket(
|
| 72 |
+
relative_position: mx.array,
|
| 73 |
+
num_buckets: int = 32,
|
| 74 |
+
max_distance: int = 128,
|
| 75 |
+
) -> mx.array:
|
| 76 |
+
"""T5-style relative position bucketing.
|
| 77 |
+
|
| 78 |
+
Maps relative positions to bucket indices for the learned bias.
|
| 79 |
+
Bidirectional: first half for negative, second for positive.
|
| 80 |
+
"""
|
| 81 |
+
# Bidirectional: use half the buckets for each direction
|
| 82 |
+
num_buckets //= 2
|
| 83 |
+
# Sign-based offset
|
| 84 |
+
relative_buckets = (relative_position > 0).astype(mx.int32) * num_buckets
|
| 85 |
+
relative_position = mx.abs(relative_position)
|
| 86 |
+
|
| 87 |
+
# Small positions are mapped linearly
|
| 88 |
+
max_exact = num_buckets // 2
|
| 89 |
+
is_small = relative_position < max_exact
|
| 90 |
+
|
| 91 |
+
# Larger positions are mapped logarithmically
|
| 92 |
+
val = mx.log(relative_position.astype(mx.float32) / max_exact) / math.log(
|
| 93 |
+
max_distance / max_exact
|
| 94 |
+
)
|
| 95 |
+
val = val * (num_buckets - max_exact)
|
| 96 |
+
relative_position_if_large = (max_exact + val).astype(mx.int32)
|
| 97 |
+
relative_position_if_large = mx.minimum(
|
| 98 |
+
relative_position_if_large,
|
| 99 |
+
mx.array(num_buckets - 1, dtype=mx.int32),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
relative_buckets = relative_buckets + mx.where(
|
| 103 |
+
is_small,
|
| 104 |
+
relative_position.astype(mx.int32),
|
| 105 |
+
relative_position_if_large,
|
| 106 |
+
)
|
| 107 |
+
return relative_buckets
|
| 108 |
+
|
| 109 |
+
def _compute_bias(self, seq_len: int) -> mx.array:
|
| 110 |
+
"""Compute relative position bias [1, num_heads, seq, seq]."""
|
| 111 |
+
positions = mx.arange(seq_len)
|
| 112 |
+
# relative_position[i, j] = j - i
|
| 113 |
+
relative_position = positions[None, :] - positions[:, None]
|
| 114 |
+
buckets = self._relative_position_bucket(
|
| 115 |
+
relative_position,
|
| 116 |
+
num_buckets=self._num_buckets,
|
| 117 |
+
max_distance=self._max_distance,
|
| 118 |
+
)
|
| 119 |
+
# [seq, seq] β lookup β [seq, seq, num_heads]
|
| 120 |
+
bias = self.rel_bias(buckets)
|
| 121 |
+
# β [1, num_heads, seq, seq]
|
| 122 |
+
bias = bias.transpose(2, 0, 1).reshape(1, self.num_heads, seq_len, seq_len)
|
| 123 |
+
return bias
|
| 124 |
+
|
| 125 |
+
def __call__(
|
| 126 |
+
self, x: mx.array, position_bias: mx.array | None = None,
|
| 127 |
+
) -> tuple[mx.array, mx.array | None]:
|
| 128 |
+
B, L, _ = x.shape
|
| 129 |
+
H, D = self.num_heads, self.head_dim
|
| 130 |
+
|
| 131 |
+
q = self.q_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
|
| 132 |
+
k = self.k_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
|
| 133 |
+
v = self.v_proj(x).reshape(B, L, H, D).transpose(0, 2, 1, 3)
|
| 134 |
+
|
| 135 |
+
# Scaled dot-product attention
|
| 136 |
+
scale = math.sqrt(D)
|
| 137 |
+
scores = (q @ k.transpose(0, 1, 3, 2)) / scale # [B, H, L, L]
|
| 138 |
+
|
| 139 |
+
# Add relative position bias
|
| 140 |
+
if self.has_relative_bias:
|
| 141 |
+
position_bias = self._compute_bias(L)
|
| 142 |
+
if position_bias is not None:
|
| 143 |
+
scores = scores + position_bias
|
| 144 |
+
|
| 145 |
+
weights = mx.softmax(scores, axis=-1)
|
| 146 |
+
out = weights @ v # [B, H, L, D]
|
| 147 |
+
out = out.transpose(0, 2, 1, 3).reshape(B, L, -1) # [B, L, d_model]
|
| 148 |
+
return self.out_proj(out), position_bias
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class T5GatedFFN(nn.Module):
|
| 152 |
+
"""T5 Gated Feed-Forward Network (GeLU gate)."""
|
| 153 |
+
|
| 154 |
+
def __init__(self, cfg: T5Config):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.wi_0 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) # gate
|
| 157 |
+
self.wi_1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) # value
|
| 158 |
+
self.wo = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 159 |
+
|
| 160 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 161 |
+
gate = nn.gelu_approx(self.wi_0(x))
|
| 162 |
+
value = self.wi_1(x)
|
| 163 |
+
return self.wo(gate * value)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class T5EncoderLayer(nn.Module):
|
| 167 |
+
"""Single T5 encoder layer: Norm β Attention β Norm β FFN."""
|
| 168 |
+
|
| 169 |
+
def __init__(self, cfg: T5Config, has_relative_bias: bool = False):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.attn_norm = T5RMSNorm(cfg.d_model)
|
| 172 |
+
self.attn = T5RelativeAttention(cfg, has_relative_bias=has_relative_bias)
|
| 173 |
+
self.ffn_norm = T5RMSNorm(cfg.d_model)
|
| 174 |
+
self.ffn = T5GatedFFN(cfg)
|
| 175 |
+
|
| 176 |
+
def __call__(
|
| 177 |
+
self, x: mx.array, position_bias: mx.array | None = None,
|
| 178 |
+
) -> tuple[mx.array, mx.array | None]:
|
| 179 |
+
# Pre-norm attention
|
| 180 |
+
residual = x
|
| 181 |
+
x = self.attn_norm(x)
|
| 182 |
+
x, position_bias = self.attn(x, position_bias)
|
| 183 |
+
x = residual + x
|
| 184 |
+
|
| 185 |
+
# Pre-norm FFN
|
| 186 |
+
residual = x
|
| 187 |
+
x = self.ffn_norm(x)
|
| 188 |
+
x = self.ffn(x)
|
| 189 |
+
x = residual + x
|
| 190 |
+
|
| 191 |
+
return x, position_bias
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ββ T5 Encoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 195 |
+
|
| 196 |
+
class T5Encoder(nn.Module):
|
| 197 |
+
"""T5-XXL encoder: 24-layer transformer with relative position bias.
|
| 198 |
+
|
| 199 |
+
Input: token_ids [B, seq_len]
|
| 200 |
+
Output: embeddings [B, seq_len, 4096]
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, cfg: T5Config | None = None):
|
| 204 |
+
super().__init__()
|
| 205 |
+
if cfg is None:
|
| 206 |
+
cfg = T5Config()
|
| 207 |
+
self.cfg = cfg
|
| 208 |
+
|
| 209 |
+
self.wte = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 210 |
+
self.layers = [
|
| 211 |
+
T5EncoderLayer(cfg, has_relative_bias=(i == 0))
|
| 212 |
+
for i in range(cfg.num_layers)
|
| 213 |
+
]
|
| 214 |
+
self.final_norm = T5RMSNorm(cfg.d_model)
|
| 215 |
+
|
| 216 |
+
def __call__(self, token_ids: mx.array, use_transformer: bool = True) -> mx.array:
|
| 217 |
+
x = self.wte(token_ids) # [B, L, d_model]
|
| 218 |
+
|
| 219 |
+
if use_transformer:
|
| 220 |
+
# Full 24-layer transformer (requires correct numerical implementation)
|
| 221 |
+
position_bias = None
|
| 222 |
+
for layer in self.layers:
|
| 223 |
+
x, position_bias = layer(x, position_bias)
|
| 224 |
+
|
| 225 |
+
x = self.final_norm(x)
|
| 226 |
+
return x
|
tokenizers.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenizers for FLUX pipeline β T5 (SentencePiece) and CLIP (BPE).
|
| 2 |
+
|
| 3 |
+
Both tokenizers produce mx.array token ID tensors ready for encoder input.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import mlx.core as mx
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger("image-server")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class T5Tokenizer:
|
| 18 |
+
"""T5-XXL SentencePiece tokenizer.
|
| 19 |
+
|
| 20 |
+
Loads ``spiece.model`` from the FLUX.1-schnell repo
|
| 21 |
+
(``tokenizer_2/spiece.model``).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, spiece_path: str, max_length: int = 256):
|
| 25 |
+
import sentencepiece as spm
|
| 26 |
+
|
| 27 |
+
self._sp = spm.SentencePieceProcessor()
|
| 28 |
+
self._sp.Load(spiece_path)
|
| 29 |
+
self.max_length = max_length
|
| 30 |
+
self.pad_id = 0
|
| 31 |
+
|
| 32 |
+
def tokenize(self, text: str) -> mx.array:
|
| 33 |
+
"""Tokenize text β [1, max_length] int32 tensor."""
|
| 34 |
+
ids = self._sp.Encode(text)
|
| 35 |
+
# Truncate
|
| 36 |
+
if len(ids) > self.max_length:
|
| 37 |
+
ids = ids[: self.max_length]
|
| 38 |
+
# Pad
|
| 39 |
+
pad_len = self.max_length - len(ids)
|
| 40 |
+
if pad_len > 0:
|
| 41 |
+
ids = ids + [self.pad_id] * pad_len
|
| 42 |
+
return mx.array(ids, dtype=mx.int32).reshape(1, -1)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CLIPTokenizer:
|
| 46 |
+
"""CLIP BPE tokenizer.
|
| 47 |
+
|
| 48 |
+
Loads ``vocab.json`` (tokenβid) and ``merges.txt`` (BPE merge rules)
|
| 49 |
+
from ``tokenizer/`` in the FLUX.1-schnell repo.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
BOS_ID = 49406 # <|startoftext|>
|
| 53 |
+
EOS_ID = 49407 # <|endoftext|>
|
| 54 |
+
|
| 55 |
+
def __init__(self, vocab_path: str, merges_path: str, max_length: int = 77):
|
| 56 |
+
# Load vocab: token_str β id
|
| 57 |
+
with open(vocab_path, encoding="utf-8") as f:
|
| 58 |
+
self._vocab: dict[str, int] = json.load(f)
|
| 59 |
+
|
| 60 |
+
# Load BPE merges from merges.txt
|
| 61 |
+
self._merges: list[tuple[str, str]] = []
|
| 62 |
+
self._merge_rank: dict[tuple[str, str], int] = {}
|
| 63 |
+
with open(merges_path, encoding="utf-8") as f:
|
| 64 |
+
for i, line in enumerate(f):
|
| 65 |
+
line = line.strip()
|
| 66 |
+
if not line or line.startswith("#"):
|
| 67 |
+
continue
|
| 68 |
+
parts = line.split()
|
| 69 |
+
if len(parts) == 2:
|
| 70 |
+
pair = (parts[0], parts[1])
|
| 71 |
+
self._merges.append(pair)
|
| 72 |
+
self._merge_rank[pair] = i
|
| 73 |
+
|
| 74 |
+
self.max_length = max_length
|
| 75 |
+
self.pad_id = 0
|
| 76 |
+
|
| 77 |
+
# pre/post processing regex (simplified CLIP pattern)
|
| 78 |
+
import regex
|
| 79 |
+
self._pat = regex.compile(
|
| 80 |
+
r"""'s|'t|'re|'ve|'m|'ll|'d|"""
|
| 81 |
+
r"""[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
| 82 |
+
regex.IGNORECASE,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def _bpe(self, token: str) -> list[str]:
|
| 86 |
+
"""Apply BPE merges to a single word token."""
|
| 87 |
+
if len(token) <= 1:
|
| 88 |
+
return [token + "</w>"] if token else []
|
| 89 |
+
|
| 90 |
+
# Add end-of-word marker
|
| 91 |
+
word = list(token[:-1]) + [token[-1] + "</w>"]
|
| 92 |
+
|
| 93 |
+
while len(word) > 1:
|
| 94 |
+
# Find the highest-priority merge pair
|
| 95 |
+
best_pair = None
|
| 96 |
+
best_rank = float("inf")
|
| 97 |
+
for i in range(len(word) - 1):
|
| 98 |
+
pair = (word[i], word[i + 1])
|
| 99 |
+
rank = self._merge_rank.get(pair, float("inf"))
|
| 100 |
+
if rank < best_rank:
|
| 101 |
+
best_rank = rank
|
| 102 |
+
best_pair = pair
|
| 103 |
+
|
| 104 |
+
if best_pair is None or best_rank == float("inf"):
|
| 105 |
+
break
|
| 106 |
+
|
| 107 |
+
# Apply the merge
|
| 108 |
+
new_word = []
|
| 109 |
+
i = 0
|
| 110 |
+
while i < len(word):
|
| 111 |
+
if (
|
| 112 |
+
i < len(word) - 1
|
| 113 |
+
and word[i] == best_pair[0]
|
| 114 |
+
and word[i + 1] == best_pair[1]
|
| 115 |
+
):
|
| 116 |
+
new_word.append(best_pair[0] + best_pair[1])
|
| 117 |
+
i += 2
|
| 118 |
+
else:
|
| 119 |
+
new_word.append(word[i])
|
| 120 |
+
i += 1
|
| 121 |
+
word = new_word
|
| 122 |
+
|
| 123 |
+
return word
|
| 124 |
+
|
| 125 |
+
def tokenize(self, text: str) -> mx.array:
|
| 126 |
+
"""Tokenize text β [1, max_length] int32 tensor."""
|
| 127 |
+
text = text.lower().strip()
|
| 128 |
+
|
| 129 |
+
ids = [self.BOS_ID]
|
| 130 |
+
|
| 131 |
+
# Tokenize each word
|
| 132 |
+
for match in self._pat.finditer(text):
|
| 133 |
+
word = match.group()
|
| 134 |
+
bpe_tokens = self._bpe(word)
|
| 135 |
+
for bt in bpe_tokens:
|
| 136 |
+
token_id = self._vocab.get(bt, 0)
|
| 137 |
+
ids.append(token_id)
|
| 138 |
+
|
| 139 |
+
ids.append(self.EOS_ID)
|
| 140 |
+
|
| 141 |
+
# Truncate (keep BOS at start, EOS at end)
|
| 142 |
+
if len(ids) > self.max_length:
|
| 143 |
+
ids = ids[: self.max_length - 1] + [self.EOS_ID]
|
| 144 |
+
|
| 145 |
+
# Pad
|
| 146 |
+
pad_len = self.max_length - len(ids)
|
| 147 |
+
if pad_len > 0:
|
| 148 |
+
ids = ids + [self.pad_id] * pad_len
|
| 149 |
+
|
| 150 |
+
return mx.array(ids, dtype=mx.int32).reshape(1, -1)
|
weight_loader.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Weight loader β safetensors β MLX model parameter mapping.
|
| 2 |
+
|
| 3 |
+
Handles multi-shard loading, HuggingFace key β self-defined key mapping,
|
| 4 |
+
and dtype conversion. Each ``load_*`` function takes a model instance
|
| 5 |
+
and populates its parameters in-place.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import mlx.core as mx
|
| 14 |
+
import mlx.nn as nn
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("image-server")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ββ Utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
def _load_safetensors(*paths: str) -> dict[str, mx.array]:
|
| 22 |
+
"""Load one or more safetensors files into a flat dict.
|
| 23 |
+
|
| 24 |
+
Tries MLX framework first. Falls back to PyTorch for bfloat16 weights,
|
| 25 |
+
preserving uint32 quantized weights as-is.
|
| 26 |
+
"""
|
| 27 |
+
from safetensors import safe_open
|
| 28 |
+
|
| 29 |
+
weights: dict[str, mx.array] = {}
|
| 30 |
+
for p in paths:
|
| 31 |
+
# Try mlx framework first (fastest)
|
| 32 |
+
try:
|
| 33 |
+
with safe_open(p, framework="mlx") as f:
|
| 34 |
+
for key in f.keys():
|
| 35 |
+
weights[key] = f.get_tensor(key)
|
| 36 |
+
logger.info("[WeightLoader] Loaded %s via MLX framework", p.split("/")[-1])
|
| 37 |
+
continue
|
| 38 |
+
except Exception:
|
| 39 |
+
pass # bfloat16 or other incompatibility β fallback to pt
|
| 40 |
+
|
| 41 |
+
# Fallback: load via PyTorch, selectively convert
|
| 42 |
+
try:
|
| 43 |
+
import torch
|
| 44 |
+
with safe_open(p, framework="pt") as f:
|
| 45 |
+
for key in f.keys():
|
| 46 |
+
t = f.get_tensor(key)
|
| 47 |
+
if t.dtype == torch.uint32:
|
| 48 |
+
# Quantized weight β keep as uint32
|
| 49 |
+
weights[key] = mx.array(t.numpy())
|
| 50 |
+
elif t.dtype == torch.bfloat16:
|
| 51 |
+
# bfloat16 β float32 β mx (bfloat16 not supported by numpy)
|
| 52 |
+
weights[key] = mx.array(t.float().numpy())
|
| 53 |
+
else:
|
| 54 |
+
# float32, float16, etc. β direct conversion
|
| 55 |
+
weights[key] = mx.array(t.numpy())
|
| 56 |
+
logger.info("[WeightLoader] Loaded %s via PyTorch fallback (%d tensors)", p.split("/")[-1], len(weights))
|
| 57 |
+
except Exception as exc:
|
| 58 |
+
logger.error("[WeightLoader] Failed to load %s: %s", p, exc)
|
| 59 |
+
return weights
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _set_nested(model: nn.Module, dotpath: str, value: mx.array) -> bool:
|
| 63 |
+
"""Set a parameter on *model* using a dot-separated path.
|
| 64 |
+
|
| 65 |
+
Returns True if set successfully, False if path does not exist.
|
| 66 |
+
"""
|
| 67 |
+
parts = dotpath.split(".")
|
| 68 |
+
obj = model
|
| 69 |
+
for part in parts[:-1]:
|
| 70 |
+
if part.isdigit():
|
| 71 |
+
obj = obj[int(part)]
|
| 72 |
+
elif hasattr(obj, part):
|
| 73 |
+
obj = getattr(obj, part)
|
| 74 |
+
else:
|
| 75 |
+
return False
|
| 76 |
+
final = parts[-1]
|
| 77 |
+
if hasattr(obj, final):
|
| 78 |
+
setattr(obj, final, value)
|
| 79 |
+
return True
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ββ T5 Weight Loader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
|
| 85 |
+
# HuggingFace T5 key β our T5Encoder parameter path
|
| 86 |
+
_T5_KEY_MAP_TEMPLATES = {
|
| 87 |
+
# Embedding
|
| 88 |
+
"encoder.embed_tokens.weight": "wte.weight",
|
| 89 |
+
"shared.weight": "wte.weight", # some checkpoints use this key
|
| 90 |
+
|
| 91 |
+
# Per-layer keys (use {i} placeholder)
|
| 92 |
+
"encoder.block.{i}.layer.0.layer_norm.weight": "layers.{i}.attn_norm.weight",
|
| 93 |
+
"encoder.block.{i}.layer.0.SelfAttention.q.weight": "layers.{i}.attn.q_proj.weight",
|
| 94 |
+
"encoder.block.{i}.layer.0.SelfAttention.k.weight": "layers.{i}.attn.k_proj.weight",
|
| 95 |
+
"encoder.block.{i}.layer.0.SelfAttention.v.weight": "layers.{i}.attn.v_proj.weight",
|
| 96 |
+
"encoder.block.{i}.layer.0.SelfAttention.o.weight": "layers.{i}.attn.out_proj.weight",
|
| 97 |
+
"encoder.block.{i}.layer.1.layer_norm.weight": "layers.{i}.ffn_norm.weight",
|
| 98 |
+
"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": "layers.{i}.ffn.wi_0.weight",
|
| 99 |
+
"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": "layers.{i}.ffn.wi_1.weight",
|
| 100 |
+
"encoder.block.{i}.layer.1.DenseReluDense.wo.weight": "layers.{i}.ffn.wo.weight",
|
| 101 |
+
|
| 102 |
+
# Relative attention bias (only layer 0)
|
| 103 |
+
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight":
|
| 104 |
+
"layers.0.attn.rel_bias.weight",
|
| 105 |
+
|
| 106 |
+
# Final norm
|
| 107 |
+
"encoder.final_layer_norm.weight": "final_norm.weight",
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _build_t5_key_map(num_layers: int = 24) -> dict[str, str]:
|
| 112 |
+
"""Expand layer-templated keys for all layers."""
|
| 113 |
+
mapping: dict[str, str] = {}
|
| 114 |
+
for hf_template, our_template in _T5_KEY_MAP_TEMPLATES.items():
|
| 115 |
+
if "{i}" in hf_template:
|
| 116 |
+
for i in range(num_layers):
|
| 117 |
+
hf_key = hf_template.replace("{i}", str(i))
|
| 118 |
+
our_key = our_template.replace("{i}", str(i))
|
| 119 |
+
mapping[hf_key] = our_key
|
| 120 |
+
else:
|
| 121 |
+
mapping[hf_template] = our_template
|
| 122 |
+
return mapping
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_t5(paths: list[str], model) -> None:
|
| 126 |
+
"""Load T5 encoder weights from safetensors into a T5Encoder model."""
|
| 127 |
+
weights = _load_safetensors(*paths)
|
| 128 |
+
key_map = _build_t5_key_map(num_layers=model.cfg.num_layers)
|
| 129 |
+
|
| 130 |
+
loaded = 0
|
| 131 |
+
unmapped = []
|
| 132 |
+
for hf_key, tensor in weights.items():
|
| 133 |
+
our_key = key_map.get(hf_key)
|
| 134 |
+
if our_key is None:
|
| 135 |
+
unmapped.append(hf_key)
|
| 136 |
+
continue
|
| 137 |
+
# T5 Linear weights are stored transposed in HF format:
|
| 138 |
+
# HF shape [out, in] but nn.Linear expects [out, in] too in MLX
|
| 139 |
+
# (MLX Linear does x @ W.T, so weight shape = [out, in])
|
| 140 |
+
if _set_nested(model, our_key, tensor):
|
| 141 |
+
loaded += 1
|
| 142 |
+
else:
|
| 143 |
+
unmapped.append(f"{hf_key} β {our_key} (path not found)")
|
| 144 |
+
|
| 145 |
+
logger.info(
|
| 146 |
+
"[WeightLoader] T5: loaded %d/%d params, unmapped %d",
|
| 147 |
+
loaded, len(weights), len(unmapped),
|
| 148 |
+
)
|
| 149 |
+
if unmapped:
|
| 150 |
+
for k in unmapped[:10]:
|
| 151 |
+
logger.debug(" unmapped: %s", k)
|
| 152 |
+
if len(unmapped) > 10:
|
| 153 |
+
logger.debug(" ... and %d more", len(unmapped) - 10)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ββ CLIP Weight Loader βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
+
|
| 158 |
+
_CLIP_KEY_MAP_TEMPLATES = {
|
| 159 |
+
# Embeddings
|
| 160 |
+
"text_model.embeddings.token_embedding.weight": "token_emb.weight",
|
| 161 |
+
"text_model.embeddings.position_embedding.weight": "pos_emb.weight",
|
| 162 |
+
|
| 163 |
+
# Per-layer keys
|
| 164 |
+
"text_model.encoder.layers.{i}.layer_norm1.weight": "layers.{i}.norm1.weight",
|
| 165 |
+
"text_model.encoder.layers.{i}.layer_norm1.bias": "layers.{i}.norm1.bias",
|
| 166 |
+
"text_model.encoder.layers.{i}.self_attn.q_proj.weight": "layers.{i}.attn.q_proj.weight",
|
| 167 |
+
"text_model.encoder.layers.{i}.self_attn.q_proj.bias": "layers.{i}.attn.q_proj.bias",
|
| 168 |
+
"text_model.encoder.layers.{i}.self_attn.k_proj.weight": "layers.{i}.attn.k_proj.weight",
|
| 169 |
+
"text_model.encoder.layers.{i}.self_attn.k_proj.bias": "layers.{i}.attn.k_proj.bias",
|
| 170 |
+
"text_model.encoder.layers.{i}.self_attn.v_proj.weight": "layers.{i}.attn.v_proj.weight",
|
| 171 |
+
"text_model.encoder.layers.{i}.self_attn.v_proj.bias": "layers.{i}.attn.v_proj.bias",
|
| 172 |
+
"text_model.encoder.layers.{i}.self_attn.out_proj.weight": "layers.{i}.attn.out_proj.weight",
|
| 173 |
+
"text_model.encoder.layers.{i}.self_attn.out_proj.bias": "layers.{i}.attn.out_proj.bias",
|
| 174 |
+
"text_model.encoder.layers.{i}.layer_norm2.weight": "layers.{i}.norm2.weight",
|
| 175 |
+
"text_model.encoder.layers.{i}.layer_norm2.bias": "layers.{i}.norm2.bias",
|
| 176 |
+
"text_model.encoder.layers.{i}.mlp.fc1.weight": "layers.{i}.mlp.fc1.weight",
|
| 177 |
+
"text_model.encoder.layers.{i}.mlp.fc1.bias": "layers.{i}.mlp.fc1.bias",
|
| 178 |
+
"text_model.encoder.layers.{i}.mlp.fc2.weight": "layers.{i}.mlp.fc2.weight",
|
| 179 |
+
"text_model.encoder.layers.{i}.mlp.fc2.bias": "layers.{i}.mlp.fc2.bias",
|
| 180 |
+
|
| 181 |
+
# Final norm
|
| 182 |
+
"text_model.final_layer_norm.weight": "final_norm.weight",
|
| 183 |
+
"text_model.final_layer_norm.bias": "final_norm.bias",
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _build_clip_key_map(num_layers: int = 23) -> dict[str, str]:
|
| 188 |
+
"""Expand layer-templated keys for all layers."""
|
| 189 |
+
mapping: dict[str, str] = {}
|
| 190 |
+
for hf_template, our_template in _CLIP_KEY_MAP_TEMPLATES.items():
|
| 191 |
+
if "{i}" in hf_template:
|
| 192 |
+
for i in range(num_layers):
|
| 193 |
+
hf_key = hf_template.replace("{i}", str(i))
|
| 194 |
+
our_key = our_template.replace("{i}", str(i))
|
| 195 |
+
mapping[hf_key] = our_key
|
| 196 |
+
else:
|
| 197 |
+
mapping[hf_template] = our_template
|
| 198 |
+
return mapping
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def load_clip(path: str, model) -> None:
|
| 202 |
+
"""Load CLIP encoder weights from safetensors into a CLIPEncoder model."""
|
| 203 |
+
weights = _load_safetensors(path)
|
| 204 |
+
key_map = _build_clip_key_map(num_layers=model.cfg.num_layers)
|
| 205 |
+
|
| 206 |
+
loaded = 0
|
| 207 |
+
unmapped = []
|
| 208 |
+
for hf_key, tensor in weights.items():
|
| 209 |
+
our_key = key_map.get(hf_key)
|
| 210 |
+
if our_key is None:
|
| 211 |
+
unmapped.append(hf_key)
|
| 212 |
+
continue
|
| 213 |
+
if _set_nested(model, our_key, tensor):
|
| 214 |
+
loaded += 1
|
| 215 |
+
else:
|
| 216 |
+
unmapped.append(f"{hf_key} β {our_key} (path not found)")
|
| 217 |
+
|
| 218 |
+
logger.info(
|
| 219 |
+
"[WeightLoader] CLIP: loaded %d/%d params, unmapped %d",
|
| 220 |
+
loaded, len(weights), len(unmapped),
|
| 221 |
+
)
|
| 222 |
+
if unmapped:
|
| 223 |
+
for k in unmapped[:10]:
|
| 224 |
+
logger.debug(" unmapped: %s", k)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ββ Placeholder for future rounds ββββββββββββββββββββββββββββββββββββββββββββ
|
| 228 |
+
|
| 229 |
+
def load_flux_dit(path: str, model) -> None:
|
| 230 |
+
"""Load FLUX DiT weights. Implemented in C-122b."""
|
| 231 |
+
raise NotImplementedError("FLUX DiT weight loading β see C-122b")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def load_vae(path: str, model) -> None:
|
| 235 |
+
"""Load VAE decoder weights. Implemented in C-122c."""
|
| 236 |
+
raise NotImplementedError("VAE weight loading β see C-122c")
|