| | """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """ |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from lib.torch_utils.ops.native_ops import ( |
| | FusedLeakyReLU, |
| | fused_leaky_relu, |
| | upfirdn2d, |
| | ) |
| |
|
| |
|
| | class DiscriminatorHead(nn.Module): |
| | def __init__(self, in_channel, disc_stddev=False): |
| | super().__init__() |
| |
|
| | self.disc_stddev = disc_stddev |
| | stddev_dim = 1 if disc_stddev else 0 |
| |
|
| | self.conv_stddev = ConvLayer2d( |
| | in_channel=in_channel + stddev_dim, |
| | out_channel=in_channel, |
| | kernel_size=3, |
| | activate=True |
| | ) |
| |
|
| | self.final_linear = nn.Sequential( |
| | nn.Flatten(), |
| | EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True), |
| | EqualLinear(in_channel=in_channel, out_channel=1), |
| | ) |
| |
|
| | def cat_stddev(self, x, stddev_group=4, stddev_feat=1): |
| | perm = torch.randperm(len(x)) |
| | inv_perm = torch.argsort(perm) |
| |
|
| | batch, channel, height, width = x.shape |
| | x = x[perm |
| | ] |
| |
|
| | group = min(batch, stddev_group) |
| | stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width) |
| | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) |
| | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) |
| | stddev = stddev.repeat(group, 1, height, width) |
| |
|
| | stddev = stddev[inv_perm] |
| | x = x[inv_perm] |
| |
|
| | out = torch.cat([x, stddev], 1) |
| | return out |
| |
|
| | def forward(self, x): |
| | if self.disc_stddev: |
| | x = self.cat_stddev(x) |
| | x = self.conv_stddev(x) |
| | out = self.final_linear(x) |
| | return out |
| |
|
| |
|
| | class ConvDecoder(nn.Module): |
| | def __init__(self, in_channel, out_channel, in_res, out_res): |
| | super().__init__() |
| |
|
| | log_size_in = int(math.log(in_res, 2)) |
| | log_size_out = int(math.log(out_res, 2)) |
| |
|
| | self.layers = [] |
| | in_ch = in_channel |
| | for i in range(log_size_in, log_size_out): |
| | out_ch = in_ch // 2 |
| | self.layers.append( |
| | ConvLayer2d( |
| | in_channel=in_ch, |
| | out_channel=out_ch, |
| | kernel_size=3, |
| | upsample=True, |
| | bias=True, |
| | activate=True |
| | ) |
| | ) |
| | in_ch = out_ch |
| |
|
| | self.layers.append( |
| | ConvLayer2d( |
| | in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False |
| | ) |
| | ) |
| | self.layers = nn.Sequential(*self.layers) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | class StyleDiscriminator(nn.Module): |
| | def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs): |
| | super().__init__() |
| |
|
| | log_size_in = int(math.log(in_res, 2)) |
| | log_size_out = int(math.log(4, 2)) |
| |
|
| | self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3) |
| |
|
| | |
| | self.layers = [] |
| | in_channels = ch_mul |
| | for i in range(log_size_in, log_size_out, -1): |
| | out_channels = int(min(in_channels * 2, ch_max)) |
| | self.layers.append( |
| | ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True) |
| | ) |
| | in_channels = out_channels |
| | self.layers = nn.Sequential(*self.layers) |
| |
|
| | self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True) |
| |
|
| | def forward(self, x): |
| | x = self.conv_in(x) |
| | x = self.layers(x) |
| | out = self.disc_out(x) |
| | return out |
| |
|
| |
|
| | def make_kernel(k): |
| | k = torch.tensor(k, dtype=torch.float32) |
| |
|
| | if k.ndim == 1: |
| | k = k[None, :] * k[:, None] |
| |
|
| | k /= k.sum() |
| |
|
| | return k |
| |
|
| |
|
| | class Blur(nn.Module): |
| | """Blur layer. |
| | |
| | Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after |
| | convolutional upsampling or before convolutional downsampling helps produces models that are more robust to |
| | shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide |
| | cleaner gradients, and therefore more stable training. |
| | |
| | Args: |
| | ---- |
| | kernel: list, int |
| | A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. |
| | pad: tuple, int |
| | A tuple of integers representing the number of rows/columns of padding to be added to the top/left and |
| | the bottom/right respectively. |
| | upsample_factor: int |
| | Upsample factor. |
| | |
| | """ |
| | def __init__(self, kernel, pad, upsample_factor=1): |
| | super().__init__() |
| |
|
| | kernel = make_kernel(kernel) |
| |
|
| | if upsample_factor > 1: |
| | kernel = kernel * (upsample_factor**2) |
| |
|
| | self.register_buffer("kernel", kernel) |
| | self.pad = pad |
| |
|
| | def forward(self, input): |
| | out = upfirdn2d(input, self.kernel, pad=self.pad) |
| | return out |
| |
|
| |
|
| | class Upsample(nn.Module): |
| | """Upsampling layer. |
| | |
| | Perform upsampling using a blur kernel. |
| | |
| | Args: |
| | ---- |
| | kernel: list, int |
| | A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. |
| | factor: int |
| | Upsampling factor. |
| | |
| | """ |
| | def __init__(self, kernel=[1, 3, 3, 1], factor=2): |
| | super().__init__() |
| |
|
| | self.factor = factor |
| | kernel = make_kernel(kernel) * (factor**2) |
| | self.register_buffer("kernel", kernel) |
| |
|
| | p = kernel.shape[0] - factor |
| | pad0 = (p + 1) // 2 + factor - 1 |
| | pad1 = p // 2 |
| | self.pad = (pad0, pad1) |
| |
|
| | def forward(self, input): |
| | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) |
| | return out |
| |
|
| |
|
| | class Downsample(nn.Module): |
| | """Downsampling layer. |
| | |
| | Perform downsampling using a blur kernel. |
| | |
| | Args: |
| | ---- |
| | kernel: list, int |
| | A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. |
| | factor: int |
| | Downsampling factor. |
| | |
| | """ |
| | def __init__(self, kernel=[1, 3, 3, 1], factor=2): |
| | super().__init__() |
| |
|
| | self.factor = factor |
| | kernel = make_kernel(kernel) |
| | self.register_buffer("kernel", kernel) |
| |
|
| | p = kernel.shape[0] - factor |
| | pad0 = (p + 1) // 2 |
| | pad1 = p // 2 |
| | self.pad = (pad0, pad1) |
| |
|
| | def forward(self, input): |
| | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) |
| | return out |
| |
|
| |
|
| | class EqualLinear(nn.Module): |
| | """Linear layer with equalized learning rate. |
| | |
| | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to |
| | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU |
| | activation functions. |
| | |
| | Args: |
| | ---- |
| | in_channel: int |
| | Input channels. |
| | out_channel: int |
| | Output channels. |
| | bias: bool |
| | Use bias term. |
| | bias_init: float |
| | Initial value for the bias. |
| | lr_mul: float |
| | Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of |
| | the gradients, effectively increasing/decreasing the learning rate for this layer. |
| | activate: bool |
| | Apply leakyReLU activation. |
| | |
| | """ |
| | def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False): |
| | super().__init__() |
| |
|
| | self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul)) |
| |
|
| | if bias: |
| | self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init)) |
| | else: |
| | self.bias = None |
| |
|
| | self.activate = activate |
| | self.scale = (1 / math.sqrt(in_channel)) * lr_mul |
| | self.lr_mul = lr_mul |
| |
|
| | def forward(self, input): |
| | if self.activate: |
| | out = F.linear(input, self.weight * self.scale) |
| | out = fused_leaky_relu(out, self.bias * self.lr_mul) |
| | else: |
| | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) |
| | return out |
| |
|
| | def __repr__(self): |
| | return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" |
| |
|
| |
|
| | class EqualConv2d(nn.Module): |
| | """2D convolution layer with equalized learning rate. |
| | |
| | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to |
| | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU |
| | activation functions. |
| | |
| | Args: |
| | ---- |
| | in_channel: int |
| | Input channels. |
| | out_channel: int |
| | Output channels. |
| | kernel_size: int |
| | Kernel size. |
| | stride: int |
| | Stride of convolutional kernel across the input. |
| | padding: int |
| | Amount of zero padding applied to both sides of the input. |
| | bias: bool |
| | Use bias term. |
| | |
| | """ |
| | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): |
| | super().__init__() |
| |
|
| | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) |
| | self.scale = 1 / math.sqrt(in_channel * kernel_size**2) |
| |
|
| | self.stride = stride |
| | self.padding = padding |
| |
|
| | if bias: |
| | self.bias = nn.Parameter(torch.zeros(out_channel)) |
| | else: |
| | self.bias = None |
| |
|
| | def forward(self, input): |
| | out = F.conv2d( |
| | input, |
| | self.weight * self.scale, |
| | bias=self.bias, |
| | stride=self.stride, |
| | padding=self.padding |
| | ) |
| | return out |
| |
|
| | def __repr__(self): |
| | return ( |
| | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," |
| | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" |
| | ) |
| |
|
| |
|
| | class EqualConvTranspose2d(nn.Module): |
| | """2D transpose convolution layer with equalized learning rate. |
| | |
| | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to |
| | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU |
| | activation functions. |
| | |
| | Args: |
| | ---- |
| | in_channel: int |
| | Input channels. |
| | out_channel: int |
| | Output channels. |
| | kernel_size: int |
| | Kernel size. |
| | stride: int |
| | Stride of convolutional kernel across the input. |
| | padding: int |
| | Amount of zero padding applied to both sides of the input. |
| | output_padding: int |
| | Extra padding added to input to achieve the desired output size. |
| | bias: bool |
| | Use bias term. |
| | |
| | """ |
| | def __init__( |
| | self, |
| | in_channel, |
| | out_channel, |
| | kernel_size, |
| | stride=1, |
| | padding=0, |
| | output_padding=0, |
| | bias=True |
| | ): |
| | super().__init__() |
| |
|
| | self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size)) |
| | self.scale = 1 / math.sqrt(in_channel * kernel_size**2) |
| |
|
| | self.stride = stride |
| | self.padding = padding |
| | self.output_padding = output_padding |
| |
|
| | if bias: |
| | self.bias = nn.Parameter(torch.zeros(out_channel)) |
| | else: |
| | self.bias = None |
| |
|
| | def forward(self, input): |
| | out = F.conv_transpose2d( |
| | input, |
| | self.weight * self.scale, |
| | bias=self.bias, |
| | stride=self.stride, |
| | padding=self.padding, |
| | output_padding=self.output_padding, |
| | ) |
| | return out |
| |
|
| | def __repr__(self): |
| | return ( |
| | f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' |
| | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' |
| | ) |
| |
|
| |
|
| | class ConvLayer2d(nn.Sequential): |
| | def __init__( |
| | self, |
| | in_channel, |
| | out_channel, |
| | kernel_size=3, |
| | upsample=False, |
| | downsample=False, |
| | blur_kernel=[1, 3, 3, 1], |
| | bias=True, |
| | activate=True, |
| | ): |
| | assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' |
| | layers = [] |
| |
|
| | if upsample: |
| | factor = 2 |
| | p = (len(blur_kernel) - factor) - (kernel_size - 1) |
| | pad0 = (p + 1) // 2 + factor - 1 |
| | pad1 = p // 2 + 1 |
| |
|
| | layers.append( |
| | EqualConvTranspose2d( |
| | in_channel, |
| | out_channel, |
| | kernel_size, |
| | padding=0, |
| | stride=2, |
| | bias=bias and not activate |
| | ) |
| | ) |
| | layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)) |
| |
|
| | if downsample: |
| | factor = 2 |
| | p = (len(blur_kernel) - factor) + (kernel_size - 1) |
| | pad0 = (p + 1) // 2 |
| | pad1 = p // 2 |
| |
|
| | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) |
| | layers.append( |
| | EqualConv2d( |
| | in_channel, |
| | out_channel, |
| | kernel_size, |
| | padding=0, |
| | stride=2, |
| | bias=bias and not activate |
| | ) |
| | ) |
| |
|
| | if (not downsample) and (not upsample): |
| | padding = kernel_size // 2 |
| |
|
| | layers.append( |
| | EqualConv2d( |
| | in_channel, |
| | out_channel, |
| | kernel_size, |
| | padding=padding, |
| | stride=1, |
| | bias=bias and not activate |
| | ) |
| | ) |
| |
|
| | if activate: |
| | layers.append(FusedLeakyReLU(out_channel, bias=bias)) |
| |
|
| | super().__init__(*layers) |
| |
|
| |
|
| | class ConvResBlock2d(nn.Module): |
| | """2D convolutional residual block with equalized learning rate. |
| | |
| | Residual block composed of 3x3 convolutions and leaky ReLUs. |
| | |
| | Args: |
| | ---- |
| | in_channel: int |
| | Input channels. |
| | out_channel: int |
| | Output channels. |
| | upsample: bool |
| | Apply upsampling via strided convolution in the first conv. |
| | downsample: bool |
| | Apply downsampling via strided convolution in the second conv. |
| | |
| | """ |
| | def __init__(self, in_channel, out_channel, upsample=False, downsample=False): |
| | super().__init__() |
| |
|
| | assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' |
| | mid_ch = in_channel if downsample else out_channel |
| |
|
| | self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3) |
| | self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3) |
| |
|
| | if (in_channel != out_channel) or upsample or downsample: |
| | self.skip = ConvLayer2d( |
| | in_channel, |
| | out_channel, |
| | upsample=upsample, |
| | downsample=downsample, |
| | kernel_size=1, |
| | activate=False, |
| | bias=False, |
| | ) |
| |
|
| | def forward(self, input): |
| | out = self.conv1(input) |
| | out = self.conv2(out) |
| |
|
| | if hasattr(self, 'skip'): |
| | skip = self.skip(input) |
| | out = (out + skip) / math.sqrt(2) |
| | else: |
| | out = (out + input) / math.sqrt(2) |
| | return out |
| |
|