"""AutoencoderKL Decoder — pure MLX implementation. Decodes latent representations to RGB images without PyTorch/diffusers dependency. Architecture matches diffusers AutoencoderKL with the Z-Image-Turbo VAE config: latent_channels = 16 block_out_channels = [128, 256, 512, 512] layers_per_block = 2 (decoder uses layers_per_block + 1 = 3) norm_num_groups = 32 mid_block_add_attention = true force_upcast = true (all ops in float32) scaling_factor = 0.3611 shift_factor = 0.1159 Data format: NHWC throughout (MLX convention). """ from __future__ import annotations import math import mlx.core as mx import mlx.nn as nn # Match diffusers VAE GroupNorm: eps=1e-6, pytorch_compatible=True _GN_EPS = 1e-6 def _gn(groups: int, channels: int) -> nn.GroupNorm: return nn.GroupNorm(groups, channels, eps=_GN_EPS, pytorch_compatible=True) # ── Building blocks ────────────────────────────────────────────── class ResnetBlock2D(nn.Module): """Residual block: GroupNorm → SiLU → Conv → GroupNorm → SiLU → Conv + skip.""" def __init__(self, in_channels: int, out_channels: int, groups: int = 32): super().__init__() self.norm1 = _gn(groups, in_channels) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.norm2 = _gn(groups, out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.conv_shortcut = None if in_channels != out_channels: self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) def __call__(self, x: mx.array) -> mx.array: residual = x x = nn.silu(self.norm1(x)) x = self.conv1(x) x = nn.silu(self.norm2(x)) x = self.conv2(x) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return x + residual class AttentionBlock(nn.Module): """Single-head self-attention over spatial positions (NHWC).""" def __init__(self, channels: int, groups: int = 32): super().__init__() self.group_norm = _gn(groups, channels) self.to_q = nn.Linear(channels, channels) self.to_k = nn.Linear(channels, channels) self.to_v = nn.Linear(channels, channels) # diffusers wraps out-proj in a list (Sequential): to_out.0 self.to_out = [nn.Linear(channels, channels)] def __call__(self, x: mx.array) -> mx.array: residual = x B, H, W, C = x.shape x = self.group_norm(x) x = x.reshape(B, H * W, C) q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) scale = 1.0 / math.sqrt(C) attn = (q @ k.transpose(0, 2, 1)) * scale attn = mx.softmax(attn, axis=-1) x = attn @ v x = self.to_out[0](x) x = x.reshape(B, H, W, C) return x + residual class Upsample2D(nn.Module): """2× nearest-neighbour upsample followed by a 3×3 conv.""" def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def __call__(self, x: mx.array) -> mx.array: # Nearest-neighbour 2× in NHWC B, H, W, C = x.shape x = mx.repeat(x, 2, axis=1) x = mx.repeat(x, 2, axis=2) x = self.conv(x) return x class UpDecoderBlock2D(nn.Module): """Decoder up-block: N resnet blocks + optional 2× upsample.""" def __init__( self, in_channels: int, out_channels: int, num_layers: int = 3, add_upsample: bool = True, groups: int = 32, ): super().__init__() self.resnets = [] for i in range(num_layers): res_in = in_channels if i == 0 else out_channels self.resnets.append(ResnetBlock2D(res_in, out_channels, groups)) self.upsamplers = [] if add_upsample: self.upsamplers.append(Upsample2D(out_channels)) def __call__(self, x: mx.array) -> mx.array: for resnet in self.resnets: x = resnet(x) for up in self.upsamplers: x = up(x) return x class MidBlock2D(nn.Module): """Mid block: resnet → self-attention → resnet.""" def __init__(self, channels: int, groups: int = 32): super().__init__() self.resnets = [ ResnetBlock2D(channels, channels, groups), ResnetBlock2D(channels, channels, groups), ] self.attentions = [AttentionBlock(channels, groups)] def __call__(self, x: mx.array) -> mx.array: x = self.resnets[0](x) x = self.attentions[0](x) x = self.resnets[1](x) return x # ── Decoder ────────────────────────────────────────────────────── class Decoder(nn.Module): """AutoencoderKL Decoder (NHWC, pure MLX). Module hierarchy matches diffusers weight-key paths after stripping the ``decoder.`` prefix, so weights can be loaded directly. """ def __init__( self, latent_channels: int = 16, block_out_channels: tuple[int, ...] = (128, 256, 512, 512), layers_per_block: int = 2, norm_num_groups: int = 32, ): super().__init__() reversed_ch = list(reversed(block_out_channels)) # [512, 512, 256, 128] # Input projection self.conv_in = nn.Conv2d(latent_channels, reversed_ch[0], kernel_size=3, padding=1) # Mid block self.mid_block = MidBlock2D(reversed_ch[0], norm_num_groups) # Up blocks (3 upsamples → total 8× spatial increase) self.up_blocks = [] for i, out_ch in enumerate(reversed_ch): in_ch = reversed_ch[i - 1] if i > 0 else reversed_ch[0] add_upsample = i < len(reversed_ch) - 1 self.up_blocks.append( UpDecoderBlock2D( in_channels=in_ch, out_channels=out_ch, num_layers=layers_per_block + 1, add_upsample=add_upsample, groups=norm_num_groups, ) ) # Output self.conv_norm_out = _gn(norm_num_groups, reversed_ch[-1]) self.conv_out = nn.Conv2d(reversed_ch[-1], 3, kernel_size=3, padding=1) def __call__(self, z: mx.array) -> mx.array: """Decode latents → image. Args: z: (B, H, W, C) latent tensor in NHWC, **already scaled**. Returns: (B, 8H, 8W, 3) decoded image. """ x = self.conv_in(z) x = self.mid_block(x) for block in self.up_blocks: x = block(x) x = nn.silu(self.conv_norm_out(x)) x = self.conv_out(x) return x