"""FLUX VAE decoder — param names match argmaxinc ae.safetensors keys. Weight key structure: decoder.conv_in.* decoder.mid.block_{1,2}.{norm1,conv1,norm2,conv2}.* decoder.mid.attn_1.{norm,q,k,v,proj_out}.* decoder.up.{0-3}.block.{0-2}.{norm1,conv1,norm2,conv2,nin_shortcut}.* decoder.up.{1-3}.upsample.conv.* decoder.norm_out.* decoder.conv_out.* Note: up blocks are indexed in reverse — up.3 is the first decoder stage (highest channels), up.0 is the last (lowest channels). All conv weights loaded as PyTorch [O,I,kH,kW] are transposed to MLX [O,kH,kW,I] in the pipeline's _load_vae(). """ from __future__ import annotations import mlx.core as mx import mlx.nn as nn # ── Building blocks (param names match weight keys) ────────────────────────── class ResnetBlock(nn.Module): """Matches: block_{i}.{norm1,conv1,norm2,conv2,nin_shortcut}.*""" def __init__(self, in_ch: int, out_ch: int): super().__init__() self.norm1 = nn.GroupNorm(32, in_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.norm2 = nn.GroupNorm(32, out_ch) self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1) if in_ch != out_ch: self.nin_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1) else: self.nin_shortcut = None def __call__(self, x): h = nn.silu(self.norm1(x)) h = self.conv1(h) h = nn.silu(self.norm2(h)) h = self.conv2(h) if self.nin_shortcut is not None: x = self.nin_shortcut(x) return x + h class AttnBlock(nn.Module): """Matches: attn_1.{norm,q,k,v,proj_out}.* Uses 1×1 Conv2d for Q/K/V/O projections (matching weight shapes). """ def __init__(self, channels: int): super().__init__() self.norm = nn.GroupNorm(32, channels) self.q = nn.Conv2d(channels, channels, kernel_size=1) self.k = nn.Conv2d(channels, channels, kernel_size=1) self.v = nn.Conv2d(channels, channels, kernel_size=1) self.proj_out = nn.Conv2d(channels, channels, kernel_size=1) def __call__(self, x): B, H, W, C = x.shape h = self.norm(x) q = self.q(h).reshape(B, H * W, C) k = self.k(h).reshape(B, H * W, C) v = self.v(h).reshape(B, H * W, C) scale = C ** -0.5 attn = (q @ k.transpose(0, 2, 1)) * scale attn = mx.softmax(attn, axis=-1) h = (attn @ v).reshape(B, H, W, C) return x + self.proj_out(h) class Upsample(nn.Module): """Matches: upsample.conv.*""" def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def __call__(self, x): B, H, W, C = x.shape x = mx.repeat(x, 2, axis=1) x = mx.repeat(x, 2, axis=2) return self.conv(x) class UpBlock(nn.Module): """One decoder up-stage. Matches: up.{i}.block.{0-2}.* + up.{i}.upsample.*""" def __init__(self, in_ch: int, out_ch: int, num_blocks: int = 3, has_upsample: bool = True): super().__init__() self.block = [ResnetBlock(in_ch if j == 0 else out_ch, out_ch) for j in range(num_blocks)] if has_upsample: self.upsample = Upsample(out_ch) else: self.upsample = None def __call__(self, x): for b in self.block: x = b(x) if self.upsample is not None: x = self.upsample(x) return x class MidBlock(nn.Module): """Matches: mid.{block_1, attn_1, block_2}.*""" def __init__(self, channels: int): super().__init__() self.block_1 = ResnetBlock(channels, channels) self.attn_1 = AttnBlock(channels) self.block_2 = ResnetBlock(channels, channels) def __call__(self, x): x = self.block_1(x) x = self.attn_1(x) x = self.block_2(x) return x # ── Decoder ────────────────────────────────────────────────────────────────── class Decoder(nn.Module): """VAE Decoder. Param paths match: decoder.{conv_in,mid,up,norm_out,conv_out}.* Up block order (matching weight keys): up.3 → 512→512 + upsample (first stage) up.2 → 512→512 + upsample up.1 → 512→256 + upsample up.0 → 256→128 (no upsample, last stage) """ def __init__(self): super().__init__() self.conv_in = nn.Conv2d(16, 512, kernel_size=3, padding=1) self.mid = MidBlock(512) # up blocks — indexed 0-3, processed in reverse order (3→2→1→0) self.up = [ UpBlock(256, 128, num_blocks=3, has_upsample=False), # up.0 UpBlock(512, 256, num_blocks=3, has_upsample=True), # up.1 UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.2 UpBlock(512, 512, num_blocks=3, has_upsample=True), # up.3 ] self.norm_out = nn.GroupNorm(32, 128) self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1) def __call__(self, z): h = self.conv_in(z) h = self.mid(h) # Process up blocks in reverse order: 3, 2, 1, 0 for i in reversed(range(len(self.up))): h = self.up[i](h) h = nn.silu(self.norm_out(h)) h = self.conv_out(h) return h # ── AutoencoderKL ──────────────────────────────────────────────────────────── class AutoencoderKL(nn.Module): """FLUX VAE — decode-only path. Input: z [B, H/8, W/8, 16] (latent, channels-last) Output: image [B, H, W, 3] (RGB in [0, 1]) """ SCALE_FACTOR = 0.3611 SHIFT_FACTOR = 0.1159 def __init__(self): super().__init__() self.decoder = Decoder() def decode(self, z: mx.array) -> mx.array: z = z / self.SCALE_FACTOR + self.SHIFT_FACTOR image = self.decoder(z) image = mx.clip((image + 1.0) / 2.0, 0.0, 1.0) return image