"""AutoencoderKL implementation compatible with diffusers weights.""" # Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn @dataclass class AutoencoderKLOutput: sample: torch.Tensor class AutoencoderConfig: def __init__(self, **kwargs): self.__dict__.update(kwargs) def get(self, key, default=None): return self.__dict__.get(key, default) def __getattr__(self, name): return self.__dict__.get(name) def swish(x): return x * torch.sigmoid(x) class ResnetBlock2D(nn.Module): def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6): super().__init__() out_channels = out_channels or in_channels self.in_channels = in_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = swish if self.in_channels != self.out_channels: self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) else: self.conv_shortcut = None def forward(self, input_tensor, temb=None): hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / 1.0 return output_tensor class Attention(nn.Module): def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6): super().__init__() self.heads = heads self.in_channels = in_channels self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.to_q = nn.Linear(in_channels, in_channels) self.to_k = nn.Linear(in_channels, in_channels) self.to_v = nn.Linear(in_channels, in_channels) self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)]) def forward(self, hidden_states): b, c, h, w = hidden_states.shape residual = hidden_states hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view(b, c, -1).transpose(1, 2) # (B, H*W, C) query = self.to_q(hidden_states) key = self.to_k(hidden_states) value = self.to_v(hidden_states) import torch.nn.functional as F hidden_states = F.scaled_dot_product_attention(query, key, value) hidden_states = self.to_out[0](hidden_states) hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w) return residual + hidden_states class Downsample2D(nn.Module): def __init__(self, channels, with_conv=True, out_channels=None, padding=1): super().__init__() out_channels = out_channels or channels self.with_conv = with_conv if with_conv: self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding) def forward(self, hidden_states): if self.with_conv: return self.conv(hidden_states) else: return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2) class Upsample2D(nn.Module): def __init__(self, channels, with_conv=True, out_channels=None): super().__init__() out_channels = out_channels or channels self.with_conv = with_conv if with_conv: self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, hidden_states): hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") if self.with_conv: hidden_states = self.conv(hidden_states) return hidden_states class DownEncoderBlock2D(nn.Module): def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True): super().__init__() resnets = [] for i in range(num_layers): in_c = in_channels if i == 0 else out_channels resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups)) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet in self.resnets: hidden_states = resnet(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: pad = (0, 1, 0, 1) hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0) hidden_states = downsampler(hidden_states) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True): super().__init__() resnets = [] for i in range(num_layers): in_c = in_channels if i == 0 else out_channels resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups)) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states): for resnet in self.resnets: hidden_states = resnet(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class UNetMidBlock2D(nn.Module): def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None): super().__init__() self.resnets = nn.ModuleList( [ ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups), ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups), ] ) self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)]) def forward(self, hidden_states): hidden_states = self.resnets[0](hidden_states) for attn in self.attentions: hidden_states = attn(hidden_states) hidden_states = self.resnets[1](hidden_states) return hidden_states class Encoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, double_z=True, ): super().__init__() self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) self.down_blocks = nn.ModuleList([]) output_channel = block_out_channels[0] for i, block_out_channel in enumerate(block_out_channels): input_channel = output_channel output_channel = block_out_channel is_final_block = i == len(block_out_channels) - 1 block = DownEncoderBlock2D( input_channel, output_channel, num_layers=layers_per_block, resnet_groups=norm_num_groups, add_downsample=not is_final_block, ) self.down_blocks.append(block) self.mid_block = UNetMidBlock2D( block_out_channels[-1], resnet_groups=norm_num_groups, ) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) def forward(self, x): x = self.conv_in(x) for block in self.down_blocks: x = block(x) x = self.mid_block(x) x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) return x class Decoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, ): super().__init__() self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) self.mid_block = UNetMidBlock2D( block_out_channels[-1], resnet_groups=norm_num_groups, ) self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, block_out_channel in enumerate(reversed_block_out_channels): input_channel = output_channel output_channel = block_out_channel is_final_block = i == len(block_out_channels) - 1 block = UpDecoderBlock2D( input_channel, output_channel, num_layers=layers_per_block + 1, resnet_groups=norm_num_groups, add_upsample=not is_final_block, ) self.up_blocks.append(block) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward(self, x): x = self.conv_in(x) x = self.mid_block(x) for block in self.up_blocks: x = block(x) x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) return x class AutoencoderKL(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, shift_factor: Optional[float] = None, force_upcast: bool = True, use_quant_conv: bool = True, use_post_quant_conv: bool = True, mid_block_add_attention: bool = True, **kwargs, ): super().__init__() self.config = AutoencoderConfig( in_channels=in_channels, out_channels=out_channels, block_out_channels=block_out_channels, layers_per_block=layers_per_block, latent_channels=latent_channels, scaling_factor=scaling_factor, shift_factor=shift_factor, ) self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, double_z=True, ) self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None @property def dtype(self): return next(self.parameters()).dtype def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: if self.post_quant_conv is not None: z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return AutoencoderKLOutput(sample=dec)