| """AutoencoderKL Decoder β pure MLX implementation. |
| |
| Decodes latent representations to RGB images without PyTorch/diffusers |
| dependency. Architecture matches diffusers AutoencoderKL with the |
| Z-Image-Turbo VAE config: |
| |
| latent_channels = 16 |
| block_out_channels = [128, 256, 512, 512] |
| layers_per_block = 2 (decoder uses layers_per_block + 1 = 3) |
| norm_num_groups = 32 |
| mid_block_add_attention = true |
| force_upcast = true (all ops in float32) |
| scaling_factor = 0.3611 |
| shift_factor = 0.1159 |
| |
| Data format: NHWC throughout (MLX convention). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
| |
| _GN_EPS = 1e-6 |
|
|
|
|
| def _gn(groups: int, channels: int) -> nn.GroupNorm: |
| return nn.GroupNorm(groups, channels, eps=_GN_EPS, pytorch_compatible=True) |
|
|
|
|
| |
|
|
|
|
| class ResnetBlock2D(nn.Module): |
| """Residual block: GroupNorm β SiLU β Conv β GroupNorm β SiLU β Conv + skip.""" |
|
|
| def __init__(self, in_channels: int, out_channels: int, groups: int = 32): |
| super().__init__() |
| self.norm1 = _gn(groups, in_channels) |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
| self.norm2 = _gn(groups, out_channels) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
|
|
| self.conv_shortcut = None |
| if in_channels != out_channels: |
| self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| residual = x |
| x = nn.silu(self.norm1(x)) |
| x = self.conv1(x) |
| x = nn.silu(self.norm2(x)) |
| x = self.conv2(x) |
| if self.conv_shortcut is not None: |
| residual = self.conv_shortcut(residual) |
| return x + residual |
|
|
|
|
| class AttentionBlock(nn.Module): |
| """Single-head self-attention over spatial positions (NHWC).""" |
|
|
| def __init__(self, channels: int, groups: int = 32): |
| super().__init__() |
| self.group_norm = _gn(groups, channels) |
| self.to_q = nn.Linear(channels, channels) |
| self.to_k = nn.Linear(channels, channels) |
| self.to_v = nn.Linear(channels, channels) |
| |
| self.to_out = [nn.Linear(channels, channels)] |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| residual = x |
| B, H, W, C = x.shape |
| x = self.group_norm(x) |
| x = x.reshape(B, H * W, C) |
|
|
| q = self.to_q(x) |
| k = self.to_k(x) |
| v = self.to_v(x) |
|
|
| scale = 1.0 / math.sqrt(C) |
| attn = (q @ k.transpose(0, 2, 1)) * scale |
| attn = mx.softmax(attn, axis=-1) |
| x = attn @ v |
|
|
| x = self.to_out[0](x) |
| x = x.reshape(B, H, W, C) |
| return x + residual |
|
|
|
|
| class Upsample2D(nn.Module): |
| """2Γ nearest-neighbour upsample followed by a 3Γ3 conv.""" |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| |
| B, H, W, C = x.shape |
| x = mx.repeat(x, 2, axis=1) |
| x = mx.repeat(x, 2, axis=2) |
| x = self.conv(x) |
| return x |
|
|
|
|
| class UpDecoderBlock2D(nn.Module): |
| """Decoder up-block: N resnet blocks + optional 2Γ upsample.""" |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| num_layers: int = 3, |
| add_upsample: bool = True, |
| groups: int = 32, |
| ): |
| super().__init__() |
| self.resnets = [] |
| for i in range(num_layers): |
| res_in = in_channels if i == 0 else out_channels |
| self.resnets.append(ResnetBlock2D(res_in, out_channels, groups)) |
|
|
| self.upsamplers = [] |
| if add_upsample: |
| self.upsamplers.append(Upsample2D(out_channels)) |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| for resnet in self.resnets: |
| x = resnet(x) |
| for up in self.upsamplers: |
| x = up(x) |
| return x |
|
|
|
|
| class MidBlock2D(nn.Module): |
| """Mid block: resnet β self-attention β resnet.""" |
|
|
| def __init__(self, channels: int, groups: int = 32): |
| super().__init__() |
| self.resnets = [ |
| ResnetBlock2D(channels, channels, groups), |
| ResnetBlock2D(channels, channels, groups), |
| ] |
| self.attentions = [AttentionBlock(channels, groups)] |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| x = self.resnets[0](x) |
| x = self.attentions[0](x) |
| x = self.resnets[1](x) |
| return x |
|
|
|
|
| |
|
|
|
|
| class Decoder(nn.Module): |
| """AutoencoderKL Decoder (NHWC, pure MLX). |
| |
| Module hierarchy matches diffusers weight-key paths after stripping |
| the ``decoder.`` prefix, so weights can be loaded directly. |
| """ |
|
|
| def __init__( |
| self, |
| latent_channels: int = 16, |
| block_out_channels: tuple[int, ...] = (128, 256, 512, 512), |
| layers_per_block: int = 2, |
| norm_num_groups: int = 32, |
| ): |
| super().__init__() |
| reversed_ch = list(reversed(block_out_channels)) |
|
|
| |
| self.conv_in = nn.Conv2d(latent_channels, reversed_ch[0], kernel_size=3, padding=1) |
|
|
| |
| self.mid_block = MidBlock2D(reversed_ch[0], norm_num_groups) |
|
|
| |
| self.up_blocks = [] |
| for i, out_ch in enumerate(reversed_ch): |
| in_ch = reversed_ch[i - 1] if i > 0 else reversed_ch[0] |
| add_upsample = i < len(reversed_ch) - 1 |
| self.up_blocks.append( |
| UpDecoderBlock2D( |
| in_channels=in_ch, |
| out_channels=out_ch, |
| num_layers=layers_per_block + 1, |
| add_upsample=add_upsample, |
| groups=norm_num_groups, |
| ) |
| ) |
|
|
| |
| self.conv_norm_out = _gn(norm_num_groups, reversed_ch[-1]) |
| self.conv_out = nn.Conv2d(reversed_ch[-1], 3, kernel_size=3, padding=1) |
|
|
| def __call__(self, z: mx.array) -> mx.array: |
| """Decode latents β image. |
| |
| Args: |
| z: (B, H, W, C) latent tensor in NHWC, **already scaled**. |
| |
| Returns: |
| (B, 8H, 8W, 3) decoded image. |
| """ |
| x = self.conv_in(z) |
| x = self.mid_block(x) |
| for block in self.up_blocks: |
| x = block(x) |
| x = nn.silu(self.conv_norm_out(x)) |
| x = self.conv_out(x) |
| return x |
|
|