| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torch.nn.utils.spectral_norm as spectral_norm |
| from models.networks.normalization import FADE |
| from models.networks.sync_batchnorm import SynchronizedBatchNorm2d |
| from torchvision.models import vgg19, VGG19_Weights |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| class FADEResnetBlock(nn.Module): |
| def __init__(self, fin, fout, opt): |
| super().__init__() |
| |
| self.learned_shortcut = (fin != fout) |
| fmiddle = fin |
|
|
| |
| self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) |
| self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) |
| if self.learned_shortcut: |
| self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) |
|
|
| |
| if 'spectral' in opt.norm_G: |
| self.conv_0 = spectral_norm(self.conv_0) |
| self.conv_1 = spectral_norm(self.conv_1) |
| if self.learned_shortcut: |
| self.conv_s = spectral_norm(self.conv_s) |
|
|
| |
| fade_config_str = opt.norm_G.replace('spectral', '') |
| self.norm_0 = FADE(fade_config_str, fin, fin) |
| self.norm_1 = FADE(fade_config_str, fmiddle, fmiddle) |
| if self.learned_shortcut: |
| self.norm_s = FADE(fade_config_str, fin, fin) |
|
|
| |
| |
| def forward(self, x, feat): |
| x_s = self.shortcut(x, feat) |
|
|
| dx = self.conv_0(self.actvn(self.norm_0(x, feat))) |
| dx = self.conv_1(self.actvn(self.norm_1(dx, feat))) |
|
|
| out = x_s + dx |
|
|
| return out |
|
|
| def shortcut(self, x, feat): |
| if self.learned_shortcut: |
| x_s = self.conv_s(self.norm_s(x, feat)) |
| else: |
| x_s = x |
| return x_s |
|
|
| def actvn(self, x): |
| return F.leaky_relu(x, 2e-1) |
|
|
|
|
| class StreamResnetBlock(nn.Module): |
| def __init__(self, fin, fout, opt): |
| super().__init__() |
| |
| self.learned_shortcut = (fin != fout) |
| fmiddle = fin |
|
|
| |
| self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) |
| self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) |
| if self.learned_shortcut: |
| self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) |
|
|
| |
| if 'spectral' in opt.norm_S: |
| self.conv_0 = spectral_norm(self.conv_0) |
| self.conv_1 = spectral_norm(self.conv_1) |
| if self.learned_shortcut: |
| self.conv_s = spectral_norm(self.conv_s) |
|
|
| |
| subnorm_type = opt.norm_S.replace('spectral', '') |
| if subnorm_type == 'batch': |
| self.norm_layer_in = nn.BatchNorm2d(fin, affine=True) |
| self.norm_layer_out= nn.BatchNorm2d(fout, affine=True) |
| if self.learned_shortcut: |
| self.norm_layer_s = nn.BatchNorm2d(fout, affine=True) |
| elif subnorm_type == 'syncbatch': |
| self.norm_layer_in = SynchronizedBatchNorm2d(fin, affine=True) |
| self.norm_layer_out= SynchronizedBatchNorm2d(fout, affine=True) |
| if self.learned_shortcut: |
| self.norm_layer_s = SynchronizedBatchNorm2d(fout, affine=True) |
| elif subnorm_type == 'instance': |
| self.norm_layer_in = nn.InstanceNorm2d(fin, affine=False) |
| self.norm_layer_out= nn.InstanceNorm2d(fout, affine=False) |
| if self.learned_shortcut: |
| self.norm_layer_s = nn.InstanceNorm2d(fout, affine=False) |
| else: |
| raise ValueError('normalization layer %s is not recognized' % subnorm_type) |
|
|
| def forward(self, x): |
| x_s = self.shortcut(x) |
|
|
| dx = self.actvn(self.norm_layer_in(self.conv_0(x))) |
| dx = self.actvn(self.norm_layer_out(self.conv_1(dx))) |
|
|
| out = x_s + dx |
|
|
| return out |
|
|
| def shortcut(self,x): |
| if self.learned_shortcut: |
| x_s = self.actvn(self.norm_layer_s(self.conv_s(x))) |
| else: |
| x_s = x |
| return x_s |
|
|
| def actvn(self, x): |
| return F.leaky_relu(x, 2e-1) |
|
|
|
|
| |
| |
| class ResnetBlock(nn.Module): |
| def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): |
| super().__init__() |
|
|
| pw = (kernel_size - 1) // 2 |
| self.conv_block = nn.Sequential( |
| nn.ReflectionPad2d(pw), |
| norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), |
| activation, |
| nn.ReflectionPad2d(pw), |
| norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)) |
| ) |
|
|
| def forward(self, x): |
| y = self.conv_block(x) |
| out = x + y |
| return out |
|
|
|
|
| |
| class FCMapping(nn.Module): |
| def __init__(self, opt): |
| super().__init__() |
| self.fc1 = nn.Linear(3 * opt.crop_size * opt.crop_size, 1024) |
| self.fc2 = nn.Linear(1024, 1024) |
| self.fc3 = nn.Linear(1024, 512) |
| self.fc4 = nn.Linear(512, 256) |
| self.fc5 = nn.Linear(256, opt.latent_dim) |
| self.actvn = nn.LeakyReLU(2e-1) |
|
|
| def forward(self, x): |
| x = x.view(x.size(0), -1) |
| x = self.actvn(self.fc1(x)) |
| x = self.actvn(self.fc2(x)) |
| x = self.actvn(self.fc3(x)) |
| x = self.actvn(self.fc4(x)) |
| x = self.fc5(x) |
| return x |
|
|
|
|
| |
| class VGG19(torch.nn.Module): |
| def __init__(self, requires_grad=False): |
| super().__init__() |
| vgg_pretrained_features = vgg19(weights=VGG19_Weights.DEFAULT).features |
| self.slice1 = torch.nn.Sequential() |
| self.slice2 = torch.nn.Sequential() |
| self.slice3 = torch.nn.Sequential() |
| self.slice4 = torch.nn.Sequential() |
| self.slice5 = torch.nn.Sequential() |
| for x in range(2): |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(2, 7): |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(7, 12): |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(12, 21): |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(21, 30): |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
| if not requires_grad: |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, X): |
| h_relu1 = self.slice1(X) |
| h_relu2 = self.slice2(h_relu1) |
| h_relu3 = self.slice3(h_relu2) |
| h_relu4 = self.slice4(h_relu3) |
| h_relu5 = self.slice5(h_relu4) |
| out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
| return out |
|
|