File size: 6,273 Bytes
31f3da5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """FLUX VAE decoder β param names match argmaxinc ae.safetensors keys.
Weight key structure:
decoder.conv_in.*
decoder.mid.block_{1,2}.{norm1,conv1,norm2,conv2}.*
decoder.mid.attn_1.{norm,q,k,v,proj_out}.*
decoder.up.{0-3}.block.{0-2}.{norm1,conv1,norm2,conv2,nin_shortcut}.*
decoder.up.{1-3}.upsample.conv.*
decoder.norm_out.*
decoder.conv_out.*
Note: up blocks are indexed in reverse β up.3 is the first decoder stage
(highest channels), up.0 is the last (lowest channels).
All conv weights loaded as PyTorch [O,I,kH,kW] are transposed to MLX
[O,kH,kW,I] in the pipeline's _load_vae().
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
# ββ Building blocks (param names match weight keys) ββββββββββββββββββββββββββ
class ResnetBlock(nn.Module):
"""Matches: block_{i}.{norm1,conv1,norm2,conv2,nin_shortcut}.*"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
if in_ch != out_ch:
self.nin_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1)
else:
self.nin_shortcut = None
def __call__(self, x):
h = nn.silu(self.norm1(x))
h = self.conv1(h)
h = nn.silu(self.norm2(h))
h = self.conv2(h)
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
"""Matches: attn_1.{norm,q,k,v,proj_out}.*
Uses 1Γ1 Conv2d for Q/K/V/O projections (matching weight shapes).
"""
def __init__(self, channels: int):
super().__init__()
self.norm = nn.GroupNorm(32, channels)
self.q = nn.Conv2d(channels, channels, kernel_size=1)
self.k = nn.Conv2d(channels, channels, kernel_size=1)
self.v = nn.Conv2d(channels, channels, kernel_size=1)
self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)
def __call__(self, x):
B, H, W, C = x.shape
h = self.norm(x)
q = self.q(h).reshape(B, H * W, C)
k = self.k(h).reshape(B, H * W, C)
v = self.v(h).reshape(B, H * W, C)
scale = C ** -0.5
attn = (q @ k.transpose(0, 2, 1)) * scale
attn = mx.softmax(attn, axis=-1)
h = (attn @ v).reshape(B, H, W, C)
return x + self.proj_out(h)
class Upsample(nn.Module):
"""Matches: upsample.conv.*"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def __call__(self, x):
B, H, W, C = x.shape
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
return self.conv(x)
class UpBlock(nn.Module):
"""One decoder up-stage. Matches: up.{i}.block.{0-2}.* + up.{i}.upsample.*"""
def __init__(self, in_ch: int, out_ch: int, num_blocks: int = 3, has_upsample: bool = True):
super().__init__()
self.block = [ResnetBlock(in_ch if j == 0 else out_ch, out_ch) for j in range(num_blocks)]
if has_upsample:
self.upsample = Upsample(out_ch)
else:
self.upsample = None
def __call__(self, x):
for b in self.block:
x = b(x)
if self.upsample is not None:
x = self.upsample(x)
return x
class MidBlock(nn.Module):
"""Matches: mid.{block_1, attn_1, block_2}.*"""
def __init__(self, channels: int):
super().__init__()
self.block_1 = ResnetBlock(channels, channels)
self.attn_1 = AttnBlock(channels)
self.block_2 = ResnetBlock(channels, channels)
def __call__(self, x):
x = self.block_1(x)
x = self.attn_1(x)
x = self.block_2(x)
return x
# ββ Decoder ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Decoder(nn.Module):
"""VAE Decoder. Param paths match: decoder.{conv_in,mid,up,norm_out,conv_out}.*
Up block order (matching weight keys):
up.3 β 512β512 + upsample (first stage)
up.2 β 512β512 + upsample
up.1 β 512β256 + upsample
up.0 β 256β128 (no upsample, last stage)
"""
def __init__(self):
super().__init__()
self.conv_in = nn.Conv2d(16, 512, kernel_size=3, padding=1)
self.mid = MidBlock(512)
# up blocks β indexed 0-3, processed in reverse order (3β2β1β0)
self.up = [
UpBlock(256, 128, num_blocks=3, has_upsample=False), # up.0
UpBlock(512, 256, num_blocks=3, has_upsample=True), # up.1
UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.2
UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.3
]
self.norm_out = nn.GroupNorm(32, 128)
self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1)
def __call__(self, z):
h = self.conv_in(z)
h = self.mid(h)
# Process up blocks in reverse order: 3, 2, 1, 0
for i in reversed(range(len(self.up))):
h = self.up[i](h)
h = nn.silu(self.norm_out(h))
h = self.conv_out(h)
return h
# ββ AutoencoderKL ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class AutoencoderKL(nn.Module):
"""FLUX VAE β decode-only path.
Input: z [B, H/8, W/8, 16] (latent, channels-last)
Output: image [B, H, W, 3] (RGB in [0, 1])
"""
SCALE_FACTOR = 0.3611
SHIFT_FACTOR = 0.1159
def __init__(self):
super().__init__()
self.decoder = Decoder()
def decode(self, z: mx.array) -> mx.array:
z = z / self.SCALE_FACTOR + self.SHIFT_FACTOR
image = self.decoder(z)
image = mx.clip((image + 1.0) / 2.0, 0.0, 1.0)
return image
|