Spaces:
Build error
Build error
| """ | |
| AUTOENCODER WITH ARCHTECTURE FROM VERSION 2 | |
| """ | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def swish(x): | |
| return x * torch.sigmoid(x) | |
| def Normalize(in_channels): | |
| return nn.GroupNorm( | |
| num_groups=32, | |
| num_channels=in_channels, | |
| eps=1e-6, | |
| affine=True | |
| ) | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.conv = nn.Conv3d( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| def forward(self, x): | |
| x = F.interpolate(x, scale_factor=2.0, mode="nearest") | |
| x = self.conv(x) | |
| return x | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.conv = nn.Conv3d( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=0 | |
| ) | |
| def forward(self, x): | |
| pad = (0, 1, 0, 1, 0, 1) | |
| x = nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = nn.Conv3d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| self.norm2 = Normalize(out_channels) | |
| self.conv2 = nn.Conv3d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| if self.in_channels != self.out_channels: | |
| self.nin_shortcut = nn.Conv3d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0 | |
| ) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = F.silu(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = F.silu(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| n_channels: int, | |
| z_channels: int, | |
| ch_mult: Tuple[int], | |
| num_res_blocks: int, | |
| resolution: Tuple[int], | |
| attn_resolutions: Tuple[int], | |
| **ignorekwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.n_channels = n_channels | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.attn_resolutions = attn_resolutions | |
| curr_res = resolution | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| blocks = [] | |
| # initial convolution | |
| blocks.append( | |
| nn.Conv3d( | |
| in_channels, | |
| n_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| ) | |
| # residual and downsampling blocks, with attention on smaller res (16x16) | |
| for i in range(self.num_resolutions): | |
| block_in_ch = n_channels * in_ch_mult[i] | |
| block_out_ch = n_channels * ch_mult[i] | |
| for _ in range(self.num_res_blocks): | |
| blocks.append(ResBlock(block_in_ch, block_out_ch)) | |
| block_in_ch = block_out_ch | |
| if i != self.num_resolutions - 1: | |
| blocks.append(Downsample(block_in_ch)) | |
| curr_res = tuple(ti // 2 for ti in curr_res) | |
| # normalise and convert to latent size | |
| blocks.append(Normalize(block_in_ch)) | |
| blocks.append( | |
| nn.Conv3d( | |
| block_in_ch, | |
| z_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| ) | |
| self.blocks = nn.ModuleList(blocks) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| x = block(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| n_channels: int, | |
| z_channels: int, | |
| out_channels: int, | |
| ch_mult: Tuple[int], | |
| num_res_blocks: int, | |
| resolution: Tuple[int], | |
| attn_resolutions: Tuple[int], | |
| **ignorekwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.n_channels = n_channels | |
| self.z_channels = z_channels | |
| self.out_channels = out_channels | |
| self.ch_mult = ch_mult | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.attn_resolutions = attn_resolutions | |
| block_in_ch = n_channels * self.ch_mult[-1] | |
| curr_res = tuple(ti // 2 ** (self.num_resolutions - 1) for ti in resolution) | |
| blocks = [] | |
| # initial conv | |
| blocks.append( | |
| nn.Conv3d( | |
| z_channels, | |
| block_in_ch, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| ) | |
| for i in reversed(range(self.num_resolutions)): | |
| block_out_ch = n_channels * self.ch_mult[i] | |
| for _ in range(self.num_res_blocks): | |
| blocks.append(ResBlock(block_in_ch, block_out_ch)) | |
| block_in_ch = block_out_ch | |
| if i != 0: | |
| blocks.append(Upsample(block_in_ch)) | |
| curr_res = tuple(ti * 2 for ti in curr_res) | |
| blocks.append(Normalize(block_in_ch)) | |
| blocks.append( | |
| nn.Conv3d( | |
| block_in_ch, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1 | |
| ) | |
| ) | |
| self.blocks = nn.ModuleList(blocks) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| x = block(x) | |
| return x | |
| class AutoencoderKL(nn.Module): | |
| def __init__(self, embed_dim: int, hparams) -> None: | |
| super().__init__() | |
| self.encoder = Encoder(**hparams) | |
| self.decoder = Decoder(**hparams) | |
| self.quant_conv_mu = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1) | |
| self.quant_conv_log_sigma = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1) | |
| self.post_quant_conv = torch.nn.Conv3d(embed_dim, hparams["z_channels"], 1) | |
| self.embed_dim = embed_dim | |
| def decode(self, z): | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z) | |
| return dec | |
| def reconstruct_ldm_outputs(self, z): | |
| x_hat = self.decode(z) | |
| return x_hat | |