File size: 5,792 Bytes
0c120cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import torch
import torch.nn as nn
from model.blocks import DownBlock, MidBlock, UpBlock
class VAE(nn.Module):
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config["down_channels"]
self.mid_channels = model_config["mid_channels"]
self.down_sample = model_config["down_sample"]
self.num_down_layers = model_config["num_down_layers"]
self.num_mid_layers = model_config["num_mid_layers"]
self.num_up_layers = model_config["num_up_layers"]
# To disable attention in Downblock of Encoder and Upblock of Decoder
self.attns = model_config["attn_down"]
# Latent Dimension
self.z_channels = model_config["z_channels"]
self.norm_channels = model_config["norm_channels"]
self.num_heads = model_config["num_heads"]
# Assertion to validate the channel information
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-1]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Wherever we use downsampling in encoder correspondingly use
# upsampling in decoder
self.up_sample = list(reversed(self.down_sample))
##################### Encoder ######################
self.encoder_conv_in = nn.Conv2d(
im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)
)
# Downblock + Midblock
self.encoder_layers = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.encoder_layers.append(
DownBlock(
self.down_channels[i],
self.down_channels[i + 1],
t_emb_dim=None,
down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels,
)
)
self.encoder_mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.encoder_mids.append(
MidBlock(
self.mid_channels[i],
self.mid_channels[i + 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels,
)
)
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
self.encoder_conv_out = nn.Conv2d(
self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1
)
# Latent Dimension is 2*Latent because we are predicting mean & variance
self.pre_quant_conv = nn.Conv2d(
2 * self.z_channels, 2 * self.z_channels, kernel_size=1
)
####################################################
##################### Decoder ######################
self.post_quant_conv = nn.Conv2d(
self.z_channels, self.z_channels, kernel_size=1
)
self.decoder_conv_in = nn.Conv2d(
self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)
)
# Midblock + Upblock
self.decoder_mids = nn.ModuleList([])
for i in reversed(range(1, len(self.mid_channels))):
self.decoder_mids.append(
MidBlock(
self.mid_channels[i],
self.mid_channels[i - 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels,
)
)
self.decoder_layers = nn.ModuleList([])
for i in reversed(range(1, len(self.down_channels))):
self.decoder_layers.append(
UpBlock(
self.down_channels[i],
self.down_channels[i - 1],
t_emb_dim=None,
up_sample=self.down_sample[i - 1],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
attn=self.attns[i - 1],
norm_channels=self.norm_channels,
)
)
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
self.decoder_conv_out = nn.Conv2d(
self.down_channels[0], im_channels, kernel_size=3, padding=1
)
def encode(self, x):
out = self.encoder_conv_in(x)
for idx, down in enumerate(self.encoder_layers):
out = down(out)
for mid in self.encoder_mids:
out = mid(out)
out = self.encoder_norm_out(out)
out = nn.SiLU()(out)
out = self.encoder_conv_out(out)
out = self.pre_quant_conv(out)
mean, logvar = torch.chunk(out, 2, dim=1)
std = torch.exp(0.5 * logvar)
sample = mean + std * torch.randn(mean.shape).to(device=x.device)
return sample, out
def decode(self, z):
out = z
out = self.post_quant_conv(out)
out = self.decoder_conv_in(out)
for mid in self.decoder_mids:
out = mid(out)
for idx, up in enumerate(self.decoder_layers):
out = up(out)
out = self.decoder_norm_out(out)
out = nn.SiLU()(out)
out = self.decoder_conv_out(out)
return out
def forward(self, x):
z, encoder_output = self.encode(x)
out = self.decode(z)
return out, encoder_output
|