Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from basicsr.utils.registry import ARCH_REGISTRY | |
| class VectorQuantizer(nn.Module): | |
| def __init__(self, n_e, e_dim): | |
| super().__init__() | |
| self.n_e = n_e | |
| self.e_dim = e_dim | |
| self.embedding = nn.Parameter(torch.randn(n_e, e_dim)) | |
| def get_codebook_feat(self, indices, shape): | |
| feat = self.embedding[indices] | |
| feat = feat.view(shape[0], shape[1], shape[2], shape[3]).permute(0, 3, 1, 2) | |
| return feat | |
| def forward(self, z): | |
| z_q = self.embedding[z] | |
| loss = torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) | |
| z_q = z + (z_q - z).detach() | |
| return z_q, loss, None | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel, downsample=False, downsample_method='nearest'): | |
| super().__init__() | |
| self.downsample = downsample | |
| self.downsample_method = downsample_method | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channel, out_channel, 3, 1, 1), | |
| nn.ReLU(True), | |
| nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |
| ) | |
| # Add a projection layer for the identity path if channels or spatial dimensions change | |
| self.proj = nn.Conv2d(in_channel, out_channel, 1, 1, 0) if in_channel != out_channel else nn.Identity() | |
| self.downsample_identity = nn.AvgPool2d(kernel_size=2, stride=2) if downsample else nn.Identity() | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv(x) | |
| if self.downsample: | |
| out = F.interpolate(out, scale_factor=0.5, mode=self.downsample_method) | |
| # Adjust the identity path to match out's dimensions | |
| identity = self.proj(identity) # Match channel dimensions | |
| identity = self.downsample_identity(identity) # Match spatial dimensions if downsampling | |
| out += identity | |
| out = F.relu(out) | |
| return out | |
| class VQAutoEncoder(nn.Module): | |
| def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size): | |
| super().__init__() | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| curr_channel = in_channel | |
| for i in range(downsample_steps): | |
| next_channel = channel * down_factor[i] | |
| down = i < len(down_factor) - 1 | |
| self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down, downsample_method=downsample_method)) | |
| curr_channel = next_channel | |
| self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1)) | |
| self.quantize = VectorQuantizer(codebook_size, z_channels) | |
| self.decoder.append(nn.Conv2d(z_channels, curr_channel, 3, 1, 1)) | |
| for i in range(downsample_steps - 1, -1, -1): | |
| next_channel = channel * down_factor[i] | |
| up = i > 0 | |
| self.decoder.append(ResBlock(curr_channel, next_channel, downsample=False)) | |
| if up: | |
| self.decoder.append(nn.Upsample(scale_factor=down_factor[i], mode=downsample_method)) | |
| curr_channel = next_channel | |
| self.decoder.append(nn.Conv2d(curr_channel, in_channel, 3, 1, 1)) | |
| def encode(self, x): | |
| for module in self.encoder: | |
| x = module(x) | |
| return x | |
| def decode(self, z): | |
| for module in self.decoder: | |
| z = module(z) | |
| return z | |
| def forward(self, x): | |
| z = self.encode(x) | |
| z_q, quant_loss, _ = self.quantize(z) | |
| out = self.decode(z_q) | |
| return out, quant_loss |