FLUX.1-schnell-MLX / autoencoder.py
illusion615's picture
Upload folder using huggingface_hub
31f3da5 verified
"""FLUX VAE decoder β€” param names match argmaxinc ae.safetensors keys.
Weight key structure:
decoder.conv_in.*
decoder.mid.block_{1,2}.{norm1,conv1,norm2,conv2}.*
decoder.mid.attn_1.{norm,q,k,v,proj_out}.*
decoder.up.{0-3}.block.{0-2}.{norm1,conv1,norm2,conv2,nin_shortcut}.*
decoder.up.{1-3}.upsample.conv.*
decoder.norm_out.*
decoder.conv_out.*
Note: up blocks are indexed in reverse β€” up.3 is the first decoder stage
(highest channels), up.0 is the last (lowest channels).
All conv weights loaded as PyTorch [O,I,kH,kW] are transposed to MLX
[O,kH,kW,I] in the pipeline's _load_vae().
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
# ── Building blocks (param names match weight keys) ──────────────────────────
class ResnetBlock(nn.Module):
"""Matches: block_{i}.{norm1,conv1,norm2,conv2,nin_shortcut}.*"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
if in_ch != out_ch:
self.nin_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1)
else:
self.nin_shortcut = None
def __call__(self, x):
h = nn.silu(self.norm1(x))
h = self.conv1(h)
h = nn.silu(self.norm2(h))
h = self.conv2(h)
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
"""Matches: attn_1.{norm,q,k,v,proj_out}.*
Uses 1Γ—1 Conv2d for Q/K/V/O projections (matching weight shapes).
"""
def __init__(self, channels: int):
super().__init__()
self.norm = nn.GroupNorm(32, channels)
self.q = nn.Conv2d(channels, channels, kernel_size=1)
self.k = nn.Conv2d(channels, channels, kernel_size=1)
self.v = nn.Conv2d(channels, channels, kernel_size=1)
self.proj_out = nn.Conv2d(channels, channels, kernel_size=1)
def __call__(self, x):
B, H, W, C = x.shape
h = self.norm(x)
q = self.q(h).reshape(B, H * W, C)
k = self.k(h).reshape(B, H * W, C)
v = self.v(h).reshape(B, H * W, C)
scale = C ** -0.5
attn = (q @ k.transpose(0, 2, 1)) * scale
attn = mx.softmax(attn, axis=-1)
h = (attn @ v).reshape(B, H, W, C)
return x + self.proj_out(h)
class Upsample(nn.Module):
"""Matches: upsample.conv.*"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def __call__(self, x):
B, H, W, C = x.shape
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
return self.conv(x)
class UpBlock(nn.Module):
"""One decoder up-stage. Matches: up.{i}.block.{0-2}.* + up.{i}.upsample.*"""
def __init__(self, in_ch: int, out_ch: int, num_blocks: int = 3, has_upsample: bool = True):
super().__init__()
self.block = [ResnetBlock(in_ch if j == 0 else out_ch, out_ch) for j in range(num_blocks)]
if has_upsample:
self.upsample = Upsample(out_ch)
else:
self.upsample = None
def __call__(self, x):
for b in self.block:
x = b(x)
if self.upsample is not None:
x = self.upsample(x)
return x
class MidBlock(nn.Module):
"""Matches: mid.{block_1, attn_1, block_2}.*"""
def __init__(self, channels: int):
super().__init__()
self.block_1 = ResnetBlock(channels, channels)
self.attn_1 = AttnBlock(channels)
self.block_2 = ResnetBlock(channels, channels)
def __call__(self, x):
x = self.block_1(x)
x = self.attn_1(x)
x = self.block_2(x)
return x
# ── Decoder ──────────────────────────────────────────────────────────────────
class Decoder(nn.Module):
"""VAE Decoder. Param paths match: decoder.{conv_in,mid,up,norm_out,conv_out}.*
Up block order (matching weight keys):
up.3 β†’ 512β†’512 + upsample (first stage)
up.2 β†’ 512β†’512 + upsample
up.1 β†’ 512β†’256 + upsample
up.0 β†’ 256β†’128 (no upsample, last stage)
"""
def __init__(self):
super().__init__()
self.conv_in = nn.Conv2d(16, 512, kernel_size=3, padding=1)
self.mid = MidBlock(512)
# up blocks β€” indexed 0-3, processed in reverse order (3β†’2β†’1β†’0)
self.up = [
UpBlock(256, 128, num_blocks=3, has_upsample=False), # up.0
UpBlock(512, 256, num_blocks=3, has_upsample=True), # up.1
UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.2
UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.3
]
self.norm_out = nn.GroupNorm(32, 128)
self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1)
def __call__(self, z):
h = self.conv_in(z)
h = self.mid(h)
# Process up blocks in reverse order: 3, 2, 1, 0
for i in reversed(range(len(self.up))):
h = self.up[i](h)
h = nn.silu(self.norm_out(h))
h = self.conv_out(h)
return h
# ── AutoencoderKL ────────────────────────────────────────────────────────────
class AutoencoderKL(nn.Module):
"""FLUX VAE β€” decode-only path.
Input: z [B, H/8, W/8, 16] (latent, channels-last)
Output: image [B, H, W, 3] (RGB in [0, 1])
"""
SCALE_FACTOR = 0.3611
SHIFT_FACTOR = 0.1159
def __init__(self):
super().__init__()
self.decoder = Decoder()
def decode(self, z: mx.array) -> mx.array:
z = z / self.SCALE_FACTOR + self.SHIFT_FACTOR
image = self.decoder(z)
image = mx.clip((image + 1.0) / 2.0, 0.0, 1.0)
return image