CR-Net / models /networks /normalization.py
datnguyentien204's picture
Upload 147 files
0f52c9d verified
import re
import torch.nn as nn
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm
import torch
# Returns a function that creates a standard normalization function
def get_norm_layer(opt, norm_type='instance'):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, 'out_channels'):
return getattr(layer, 'out_channels')
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith('spectral'):
layer = spectral_norm(layer)
subnorm_type = norm_type[len('spectral'):]
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, 'bias', None) is not None:
delattr(layer, 'bias')
layer.register_parameter('bias', None)
if subnorm_type == 'batch':
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'syncbatch':
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'instance':
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError('normalization layer %s is not recognized' % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
# Creates FADE normalization layer based on the given configuration
class FADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc):
super().__init__()
assert config_text.startswith('fade')
parsed = re.search('fade(\D+)(\d)x\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
if param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'syncbatch':
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError('%s is not a recognized param-free norm type in FADE'
% param_free_norm_type)
pw = ks // 2
self.mlp_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, feat):
# Step 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Step 2. produce scale and bias conditioned on feature map
gamma = self.mlp_gamma(feat)
beta = self.mlp_beta(feat)
# Step 3. apply scale and bias
out = normalized * (1 + gamma) + beta
return out
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel, opt):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channel)
self.gamma = nn.Conv2d(in_channel, in_channel, 1)
self.beta = nn.Conv2d(in_channel, in_channel, 1)
self.conv_l = nn.Conv2d(2 * opt.latent_dim, in_channel, 1)
self.relu = nn.ReLU()
def forward(self, input, latent_code, alpha):
size = latent_code.size()
alpha = alpha.expand(size)
latent_code = torch.cat([latent_code, alpha], 1)
latent_code = latent_code.unsqueeze(2).unsqueeze(3)
latent_code = self.relu(self.conv_l(latent_code))
gamma = self.gamma(latent_code)
beta = self.beta(latent_code)
out = self.norm(input)
out = out * (1 + gamma) + beta
return out
# class AdaptiveInstanceNorm(nn.Module):
# def __init__(self, in_channel, opt):
# super().__init__()
# self.norm = nn.InstanceNorm2d(in_channel)
# self.gamma = nn.Conv2d(in_channel, in_channel, 1)
# self.beta = nn.Conv2d(in_channel, in_channel, 1)
# self.conv1x1 = nn.Conv2d(in_channel, in_channel, 1)
# self.relu = nn.ReLU()
# def forward(self, input, style, alpha):
# size = input.size()
# alpha = alpha.expand(size)
# gamma = self.gamma(alpha)
# beta = self.beta(alpha)
# normalized = self.norm(input)
# out = normalized * gamma + beta
# out = self.relu(self.conv1x1(out))
# return out