Yash Nagraj commited on
Commit
7377e9c
·
1 Parent(s): 70a401a

Add AutoEncoder (VQVAE)

Browse files
Files changed (2) hide show
  1. models/blocks.py +108 -0
  2. models/vqvae.py +156 -0
models/blocks.py CHANGED
@@ -398,3 +398,111 @@ class UpBlockUnet(nn.Module):
398
  out = out + out_attn
399
 
400
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  out = out + out_attn
399
 
400
  return out
401
+
402
+
403
+ class UpBlock(nn.Module):
404
+ r"""
405
+ Up conv block with attention.
406
+ Sequence of following blocks
407
+ 1. Upsample
408
+ 1. Concatenate Down block output
409
+ 2. Resnet block with time embedding
410
+ 3. Attention Block
411
+ """
412
+
413
+ def __init__(self, in_channels, out_channels, t_emb_dim,
414
+ up_sample, num_heads, num_layers, attn, norm_channels):
415
+ super().__init__()
416
+ self.num_layers = num_layers
417
+ self.up_sample = up_sample
418
+ self.t_emb_dim = t_emb_dim
419
+ self.attn = attn
420
+ self.resnet_conv_first = nn.ModuleList(
421
+ [
422
+ nn.Sequential(
423
+ nn.GroupNorm(norm_channels, in_channels if i ==
424
+ 0 else out_channels),
425
+ nn.SiLU(),
426
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
427
+ padding=1),
428
+ )
429
+ for i in range(num_layers)
430
+ ]
431
+ )
432
+
433
+ if self.t_emb_dim is not None:
434
+ self.t_emb_layers = nn.ModuleList([
435
+ nn.Sequential(
436
+ nn.SiLU(),
437
+ nn.Linear(t_emb_dim, out_channels)
438
+ )
439
+ for _ in range(num_layers)
440
+ ])
441
+
442
+ self.resnet_conv_second = nn.ModuleList(
443
+ [
444
+ nn.Sequential(
445
+ nn.GroupNorm(norm_channels, out_channels),
446
+ nn.SiLU(),
447
+ nn.Conv2d(out_channels, out_channels,
448
+ kernel_size=3, stride=1, padding=1),
449
+ )
450
+ for _ in range(num_layers)
451
+ ]
452
+ )
453
+ if self.attn:
454
+ self.attention_norms = nn.ModuleList(
455
+ [
456
+ nn.GroupNorm(norm_channels, out_channels)
457
+ for _ in range(num_layers)
458
+ ]
459
+ )
460
+
461
+ self.attentions = nn.ModuleList(
462
+ [
463
+ nn.MultiheadAttention(
464
+ out_channels, num_heads, batch_first=True)
465
+ for _ in range(num_layers)
466
+ ]
467
+ )
468
+
469
+ self.residual_input_conv = nn.ModuleList(
470
+ [
471
+ nn.Conv2d(in_channels if i == 0 else out_channels,
472
+ out_channels, kernel_size=1)
473
+ for i in range(num_layers)
474
+ ]
475
+ )
476
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
477
+ 4, 2, 1) \
478
+ if self.up_sample else nn.Identity()
479
+
480
+ def forward(self, x, out_down=None, t_emb=None):
481
+ # Upsample
482
+ x = self.up_sample_conv(x)
483
+
484
+ # Concat with Downblock output
485
+ if out_down is not None:
486
+ x = torch.cat([x, out_down], dim=1)
487
+
488
+ out = x
489
+ for i in range(self.num_layers):
490
+ # Resnet Block
491
+ resnet_input = out
492
+ out = self.resnet_conv_first[i](out)
493
+ if self.t_emb_dim is not None:
494
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
495
+ out = self.resnet_conv_second[i](out)
496
+ out = out + self.residual_input_conv[i](resnet_input)
497
+
498
+ # Self Attention
499
+ if self.attn:
500
+ batch_size, channels, h, w = out.shape
501
+ in_attn = out.reshape(batch_size, channels, h * w)
502
+ in_attn = self.attention_norms[i](in_attn)
503
+ in_attn = in_attn.transpose(1, 2)
504
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
505
+ out_attn = out_attn.transpose(1, 2).reshape(
506
+ batch_size, channels, h, w)
507
+ out = out + out_attn
508
+ return out
models/vqvae.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from sre_compile import dis
3
+ import torch
4
+ import torch.nn as nn
5
+ from models.blocks import DownBlock, UpBlock, MidBlock
6
+
7
+
8
+ class VQVAE(nn.Module):
9
+ def __init__(self, im_channels, model_config):
10
+ super().__init__()
11
+ self.down_channels = model_config['down_channels']
12
+ self.mid_channels = model_config['mid_channels']
13
+ self.down_sample = model_config['down_sample']
14
+ self.num_down_layers = model_config['num_down_layers']
15
+ self.num_up_layers = model_config['num_up_layers']
16
+ self.num_mid_layers = model_config['num_mid_layers']
17
+
18
+ # To disable attn in encoder and decoder blocks
19
+ self.attn = model_config['attn']
20
+
21
+ # Latent Dimension
22
+ self.z_channels = model_config["z_channels"]
23
+ self.codebook_size = model_config["codebook_size"]
24
+ self.norm_channels = model_config["norm_channels"]
25
+ self.num_heads = model_config["num_heads"]
26
+
27
+ assert self.mid_channels[0] == self.down_channels[-1]
28
+ assert self.mid_channels[-1] == self.down_channels[-1]
29
+ assert len(self.down_sample) == len(self.down_channels) - 1
30
+ assert len(self.attns) == len(self.down_channels) - 1
31
+
32
+ self.upsample = list(reversed(self.down_sample))
33
+
34
+ # Encoder
35
+ self.encoder_conv_one = nn.Conv2d(
36
+ im_channels, self.down_channels[0], kernel_size=3, padding=1, stride=1)
37
+
38
+ self.encoder_layers = nn.ModuleList([])
39
+ for i in range(len(self.down_channels) - 1):
40
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i+1],
41
+ t_emd_dim=None, down_sample=self.down_sample[i],
42
+ num_heads=self.num_heads, num_layers=self.num_down_layers,
43
+ attn=self.attns[i], norm_channels=self.norm_channels))
44
+ self.encode_mid_blocks = nn.ModuleList([])
45
+ for i in range(len(self.down_channels)-1):
46
+ self.encode_mid_blocks.append(MidBlock(self.down_channels[i], self.down_channels[i+1],
47
+ t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers,
48
+ norm_dim=self.norm_channels))
49
+ self.encoder_norm_out = nn.GroupNorm(
50
+ self.norm_channels, self.down_channels[-1])
51
+ self.encoder_conv_out = nn.Conv2d(
52
+ self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
53
+
54
+ # Pre-Quantization Convolution (Before comparing to code blocks to get embedding matrix)
55
+ self.pre_quant_conv = nn.Conv2d(
56
+ self.z_channels, self.z_channels, kernel_size=1)
57
+
58
+ # Code book
59
+ self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
60
+
61
+ # Decoder
62
+ self.post_quant_conv = nn.Conv2d(
63
+ self.z_channels, self.z_channels, kernel_size=1)
64
+ self.decoder_conv_out = nn.Conv2d(
65
+ self.z_channels, self.mid_channels[-1], kernel_size=3, padding=1)
66
+
67
+ # Midblock + UpBlock
68
+ self.decode_mids = nn.ModuleList([])
69
+ for i in reversed(range(1, len(self.mid_channels))):
70
+ self.decode_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i-1],
71
+ t_emb_dim=None, num_heads=self.num_heads,
72
+ num_layers=self.num_mid_layers,
73
+ norm_dim=self.norm_channels))
74
+ self.decoder_layers = nn.ModuleList([])
75
+ for i in reversed(range(1, len(self.down_channels))):
76
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i-1],
77
+ t_emb_dim=None, up_sample=self.down_sample[i-1], num_heads=self.num_heads,
78
+ num_layers=self.num_up_layers,
79
+ attn=self.attn[i-1],
80
+ norm_channels=self.norm_channels))
81
+
82
+ self.decoder_norm_out = nn.GroupNorm(
83
+ self.norm_channels, self.down_channels[0])
84
+ self.decoder_conv_out = nn.Conv2d(
85
+ self.down_channels[0], im_channels, kernel_size=3, padding=1)
86
+
87
+ def quantize(self, x):
88
+ B, C, H, W = x.shape,
89
+
90
+ # B,C,H,W -> B,H,W,C
91
+ x = x.permute(0, 2, 3, 1)
92
+
93
+ # B,H,W,C -> B, H*W, C
94
+ x = x.reshape(x.size(0), -1, x.size(-1))
95
+
96
+ # Find nearest neighbours/codebook vectors
97
+ # Distance between B,H*W,C and B,K,C
98
+ dist = torch.cdist(
99
+ x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
100
+
101
+ min_encoding_indices = torch.argmin(dist, dim=-1)
102
+
103
+ # Replace encoder output with codebook vector
104
+ quant_out = torch.index_select(
105
+ self.embedding.weight, 0, min_encoding_indices.view(-1))
106
+
107
+ # x -> B*H*W,C
108
+ x = x.reshape((-1, x.size(-1)))
109
+ commitment_loss = torch.mean((quant_out.detach() - x) ** 2)
110
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
111
+ quantize_loss = {
112
+ "codebook_loss": codebook_loss,
113
+ "commitment_loss": commitment_loss
114
+ }
115
+
116
+ # Straight through estimation
117
+ quant_out = x - (quant_out - x).detach()
118
+
119
+ # quant_out -> B,C,H,W
120
+ quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
121
+ min_encoding_indices = min_encoding_indices.reshape(
122
+ (-1, quant_out.size(-2), quant_out.size(-1)))
123
+
124
+ return quant_out, quantize_loss, min_encoding_indices
125
+
126
+ def encode(self, x):
127
+ out = self.encoder_conv_one(x)
128
+ for _, down in enumerate(self.encoder_layers):
129
+ out = down(out)
130
+ for mid in self.encode_mid_blocks:
131
+ out = mid(out)
132
+ out = self.encoder_norm_out(out)
133
+ out = nn.SiLU()(out)
134
+ out = self.encoder_conv_out(out)
135
+ out = self.pre_quant_conv(out)
136
+ out, quant_losses, _ = self.quantize(out)
137
+ return out, quant_losses
138
+
139
+ def decode(self, z):
140
+ out = z
141
+ out = self.post_quant_conv(out)
142
+ out = self.decoder_conv_in(out)
143
+ for mid in self.decode_mids:
144
+ out = mid(out)
145
+ for up in self.decoder_layers:
146
+ out = up(out)
147
+
148
+ out = self.decoder_norm_out(out)
149
+ out = nn.SiLU(out)
150
+ out = self.decoder_conv_out(out)
151
+ return out
152
+
153
+ def forward(self, x):
154
+ z, quant_losses = self.encode(x)
155
+ out = self.decode(z)
156
+ return out, z, quant_losses