| ''' Layers |
| This file contains various layers for the BigGAN models. |
| ''' |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.nn import init |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from torch.nn import Parameter as P |
|
|
| from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d |
|
|
| |
| def proj(x, y): |
| return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) |
|
|
|
|
| |
| def gram_schmidt(x, ys): |
| for y in ys: |
| x = x - proj(x, y) |
| return x |
|
|
|
|
| |
| def power_iteration(W, u_, update=True, eps=1e-12): |
| |
| us, vs, svs = [], [], [] |
| for i, u in enumerate(u_): |
| |
| with torch.no_grad(): |
| v = torch.matmul(u, W) |
| |
| v = F.normalize(gram_schmidt(v, vs), eps=eps) |
| |
| vs += [v] |
| |
| u = torch.matmul(v, W.t()) |
| |
| u = F.normalize(gram_schmidt(u, us), eps=eps) |
| |
| us += [u] |
| if update: |
| u_[i][:] = u |
| |
| svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] |
| |
| return svs, us, vs |
|
|
|
|
| |
| class identity(nn.Module): |
| def forward(self, input): |
| return input |
|
|
|
|
| |
| class SN(object): |
| def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): |
| |
| self.num_itrs = num_itrs |
| |
| self.num_svs = num_svs |
| |
| self.transpose = transpose |
| |
| self.eps = eps |
| |
| for i in range(self.num_svs): |
| self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) |
| self.register_buffer('sv%d' % i, torch.ones(1)) |
|
|
| |
| @property |
| def u(self): |
| return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] |
|
|
| |
| |
| @property |
| def sv(self): |
| return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] |
|
|
| |
| def W_(self): |
| W_mat = self.weight.view(self.weight.size(0), -1) |
| if self.transpose: |
| W_mat = W_mat.t() |
| |
| for _ in range(self.num_itrs): |
| svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) |
| |
| if self.training: |
| with torch.no_grad(): |
| for i, sv in enumerate(svs): |
| self.sv[i][:] = sv |
| return self.weight / svs[0] |
|
|
|
|
| |
| class SNConv2d(nn.Conv2d, SN): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| padding=0, dilation=1, groups=1, bias=True, |
| num_svs=1, num_itrs=1, eps=1e-12): |
| nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, |
| padding, dilation, groups, bias) |
| SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) |
|
|
| def forward(self, x): |
| return F.conv2d(x, self.W_(), self.bias, self.stride, |
| self.padding, self.dilation, self.groups) |
|
|
|
|
| |
| class SNLinear(nn.Linear, SN): |
| def __init__(self, in_features, out_features, bias=True, |
| num_svs=1, num_itrs=1, eps=1e-12): |
| nn.Linear.__init__(self, in_features, out_features, bias) |
| SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) |
|
|
| def forward(self, x): |
| return F.linear(x, self.W_(), self.bias) |
|
|
|
|
| |
| |
| |
| class SNEmbedding(nn.Embedding, SN): |
| def __init__(self, num_embeddings, embedding_dim, padding_idx=None, |
| max_norm=None, norm_type=2, scale_grad_by_freq=False, |
| sparse=False, _weight=None, |
| num_svs=1, num_itrs=1, eps=1e-12): |
| nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, |
| max_norm, norm_type, scale_grad_by_freq, |
| sparse, _weight) |
| SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) |
|
|
| def forward(self, x): |
| return F.embedding(x, self.W_()) |
|
|
|
|
| |
| |
| |
| class Attention(nn.Module): |
| def __init__(self, ch, which_conv=SNConv2d, name='attention'): |
| super(Attention, self).__init__() |
| |
| self.ch = ch |
| self.which_conv = which_conv |
| self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) |
| self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) |
| self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) |
| self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) |
| |
| self.gamma = P(torch.tensor(0.), requires_grad=True) |
|
|
| def forward(self, x, y=None): |
| |
| theta = self.theta(x) |
| phi = F.max_pool2d(self.phi(x), [2, 2]) |
| g = F.max_pool2d(self.g(x), [2, 2]) |
| |
| theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) |
| try: |
| phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) |
| except: |
| print(phi.shape) |
| g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) |
| |
| beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) |
| |
| o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) |
| return self.gamma * o + x |
|
|
|
|
| |
| def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): |
| |
| |
| scale = torch.rsqrt(var + eps) |
| |
| if gain is not None: |
| scale = scale * gain |
| |
| shift = mean * scale |
| |
| if bias is not None: |
| shift = shift - bias |
| return x * scale - shift |
| |
|
|
|
|
| |
| |
| def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): |
| |
| float_x = x.float() |
| |
| |
| m = torch.mean(float_x, [0, 2, 3], keepdim=True) |
| |
| m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) |
| |
| var = (m2 - m ** 2) |
| |
| var = var.type(x.type()) |
| m = m.type(x.type()) |
| |
| if return_mean_var: |
| return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() |
| else: |
| return fused_bn(x, m, var, gain, bias, eps) |
|
|
|
|
| |
| class myBN(nn.Module): |
| def __init__(self, num_channels, eps=1e-5, momentum=0.1): |
| super(myBN, self).__init__() |
| |
| self.momentum = momentum |
| |
| self.eps = eps |
| |
| self.momentum = momentum |
| |
| self.register_buffer('stored_mean', torch.zeros(num_channels)) |
| self.register_buffer('stored_var', torch.ones(num_channels)) |
| self.register_buffer('accumulation_counter', torch.zeros(1)) |
| |
| self.accumulate_standing = False |
|
|
| |
| def reset_stats(self): |
| self.stored_mean[:] = 0 |
| self.stored_var[:] = 0 |
| self.accumulation_counter[:] = 0 |
|
|
| def forward(self, x, gain, bias): |
| if self.training: |
| out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) |
| |
| if self.accumulate_standing: |
| self.stored_mean[:] = self.stored_mean + mean.data |
| self.stored_var[:] = self.stored_var + var.data |
| self.accumulation_counter += 1.0 |
| |
| else: |
| self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum |
| self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum |
| return out |
| |
| else: |
| mean = self.stored_mean.view(1, -1, 1, 1) |
| var = self.stored_var.view(1, -1, 1, 1) |
| |
| if self.accumulate_standing: |
| mean = mean / self.accumulation_counter |
| var = var / self.accumulation_counter |
| return fused_bn(x, mean, var, gain, bias, self.eps) |
|
|
|
|
| |
| def groupnorm(x, norm_style): |
| |
| if 'ch' in norm_style: |
| ch = int(norm_style.split('_')[-1]) |
| groups = max(int(x.shape[1]) // ch, 1) |
| |
| elif 'grp' in norm_style: |
| groups = int(norm_style.split('_')[-1]) |
| |
| else: |
| groups = 16 |
| return F.group_norm(x, groups) |
|
|
|
|
| |
| |
| |
| |
| |
| class ccbn(nn.Module): |
| def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, |
| cross_replica=False, mybn=False, norm_style='bn', ): |
| super(ccbn, self).__init__() |
| self.output_size, self.input_size = output_size, input_size |
| |
| self.gain = which_linear(input_size, output_size) |
| self.bias = which_linear(input_size, output_size) |
| |
| self.eps = eps |
| |
| self.momentum = momentum |
| |
| self.cross_replica = cross_replica |
| |
| self.mybn = mybn |
| |
| self.norm_style = norm_style |
|
|
| if self.cross_replica: |
| self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) |
| elif self.mybn: |
| self.bn = myBN(output_size, self.eps, self.momentum) |
| elif self.norm_style in ['bn', 'in']: |
| self.register_buffer('stored_mean', torch.zeros(output_size)) |
| self.register_buffer('stored_var', torch.ones(output_size)) |
|
|
| def forward(self, x, y): |
| |
| gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) |
| bias = self.bias(y).view(y.size(0), -1, 1, 1) |
| |
| if self.mybn or self.cross_replica: |
| return self.bn(x, gain=gain, bias=bias) |
| |
| else: |
| if self.norm_style == 'bn': |
| out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, |
| self.training, 0.1, self.eps) |
| elif self.norm_style == 'in': |
| out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, |
| self.training, 0.1, self.eps) |
| elif self.norm_style == 'gn': |
| out = groupnorm(x, self.normstyle) |
| elif self.norm_style == 'nonorm': |
| out = x |
| return out * gain + bias |
|
|
| def extra_repr(self): |
| s = 'out: {output_size}, in: {input_size},' |
| s += ' cross_replica={cross_replica}' |
| return s.format(**self.__dict__) |
|
|
|
|
| |
| class bn(nn.Module): |
| def __init__(self, output_size, eps=1e-5, momentum=0.1, |
| cross_replica=False, mybn=False): |
| super(bn, self).__init__() |
| self.output_size = output_size |
| |
| self.gain = P(torch.ones(output_size), requires_grad=True) |
| self.bias = P(torch.zeros(output_size), requires_grad=True) |
| |
| self.eps = eps |
| |
| self.momentum = momentum |
| |
| self.cross_replica = cross_replica |
| |
| self.mybn = mybn |
|
|
| if self.cross_replica: |
| self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) |
| elif mybn: |
| self.bn = myBN(output_size, self.eps, self.momentum) |
| |
| else: |
| self.register_buffer('stored_mean', torch.zeros(output_size)) |
| self.register_buffer('stored_var', torch.ones(output_size)) |
|
|
| def forward(self, x, y=None): |
| if self.cross_replica or self.mybn: |
| gain = self.gain.view(1, -1, 1, 1) |
| bias = self.bias.view(1, -1, 1, 1) |
| return self.bn(x, gain=gain, bias=bias) |
| else: |
| return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, |
| self.bias, self.training, self.momentum, self.eps) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| class GBlock(nn.Module): |
| def __init__(self, in_channels, out_channels, |
| which_conv1=nn.Conv2d, which_conv2=nn.Conv2d, which_bn=bn, activation=None, |
| upsample=None): |
| super(GBlock, self).__init__() |
|
|
| self.in_channels, self.out_channels = in_channels, out_channels |
| self.which_conv1, self.which_conv2, self.which_bn = which_conv1, which_conv2, which_bn |
| self.activation = activation |
| self.upsample = upsample |
| |
| self.conv1 = self.which_conv1(self.in_channels, self.out_channels) |
| self.conv2 = self.which_conv2(self.out_channels, self.out_channels) |
| self.learnable_sc = in_channels != out_channels or upsample |
| if self.learnable_sc: |
| self.conv_sc = self.which_conv1(in_channels, out_channels, |
| kernel_size=1, padding=0) |
| |
| self.bn1 = self.which_bn(in_channels) |
| self.bn2 = self.which_bn(out_channels) |
| |
| self.upsample = upsample |
|
|
| def forward(self, x, y): |
| h = self.activation(self.bn1(x, y)) |
| |
| |
| if self.upsample: |
| h = self.upsample(h) |
| x = self.upsample(x) |
| h = self.conv1(h) |
| h = self.activation(self.bn2(h, y)) |
| |
| h = self.conv2(h) |
| if self.learnable_sc: |
| x = self.conv_sc(x) |
| return h + x |
|
|
|
|
| |
| class DBlock(nn.Module): |
| def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, |
| preactivation=False, activation=None, downsample=None, ): |
| super(DBlock, self).__init__() |
| self.in_channels, self.out_channels = in_channels, out_channels |
| |
| self.hidden_channels = self.out_channels if wide else self.in_channels |
| self.which_conv = which_conv |
| self.preactivation = preactivation |
| self.activation = activation |
| self.downsample = downsample |
|
|
| |
| self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) |
| self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) |
| self.learnable_sc = True if (in_channels != out_channels) or downsample else False |
| if self.learnable_sc: |
| self.conv_sc = self.which_conv(in_channels, out_channels, |
| kernel_size=1, padding=0) |
|
|
| def shortcut(self, x): |
| if self.preactivation: |
| if self.learnable_sc: |
| x = self.conv_sc(x) |
| if self.downsample: |
| x = self.downsample(x) |
| else: |
| if self.downsample: |
| x = self.downsample(x) |
| if self.learnable_sc: |
| x = self.conv_sc(x) |
| return x |
|
|
| def forward(self, x): |
| if self.preactivation: |
| |
| |
| |
| h = F.relu(x) |
| else: |
| h = x |
| h = self.conv1(h) |
| h = self.conv2(self.activation(h)) |
| if self.downsample: |
| h = self.downsample(h) |
|
|
| return h + self.shortcut(x) |
|
|
| |