Spaces:
Runtime error
Runtime error
Update vqgan_arch.py
Browse files- vqgan_arch.py +34 -48
vqgan_arch.py
CHANGED
|
@@ -1,66 +1,52 @@
|
|
| 1 |
-
import math
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class ResBlock(nn.Module):
|
| 9 |
-
def __init__(self, in_channel, out_channel, downsample=False):
|
| 10 |
super().__init__()
|
| 11 |
self.downsample = downsample
|
| 12 |
-
self.
|
| 13 |
-
self.
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def forward(self, x):
|
| 21 |
identity = x
|
| 22 |
-
out = self.
|
| 23 |
-
out = self.norm1(out)
|
| 24 |
-
out = self.relu(out)
|
| 25 |
-
out = self.conv2(out)
|
| 26 |
-
out = self.norm2(out)
|
| 27 |
if self.downsample:
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
out += identity
|
| 30 |
-
out =
|
| 31 |
return out
|
| 32 |
|
| 33 |
-
|
| 34 |
-
class VectorQuantizer(nn.Module):
|
| 35 |
-
def __init__(self, n_e, e_dim, beta=0.25):
|
| 36 |
-
super().__init__()
|
| 37 |
-
self.n_e = n_e
|
| 38 |
-
self.e_dim = e_dim
|
| 39 |
-
self.beta = beta
|
| 40 |
-
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 41 |
-
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 42 |
-
|
| 43 |
-
def forward(self, z):
|
| 44 |
-
z = z.permute(0, 2, 3, 1).contiguous()
|
| 45 |
-
z_flattened = z.view(-1, self.e_dim)
|
| 46 |
-
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 47 |
-
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
| 48 |
-
torch.matmul(z_flattened, self.embedding.weight.t())
|
| 49 |
-
min_encoding_indices = torch.argmin(d, dim=1)
|
| 50 |
-
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
| 51 |
-
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
| 52 |
-
z_q = z + (z_q - z).detach()
|
| 53 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 54 |
-
return z_q, loss, min_encoding_indices
|
| 55 |
-
|
| 56 |
-
def get_codebook_feat(self, indices, shape):
|
| 57 |
-
z_q = self.embedding(indices)
|
| 58 |
-
if len(z_q.shape) == 2:
|
| 59 |
-
z_q = z_q.view(shape[0], shape[1], shape[2], shape[3])
|
| 60 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 61 |
-
return z_q
|
| 62 |
-
|
| 63 |
-
|
| 64 |
@ARCH_REGISTRY.register()
|
| 65 |
class VQAutoEncoder(nn.Module):
|
| 66 |
def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
|
|
@@ -71,7 +57,7 @@ class VQAutoEncoder(nn.Module):
|
|
| 71 |
for i in range(downsample_steps):
|
| 72 |
next_channel = channel * down_factor[i]
|
| 73 |
down = i < len(down_factor) - 1
|
| 74 |
-
self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down))
|
| 75 |
curr_channel = next_channel
|
| 76 |
self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
|
| 77 |
self.quantize = VectorQuantizer(codebook_size, z_channels)
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 5 |
|
| 6 |
+
class VectorQuantizer(nn.Module):
|
| 7 |
+
def __init__(self, n_e, e_dim):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.n_e = n_e
|
| 10 |
+
self.e_dim = e_dim
|
| 11 |
+
self.embedding = nn.Parameter(torch.randn(n_e, e_dim))
|
| 12 |
+
|
| 13 |
+
def get_codebook_feat(self, indices, shape):
|
| 14 |
+
feat = self.embedding[indices]
|
| 15 |
+
feat = feat.view(shape[0], shape[1], shape[2], shape[3]).permute(0, 3, 1, 2)
|
| 16 |
+
return feat
|
| 17 |
+
|
| 18 |
+
def forward(self, z):
|
| 19 |
+
z_q = self.embedding[z]
|
| 20 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
| 21 |
+
z_q = z + (z_q - z).detach()
|
| 22 |
+
return z_q, loss, None
|
| 23 |
|
| 24 |
class ResBlock(nn.Module):
|
| 25 |
+
def __init__(self, in_channel, out_channel, downsample=False, downsample_method='nearest'):
|
| 26 |
super().__init__()
|
| 27 |
self.downsample = downsample
|
| 28 |
+
self.downsample_method = downsample_method
|
| 29 |
+
self.conv = nn.Sequential(
|
| 30 |
+
nn.Conv2d(in_channel, out_channel, 3, 1, 1),
|
| 31 |
+
nn.ReLU(True),
|
| 32 |
+
nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
| 33 |
+
)
|
| 34 |
+
# Add a projection layer for the identity path if channels or spatial dimensions change
|
| 35 |
+
self.proj = nn.Conv2d(in_channel, out_channel, 1, 1, 0) if in_channel != out_channel else nn.Identity()
|
| 36 |
+
self.downsample_identity = nn.AvgPool2d(kernel_size=2, stride=2) if downsample else nn.Identity()
|
| 37 |
|
| 38 |
def forward(self, x):
|
| 39 |
identity = x
|
| 40 |
+
out = self.conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
if self.downsample:
|
| 42 |
+
out = F.interpolate(out, scale_factor=0.5, mode=self.downsample_method)
|
| 43 |
+
# Adjust the identity path to match out's dimensions
|
| 44 |
+
identity = self.proj(identity) # Match channel dimensions
|
| 45 |
+
identity = self.downsample_identity(identity) # Match spatial dimensions if downsampling
|
| 46 |
out += identity
|
| 47 |
+
out = F.relu(out)
|
| 48 |
return out
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
@ARCH_REGISTRY.register()
|
| 51 |
class VQAutoEncoder(nn.Module):
|
| 52 |
def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
|
|
|
|
| 57 |
for i in range(downsample_steps):
|
| 58 |
next_channel = channel * down_factor[i]
|
| 59 |
down = i < len(down_factor) - 1
|
| 60 |
+
self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down, downsample_method=downsample_method))
|
| 61 |
curr_channel = next_channel
|
| 62 |
self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
|
| 63 |
self.quantize = VectorQuantizer(codebook_size, z_channels)
|