CodeFormer / vqgan_arch.py
lucky0146's picture
Update vqgan_arch.py
38e5619 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.embedding = nn.Parameter(torch.randn(n_e, e_dim))
def get_codebook_feat(self, indices, shape):
feat = self.embedding[indices]
feat = feat.view(shape[0], shape[1], shape[2], shape[3]).permute(0, 3, 1, 2)
return feat
def forward(self, z):
z_q = self.embedding[z]
loss = torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
z_q = z + (z_q - z).detach()
return z_q, loss, None
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, downsample=False, downsample_method='nearest'):
super().__init__()
self.downsample = downsample
self.downsample_method = downsample_method
self.conv = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, 1, 1)
)
# Add a projection layer for the identity path if channels or spatial dimensions change
self.proj = nn.Conv2d(in_channel, out_channel, 1, 1, 0) if in_channel != out_channel else nn.Identity()
self.downsample_identity = nn.AvgPool2d(kernel_size=2, stride=2) if downsample else nn.Identity()
def forward(self, x):
identity = x
out = self.conv(x)
if self.downsample:
out = F.interpolate(out, scale_factor=0.5, mode=self.downsample_method)
# Adjust the identity path to match out's dimensions
identity = self.proj(identity) # Match channel dimensions
identity = self.downsample_identity(identity) # Match spatial dimensions if downsampling
out += identity
out = F.relu(out)
return out
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
curr_channel = in_channel
for i in range(downsample_steps):
next_channel = channel * down_factor[i]
down = i < len(down_factor) - 1
self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down, downsample_method=downsample_method))
curr_channel = next_channel
self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
self.quantize = VectorQuantizer(codebook_size, z_channels)
self.decoder.append(nn.Conv2d(z_channels, curr_channel, 3, 1, 1))
for i in range(downsample_steps - 1, -1, -1):
next_channel = channel * down_factor[i]
up = i > 0
self.decoder.append(ResBlock(curr_channel, next_channel, downsample=False))
if up:
self.decoder.append(nn.Upsample(scale_factor=down_factor[i], mode=downsample_method))
curr_channel = next_channel
self.decoder.append(nn.Conv2d(curr_channel, in_channel, 3, 1, 1))
def encode(self, x):
for module in self.encoder:
x = module(x)
return x
def decode(self, z):
for module in self.decoder:
z = module(z)
return z
def forward(self, x):
z = self.encode(x)
z_q, quant_loss, _ = self.quantize(z)
out = self.decode(z_q)
return out, quant_loss