File size: 7,281 Bytes
7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c 199f027 f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c f0ff580 7377e9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import torch
import torch.nn as nn
from models.blocks import DownBlock, MidBlock, UpBlock
class VQVAE(nn.Module):
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config['down_channels']
self.mid_channels = model_config['mid_channels']
self.down_sample = model_config['down_sample']
self.num_down_layers = model_config['num_down_layers']
self.num_mid_layers = model_config['num_mid_layers']
self.num_up_layers = model_config['num_up_layers']
# To disable attention in Downblock of Encoder and Upblock of Decoder
self.attns = model_config['attn_down']
# Latent Dimension
self.z_channels = model_config['z_channels']
self.codebook_size = model_config['codebook_size']
self.norm_channels = model_config['norm_channels']
self.num_heads = model_config['num_heads']
# Assertion to validate the channel information
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-1]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Wherever we use downsampling in encoder correspondingly use
# upsampling in decoder
self.up_sample = list(reversed(self.down_sample))
##################### Encoder ######################
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
# Downblock + Midblock
self.encoder_layers = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
t_emb_dim=None, down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels))
self.encoder_mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
# Pre Quantization Convolution
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
# Codebook
self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
##################### Decoder ######################
# Post Quantization Convolution
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
# Midblock + Upblock
self.decoder_mids = nn.ModuleList([])
for i in reversed(range(1, len(self.mid_channels))):
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.decoder_layers = nn.ModuleList([])
for i in reversed(range(1, len(self.down_channels))):
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
t_emb_dim=None, up_sample=self.down_sample[i - 1],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
attn=self.attns[i-1],
norm_channels=self.norm_channels))
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
def quantize(self, x):
B, C, H, W = x.shape
# B, C, H, W -> B, H, W, C
x = x.permute(0, 2, 3, 1)
# B, H, W, C -> B, H*W, C
x = x.reshape(x.size(0), -1, x.size(-1))
# Find nearest embedding/codebook vector
# dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
# (B, H*W)
min_encoding_indices = torch.argmin(dist, dim=-1)
# Replace encoder output with nearest codebook
# quant_out -> B*H*W, C
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
# x -> B*H*W, C
x = x.reshape((-1, x.size(-1)))
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
quantize_losses = {
'codebook_loss': codebook_loss,
'commitment_loss': commmitment_loss
}
# Straight through estimation
quant_out = x + (quant_out - x).detach()
# quant_out -> B, C, H, W
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
return quant_out, quantize_losses, min_encoding_indices
def encode(self, x):
out = self.encoder_conv_in(x)
for idx, down in enumerate(self.encoder_layers):
out = down(out)
for mid in self.encoder_mids:
out = mid(out)
out = self.encoder_norm_out(out)
out = nn.SiLU()(out)
out = self.encoder_conv_out(out)
out = self.pre_quant_conv(out)
out, quant_losses, _ = self.quantize(out)
return out, quant_losses
def decode(self, z):
out = z
out = self.post_quant_conv(out)
out = self.decoder_conv_in(out)
for mid in self.decoder_mids:
out = mid(out)
for idx, up in enumerate(self.decoder_layers):
out = up(out)
out = self.decoder_norm_out(out)
out = nn.SiLU()(out)
out = self.decoder_conv_out(out)
return out
def forward(self, x):
z, quant_losses = self.encode(x)
out = self.decode(z)
return out, z, quant_losses
|