| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| |
|
| | class ResBlocks(nn.Module): |
| | def __init__(self, num_blocks, dim, norm, activation, pad_type): |
| | super(ResBlocks, self).__init__() |
| | self.model = [] |
| | for i in range(num_blocks): |
| | self.model += [ResBlock(dim, |
| | norm=norm, |
| | activation=activation, |
| | pad_type=pad_type)] |
| | self.model = nn.Sequential(*self.model) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| | class ResBlock(nn.Module): |
| | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): |
| | super(ResBlock, self).__init__() |
| | model = [] |
| | model += [Conv2dBlock(dim, dim, 3, 1, 1, |
| | norm=norm, |
| | activation=activation, |
| | pad_type=pad_type)] |
| | model += [Conv2dBlock(dim, dim, 3, 1, 1, |
| | norm=norm, |
| | activation='none', |
| | pad_type=pad_type)] |
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, x): |
| | residual = x |
| | out = self.model(x) |
| | out += residual |
| | return out |
| |
|
| |
|
| | class ActFirstResBlock(nn.Module): |
| | def __init__(self, fin, fout, fhid=None, |
| | activation='lrelu', norm='none'): |
| | super().__init__() |
| | self.learned_shortcut = (fin != fout) |
| | self.fin = fin |
| | self.fout = fout |
| | self.fhid = min(fin, fout) if fhid is None else fhid |
| | self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1, |
| | padding=1, pad_type='reflect', norm=norm, |
| | activation=activation, activation_first=True) |
| | self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1, |
| | padding=1, pad_type='reflect', norm=norm, |
| | activation=activation, activation_first=True) |
| | if self.learned_shortcut: |
| | self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1, |
| | activation='none', use_bias=False) |
| |
|
| | def forward(self, x): |
| | x_s = self.conv_s(x) if self.learned_shortcut else x |
| | dx = self.conv_0(x) |
| | dx = self.conv_1(dx) |
| | out = x_s + dx |
| | return out |
| |
|
| |
|
| | class LinearBlock(nn.Module): |
| | def __init__(self, in_dim, out_dim, norm='none', activation='relu'): |
| | super(LinearBlock, self).__init__() |
| | use_bias = True |
| | self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) |
| |
|
| | |
| | norm_dim = out_dim |
| | if norm == 'bn': |
| | self.norm = nn.BatchNorm1d(norm_dim) |
| | elif norm == 'in': |
| | self.norm = nn.InstanceNorm1d(norm_dim) |
| | elif norm == 'none': |
| | self.norm = None |
| | else: |
| | assert 0, "Unsupported normalization: {}".format(norm) |
| |
|
| | |
| | if activation == 'relu': |
| | self.activation = nn.ReLU(inplace=False) |
| | elif activation == 'lrelu': |
| | self.activation = nn.LeakyReLU(0.2, inplace=False) |
| | elif activation == 'tanh': |
| | self.activation = nn.Tanh() |
| | elif activation == 'none': |
| | self.activation = None |
| | else: |
| | assert 0, "Unsupported activation: {}".format(activation) |
| |
|
| | def forward(self, x): |
| | out = self.fc(x) |
| | if self.norm: |
| | out = self.norm(out) |
| | if self.activation: |
| | out = self.activation(out) |
| | return out |
| |
|
| |
|
| | class Conv2dBlock(nn.Module): |
| | def __init__(self, in_dim, out_dim, ks, st, padding=0, |
| | norm='none', activation='relu', pad_type='zero', |
| | use_bias=True, activation_first=False): |
| | super(Conv2dBlock, self).__init__() |
| | self.use_bias = use_bias |
| | self.activation_first = activation_first |
| | |
| | if pad_type == 'reflect': |
| | self.pad = nn.ReflectionPad2d(padding) |
| | elif pad_type == 'replicate': |
| | self.pad = nn.ReplicationPad2d(padding) |
| | elif pad_type == 'zero': |
| | self.pad = nn.ZeroPad2d(padding) |
| | else: |
| | assert 0, "Unsupported padding type: {}".format(pad_type) |
| |
|
| | |
| | norm_dim = out_dim |
| | if norm == 'bn': |
| | self.norm = nn.BatchNorm2d(norm_dim) |
| | elif norm == 'in': |
| | self.norm = nn.InstanceNorm2d(norm_dim) |
| | elif norm == 'adain': |
| | self.norm = AdaptiveInstanceNorm2d(norm_dim) |
| | elif norm == 'none': |
| | self.norm = None |
| | else: |
| | assert 0, "Unsupported normalization: {}".format(norm) |
| |
|
| | |
| | if activation == 'relu': |
| | self.activation = nn.ReLU(inplace=False) |
| | elif activation == 'lrelu': |
| | self.activation = nn.LeakyReLU(0.2, inplace=False) |
| | elif activation == 'tanh': |
| | self.activation = nn.Tanh() |
| | elif activation == 'none': |
| | self.activation = None |
| | else: |
| | assert 0, "Unsupported activation: {}".format(activation) |
| |
|
| | self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias) |
| |
|
| | def forward(self, x): |
| | if self.activation_first: |
| | if self.activation: |
| | x = self.activation(x) |
| | x = self.conv(self.pad(x)) |
| | if self.norm: |
| | x = self.norm(x) |
| | else: |
| | x = self.conv(self.pad(x)) |
| | if self.norm: |
| | x = self.norm(x) |
| | if self.activation: |
| | x = self.activation(x) |
| | return x |
| |
|
| |
|
| | class AdaptiveInstanceNorm2d(nn.Module): |
| | def __init__(self, num_features, eps=1e-5, momentum=0.1): |
| | super(AdaptiveInstanceNorm2d, self).__init__() |
| | self.num_features = num_features |
| | self.eps = eps |
| | self.momentum = momentum |
| | self.weight = None |
| | self.bias = None |
| | self.register_buffer('running_mean', torch.zeros(num_features)) |
| | self.register_buffer('running_var', torch.ones(num_features)) |
| |
|
| | def forward(self, x): |
| | assert self.weight is not None and \ |
| | self.bias is not None, "Please assign AdaIN weight first" |
| | b, c = x.size(0), x.size(1) |
| | running_mean = self.running_mean.repeat(b) |
| | running_var = self.running_var.repeat(b) |
| | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) |
| | out = F.batch_norm( |
| | x_reshaped, running_mean, running_var, self.weight, self.bias, |
| | True, self.momentum, self.eps) |
| | return out.view(b, c, *x.size()[2:]) |
| |
|
| | def __repr__(self): |
| | return self.__class__.__name__ + '(' + str(self.num_features) + ')' |
| |
|