| """FLUX VAE decoder β param names match argmaxinc ae.safetensors keys. |
| |
| Weight key structure: |
| decoder.conv_in.* |
| decoder.mid.block_{1,2}.{norm1,conv1,norm2,conv2}.* |
| decoder.mid.attn_1.{norm,q,k,v,proj_out}.* |
| decoder.up.{0-3}.block.{0-2}.{norm1,conv1,norm2,conv2,nin_shortcut}.* |
| decoder.up.{1-3}.upsample.conv.* |
| decoder.norm_out.* |
| decoder.conv_out.* |
| |
| Note: up blocks are indexed in reverse β up.3 is the first decoder stage |
| (highest channels), up.0 is the last (lowest channels). |
| |
| All conv weights loaded as PyTorch [O,I,kH,kW] are transposed to MLX |
| [O,kH,kW,I] in the pipeline's _load_vae(). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
|
|
| |
|
|
| class ResnetBlock(nn.Module): |
| """Matches: block_{i}.{norm1,conv1,norm2,conv2,nin_shortcut}.*""" |
|
|
| def __init__(self, in_ch: int, out_ch: int): |
| super().__init__() |
| self.norm1 = nn.GroupNorm(32, in_ch) |
| self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) |
| self.norm2 = nn.GroupNorm(32, out_ch) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1) |
| if in_ch != out_ch: |
| self.nin_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1) |
| else: |
| self.nin_shortcut = None |
|
|
| def __call__(self, x): |
| h = nn.silu(self.norm1(x)) |
| h = self.conv1(h) |
| h = nn.silu(self.norm2(h)) |
| h = self.conv2(h) |
| if self.nin_shortcut is not None: |
| x = self.nin_shortcut(x) |
| return x + h |
|
|
|
|
| class AttnBlock(nn.Module): |
| """Matches: attn_1.{norm,q,k,v,proj_out}.* |
| |
| Uses 1Γ1 Conv2d for Q/K/V/O projections (matching weight shapes). |
| """ |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
| self.norm = nn.GroupNorm(32, channels) |
| self.q = nn.Conv2d(channels, channels, kernel_size=1) |
| self.k = nn.Conv2d(channels, channels, kernel_size=1) |
| self.v = nn.Conv2d(channels, channels, kernel_size=1) |
| self.proj_out = nn.Conv2d(channels, channels, kernel_size=1) |
|
|
| def __call__(self, x): |
| B, H, W, C = x.shape |
| h = self.norm(x) |
| q = self.q(h).reshape(B, H * W, C) |
| k = self.k(h).reshape(B, H * W, C) |
| v = self.v(h).reshape(B, H * W, C) |
|
|
| scale = C ** -0.5 |
| attn = (q @ k.transpose(0, 2, 1)) * scale |
| attn = mx.softmax(attn, axis=-1) |
| h = (attn @ v).reshape(B, H, W, C) |
| return x + self.proj_out(h) |
|
|
|
|
| class Upsample(nn.Module): |
| """Matches: upsample.conv.*""" |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
|
|
| def __call__(self, x): |
| B, H, W, C = x.shape |
| x = mx.repeat(x, 2, axis=1) |
| x = mx.repeat(x, 2, axis=2) |
| return self.conv(x) |
|
|
|
|
| class UpBlock(nn.Module): |
| """One decoder up-stage. Matches: up.{i}.block.{0-2}.* + up.{i}.upsample.*""" |
|
|
| def __init__(self, in_ch: int, out_ch: int, num_blocks: int = 3, has_upsample: bool = True): |
| super().__init__() |
| self.block = [ResnetBlock(in_ch if j == 0 else out_ch, out_ch) for j in range(num_blocks)] |
| if has_upsample: |
| self.upsample = Upsample(out_ch) |
| else: |
| self.upsample = None |
|
|
| def __call__(self, x): |
| for b in self.block: |
| x = b(x) |
| if self.upsample is not None: |
| x = self.upsample(x) |
| return x |
|
|
|
|
| class MidBlock(nn.Module): |
| """Matches: mid.{block_1, attn_1, block_2}.*""" |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
| self.block_1 = ResnetBlock(channels, channels) |
| self.attn_1 = AttnBlock(channels) |
| self.block_2 = ResnetBlock(channels, channels) |
|
|
| def __call__(self, x): |
| x = self.block_1(x) |
| x = self.attn_1(x) |
| x = self.block_2(x) |
| return x |
|
|
|
|
| |
|
|
| class Decoder(nn.Module): |
| """VAE Decoder. Param paths match: decoder.{conv_in,mid,up,norm_out,conv_out}.* |
| |
| Up block order (matching weight keys): |
| up.3 β 512β512 + upsample (first stage) |
| up.2 β 512β512 + upsample |
| up.1 β 512β256 + upsample |
| up.0 β 256β128 (no upsample, last stage) |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.conv_in = nn.Conv2d(16, 512, kernel_size=3, padding=1) |
|
|
| self.mid = MidBlock(512) |
|
|
| |
| self.up = [ |
| UpBlock(256, 128, num_blocks=3, has_upsample=False), |
| UpBlock(512, 256, num_blocks=3, has_upsample=True), |
| UpBlock(512, 512, num_blocks=3, has_upsample=True), |
| UpBlock(512, 512, num_blocks=3, has_upsample=True), |
| ] |
|
|
| self.norm_out = nn.GroupNorm(32, 128) |
| self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1) |
|
|
| def __call__(self, z): |
| h = self.conv_in(z) |
| h = self.mid(h) |
| |
| for i in reversed(range(len(self.up))): |
| h = self.up[i](h) |
| h = nn.silu(self.norm_out(h)) |
| h = self.conv_out(h) |
| return h |
|
|
|
|
| |
|
|
| class AutoencoderKL(nn.Module): |
| """FLUX VAE β decode-only path. |
| |
| Input: z [B, H/8, W/8, 16] (latent, channels-last) |
| Output: image [B, H, W, 3] (RGB in [0, 1]) |
| """ |
|
|
| SCALE_FACTOR = 0.3611 |
| SHIFT_FACTOR = 0.1159 |
|
|
| def __init__(self): |
| super().__init__() |
| self.decoder = Decoder() |
|
|
| def decode(self, z: mx.array) -> mx.array: |
| z = z / self.SCALE_FACTOR + self.SHIFT_FACTOR |
| image = self.decoder(z) |
| image = mx.clip((image + 1.0) / 2.0, 0.0, 1.0) |
| return image |
|
|