lucky0146 commited on
Commit
9cf390a
·
verified ·
1 Parent(s): c63db6a

Create vqgan_arch.py

Browse files
Files changed (1) hide show
  1. vqgan_arch.py +102 -0
vqgan_arch.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ class ResBlock(nn.Module):
9
+ def __init__(self, in_channel, out_channel, downsample=False):
10
+ super().__init__()
11
+ self.downsample = downsample
12
+ self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=False)
13
+ self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1, bias=False)
14
+ self.norm1 = nn.BatchNorm2d(out_channel)
15
+ self.norm2 = nn.BatchNorm2d(out_channel)
16
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
17
+ if downsample:
18
+ self.down = nn.Conv2d(in_channel, out_channel, 3, 2, 1)
19
+
20
+ def forward(self, x):
21
+ identity = x
22
+ out = self.conv1(x)
23
+ out = self.norm1(out)
24
+ out = self.relu(out)
25
+ out = self.conv2(out)
26
+ out = self.norm2(out)
27
+ if self.downsample:
28
+ identity = self.down(x)
29
+ out += identity
30
+ out = self.relu(out)
31
+ return out
32
+
33
+
34
+ class VectorQuantizer(nn.Module):
35
+ def __init__(self, n_e, e_dim, beta=0.25):
36
+ super().__init__()
37
+ self.n_e = n_e
38
+ self.e_dim = e_dim
39
+ self.beta = beta
40
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
41
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
42
+
43
+ def forward(self, z):
44
+ z = z.permute(0, 2, 3, 1).contiguous()
45
+ z_flattened = z.view(-1, self.e_dim)
46
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
47
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
48
+ torch.matmul(z_flattened, self.embedding.weight.t())
49
+ min_encoding_indices = torch.argmin(d, dim=1)
50
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
51
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
52
+ z_q = z + (z_q - z).detach()
53
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
54
+ return z_q, loss, min_encoding_indices
55
+
56
+ def get_codebook_feat(self, indices, shape):
57
+ z_q = self.embedding(indices)
58
+ if len(z_q.shape) == 2:
59
+ z_q = z_q.view(shape[0], shape[1], shape[2], shape[3])
60
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
61
+ return z_q
62
+
63
+
64
+ @ARCH_REGISTRY.register()
65
+ class VQAutoEncoder(nn.Module):
66
+ def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
67
+ super().__init__()
68
+ self.encoder = nn.ModuleList()
69
+ self.decoder = nn.ModuleList()
70
+ curr_channel = in_channel
71
+ for i in range(downsample_steps):
72
+ next_channel = channel * down_factor[i]
73
+ down = i < len(down_factor) - 1
74
+ self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down))
75
+ curr_channel = next_channel
76
+ self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
77
+ self.quantize = VectorQuantizer(codebook_size, z_channels)
78
+ self.decoder.append(nn.Conv2d(z_channels, curr_channel, 3, 1, 1))
79
+ for i in range(downsample_steps - 1, -1, -1):
80
+ next_channel = channel * down_factor[i]
81
+ up = i > 0
82
+ self.decoder.append(ResBlock(curr_channel, next_channel, downsample=False))
83
+ if up:
84
+ self.decoder.append(nn.Upsample(scale_factor=down_factor[i], mode=downsample_method))
85
+ curr_channel = next_channel
86
+ self.decoder.append(nn.Conv2d(curr_channel, in_channel, 3, 1, 1))
87
+
88
+ def encode(self, x):
89
+ for module in self.encoder:
90
+ x = module(x)
91
+ return x
92
+
93
+ def decode(self, z):
94
+ for module in self.decoder:
95
+ z = module(z)
96
+ return z
97
+
98
+ def forward(self, x):
99
+ z = self.encode(x)
100
+ z_q, quant_loss, _ = self.quantize(z)
101
+ out = self.decode(z_q)
102
+ return out, quant_loss