lucky0146 commited on
Commit
38e5619
·
verified ·
1 Parent(s): bad7bf5

Update vqgan_arch.py

Browse files
Files changed (1) hide show
  1. vqgan_arch.py +34 -48
vqgan_arch.py CHANGED
@@ -1,66 +1,52 @@
1
- import math
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from basicsr.utils.registry import ARCH_REGISTRY
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class ResBlock(nn.Module):
9
- def __init__(self, in_channel, out_channel, downsample=False):
10
  super().__init__()
11
  self.downsample = downsample
12
- self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=False)
13
- self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1, bias=False)
14
- self.norm1 = nn.BatchNorm2d(out_channel)
15
- self.norm2 = nn.BatchNorm2d(out_channel)
16
- self.relu = nn.LeakyReLU(0.2, inplace=True)
17
- if downsample:
18
- self.down = nn.Conv2d(in_channel, out_channel, 3, 2, 1)
 
 
19
 
20
  def forward(self, x):
21
  identity = x
22
- out = self.conv1(x)
23
- out = self.norm1(out)
24
- out = self.relu(out)
25
- out = self.conv2(out)
26
- out = self.norm2(out)
27
  if self.downsample:
28
- identity = self.down(x)
 
 
 
29
  out += identity
30
- out = self.relu(out)
31
  return out
32
 
33
-
34
- class VectorQuantizer(nn.Module):
35
- def __init__(self, n_e, e_dim, beta=0.25):
36
- super().__init__()
37
- self.n_e = n_e
38
- self.e_dim = e_dim
39
- self.beta = beta
40
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
41
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
42
-
43
- def forward(self, z):
44
- z = z.permute(0, 2, 3, 1).contiguous()
45
- z_flattened = z.view(-1, self.e_dim)
46
- d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
47
- torch.sum(self.embedding.weight**2, dim=1) - 2 * \
48
- torch.matmul(z_flattened, self.embedding.weight.t())
49
- min_encoding_indices = torch.argmin(d, dim=1)
50
- z_q = self.embedding(min_encoding_indices).view(z.shape)
51
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
52
- z_q = z + (z_q - z).detach()
53
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
54
- return z_q, loss, min_encoding_indices
55
-
56
- def get_codebook_feat(self, indices, shape):
57
- z_q = self.embedding(indices)
58
- if len(z_q.shape) == 2:
59
- z_q = z_q.view(shape[0], shape[1], shape[2], shape[3])
60
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
61
- return z_q
62
-
63
-
64
  @ARCH_REGISTRY.register()
65
  class VQAutoEncoder(nn.Module):
66
  def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
@@ -71,7 +57,7 @@ class VQAutoEncoder(nn.Module):
71
  for i in range(downsample_steps):
72
  next_channel = channel * down_factor[i]
73
  down = i < len(down_factor) - 1
74
- self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down))
75
  curr_channel = next_channel
76
  self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
77
  self.quantize = VectorQuantizer(codebook_size, z_channels)
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from basicsr.utils.registry import ARCH_REGISTRY
5
 
6
+ class VectorQuantizer(nn.Module):
7
+ def __init__(self, n_e, e_dim):
8
+ super().__init__()
9
+ self.n_e = n_e
10
+ self.e_dim = e_dim
11
+ self.embedding = nn.Parameter(torch.randn(n_e, e_dim))
12
+
13
+ def get_codebook_feat(self, indices, shape):
14
+ feat = self.embedding[indices]
15
+ feat = feat.view(shape[0], shape[1], shape[2], shape[3]).permute(0, 3, 1, 2)
16
+ return feat
17
+
18
+ def forward(self, z):
19
+ z_q = self.embedding[z]
20
+ loss = torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
21
+ z_q = z + (z_q - z).detach()
22
+ return z_q, loss, None
23
 
24
  class ResBlock(nn.Module):
25
+ def __init__(self, in_channel, out_channel, downsample=False, downsample_method='nearest'):
26
  super().__init__()
27
  self.downsample = downsample
28
+ self.downsample_method = downsample_method
29
+ self.conv = nn.Sequential(
30
+ nn.Conv2d(in_channel, out_channel, 3, 1, 1),
31
+ nn.ReLU(True),
32
+ nn.Conv2d(out_channel, out_channel, 3, 1, 1)
33
+ )
34
+ # Add a projection layer for the identity path if channels or spatial dimensions change
35
+ self.proj = nn.Conv2d(in_channel, out_channel, 1, 1, 0) if in_channel != out_channel else nn.Identity()
36
+ self.downsample_identity = nn.AvgPool2d(kernel_size=2, stride=2) if downsample else nn.Identity()
37
 
38
  def forward(self, x):
39
  identity = x
40
+ out = self.conv(x)
 
 
 
 
41
  if self.downsample:
42
+ out = F.interpolate(out, scale_factor=0.5, mode=self.downsample_method)
43
+ # Adjust the identity path to match out's dimensions
44
+ identity = self.proj(identity) # Match channel dimensions
45
+ identity = self.downsample_identity(identity) # Match spatial dimensions if downsampling
46
  out += identity
47
+ out = F.relu(out)
48
  return out
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  @ARCH_REGISTRY.register()
51
  class VQAutoEncoder(nn.Module):
52
  def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
 
57
  for i in range(downsample_steps):
58
  next_channel = channel * down_factor[i]
59
  down = i < len(down_factor) - 1
60
+ self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down, downsample_method=downsample_method))
61
  curr_channel = next_channel
62
  self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
63
  self.quantize = VectorQuantizer(codebook_size, z_channels)