import torch import torch.nn as nn import torch.nn.functional as F from basicsr.utils.registry import ARCH_REGISTRY class VectorQuantizer(nn.Module): def __init__(self, n_e, e_dim): super().__init__() self.n_e = n_e self.e_dim = e_dim self.embedding = nn.Parameter(torch.randn(n_e, e_dim)) def get_codebook_feat(self, indices, shape): feat = self.embedding[indices] feat = feat.view(shape[0], shape[1], shape[2], shape[3]).permute(0, 3, 1, 2) return feat def forward(self, z): z_q = self.embedding[z] loss = torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) z_q = z + (z_q - z).detach() return z_q, loss, None class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, downsample=False, downsample_method='nearest'): super().__init__() self.downsample = downsample self.downsample_method = downsample_method self.conv = nn.Sequential( nn.Conv2d(in_channel, out_channel, 3, 1, 1), nn.ReLU(True), nn.Conv2d(out_channel, out_channel, 3, 1, 1) ) # Add a projection layer for the identity path if channels or spatial dimensions change self.proj = nn.Conv2d(in_channel, out_channel, 1, 1, 0) if in_channel != out_channel else nn.Identity() self.downsample_identity = nn.AvgPool2d(kernel_size=2, stride=2) if downsample else nn.Identity() def forward(self, x): identity = x out = self.conv(x) if self.downsample: out = F.interpolate(out, scale_factor=0.5, mode=self.downsample_method) # Adjust the identity path to match out's dimensions identity = self.proj(identity) # Match channel dimensions identity = self.downsample_identity(identity) # Match spatial dimensions if downsampling out += identity out = F.relu(out) return out @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size): super().__init__() self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() curr_channel = in_channel for i in range(downsample_steps): next_channel = channel * down_factor[i] down = i < len(down_factor) - 1 self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down, downsample_method=downsample_method)) curr_channel = next_channel self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1)) self.quantize = VectorQuantizer(codebook_size, z_channels) self.decoder.append(nn.Conv2d(z_channels, curr_channel, 3, 1, 1)) for i in range(downsample_steps - 1, -1, -1): next_channel = channel * down_factor[i] up = i > 0 self.decoder.append(ResBlock(curr_channel, next_channel, downsample=False)) if up: self.decoder.append(nn.Upsample(scale_factor=down_factor[i], mode=downsample_method)) curr_channel = next_channel self.decoder.append(nn.Conv2d(curr_channel, in_channel, 3, 1, 1)) def encode(self, x): for module in self.encoder: x = module(x) return x def decode(self, z): for module in self.decoder: z = module(z) return z def forward(self, x): z = self.encode(x) z_q, quant_loss, _ = self.quantize(z) out = self.decode(z_q) return out, quant_loss