import re import torch.nn as nn from models.networks.sync_batchnorm import SynchronizedBatchNorm2d import torch.nn.utils.spectral_norm as spectral_norm import torch # Returns a function that creates a standard normalization function def get_norm_layer(opt, norm_type='instance'): # helper function to get # output channels of the previous layer def get_out_channel(layer): if hasattr(layer, 'out_channels'): return getattr(layer, 'out_channels') return layer.weight.size(0) # this function will be returned def add_norm_layer(layer): nonlocal norm_type if norm_type.startswith('spectral'): layer = spectral_norm(layer) subnorm_type = norm_type[len('spectral'):] if subnorm_type == 'none' or len(subnorm_type) == 0: return layer # remove bias in the previous layer, which is meaningless # since it has no effect after normalization if getattr(layer, 'bias', None) is not None: delattr(layer, 'bias') layer.register_parameter('bias', None) if subnorm_type == 'batch': norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == 'syncbatch': norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == 'instance': norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) else: raise ValueError('normalization layer %s is not recognized' % subnorm_type) return nn.Sequential(layer, norm_layer) return add_norm_layer # Creates FADE normalization layer based on the given configuration class FADE(nn.Module): def __init__(self, config_text, norm_nc, label_nc): super().__init__() assert config_text.startswith('fade') parsed = re.search('fade(\D+)(\d)x\d', config_text) param_free_norm_type = str(parsed.group(1)) ks = int(parsed.group(2)) if param_free_norm_type == 'instance': self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) elif param_free_norm_type == 'syncbatch': self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) elif param_free_norm_type == 'batch': self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) else: raise ValueError('%s is not a recognized param-free norm type in FADE' % param_free_norm_type) pw = ks // 2 self.mlp_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw) self.mlp_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw) def forward(self, x, feat): # Step 1. generate parameter-free normalized activations normalized = self.param_free_norm(x) # Step 2. produce scale and bias conditioned on feature map gamma = self.mlp_gamma(feat) beta = self.mlp_beta(feat) # Step 3. apply scale and bias out = normalized * (1 + gamma) + beta return out class AdaptiveInstanceNorm(nn.Module): def __init__(self, in_channel, opt): super().__init__() self.norm = nn.InstanceNorm2d(in_channel) self.gamma = nn.Conv2d(in_channel, in_channel, 1) self.beta = nn.Conv2d(in_channel, in_channel, 1) self.conv_l = nn.Conv2d(2 * opt.latent_dim, in_channel, 1) self.relu = nn.ReLU() def forward(self, input, latent_code, alpha): size = latent_code.size() alpha = alpha.expand(size) latent_code = torch.cat([latent_code, alpha], 1) latent_code = latent_code.unsqueeze(2).unsqueeze(3) latent_code = self.relu(self.conv_l(latent_code)) gamma = self.gamma(latent_code) beta = self.beta(latent_code) out = self.norm(input) out = out * (1 + gamma) + beta return out # class AdaptiveInstanceNorm(nn.Module): # def __init__(self, in_channel, opt): # super().__init__() # self.norm = nn.InstanceNorm2d(in_channel) # self.gamma = nn.Conv2d(in_channel, in_channel, 1) # self.beta = nn.Conv2d(in_channel, in_channel, 1) # self.conv1x1 = nn.Conv2d(in_channel, in_channel, 1) # self.relu = nn.ReLU() # def forward(self, input, style, alpha): # size = input.size() # alpha = alpha.expand(size) # gamma = self.gamma(alpha) # beta = self.beta(alpha) # normalized = self.norm(input) # out = normalized * gamma + beta # out = self.relu(self.conv1x1(out)) # return out