| |
| """Contains the implementation of discriminator described in StyleGAN. |
| |
| Paper: https://arxiv.org/pdf/1812.04948.pdf |
| |
| Official TensorFlow implementation: https://github.com/NVlabs/stylegan |
| """ |
|
|
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| __all__ = ['StyleGANDiscriminator'] |
|
|
| |
| _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] |
|
|
| |
| _INIT_RES = 4 |
|
|
| |
| _FUSED_SCALE_ALLOWED = [True, False, 'auto'] |
|
|
| |
| _AUTO_FUSED_SCALE_MIN_RES = 128 |
|
|
| |
| _WSCALE_GAIN = np.sqrt(2.0) |
|
|
|
|
| class StyleGANDiscriminator(nn.Module): |
| """Defines the discriminator network in StyleGAN. |
| |
| NOTE: The discriminator takes images with `RGB` channel order and pixel |
| range [-1, 1] as inputs. |
| |
| Settings for the network: |
| |
| (1) resolution: The resolution of the input image. |
| (2) image_channels: Number of channels of the input image. (default: 3) |
| (3) label_size: Size of the additional label for conditional generation. |
| (default: 0) |
| (4) fused_scale: Whether to fused `conv2d` and `downsample` together, |
| resulting in `conv2d` with strides. (default: `auto`) |
| (5) use_wscale: Whether to use weight scaling. (default: True) |
| (6) minibatch_std_group_size: Group size for the minibatch standard |
| deviation layer. 0 means disable. (default: 4) |
| (7) minibatch_std_channels: Number of new channels after the minibatch |
| standard deviation layer. (default: 1) |
| (8) fmaps_base: Factor to control number of feature maps for each layer. |
| (default: 16 << 10) |
| (9) fmaps_max: Maximum number of feature maps in each layer. (default: 512) |
| """ |
|
|
| def __init__(self, |
| resolution, |
| image_channels=3, |
| label_size=0, |
| fused_scale='auto', |
| use_wscale=True, |
| minibatch_std_group_size=4, |
| minibatch_std_channels=1, |
| fmaps_base=16 << 10, |
| fmaps_max=512): |
| """Initializes with basic settings. |
| |
| Raises: |
| ValueError: If the `resolution` is not supported, or `fused_scale` |
| is not supported. |
| """ |
| super().__init__() |
|
|
| if resolution not in _RESOLUTIONS_ALLOWED: |
| raise ValueError(f'Invalid resolution: `{resolution}`!\n' |
| f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') |
| if fused_scale not in _FUSED_SCALE_ALLOWED: |
| raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' |
| f'Options allowed: {_FUSED_SCALE_ALLOWED}.') |
|
|
| self.init_res = _INIT_RES |
| self.init_res_log2 = int(np.log2(self.init_res)) |
| self.resolution = resolution |
| self.final_res_log2 = int(np.log2(self.resolution)) |
| self.image_channels = image_channels |
| self.label_size = label_size |
| self.fused_scale = fused_scale |
| self.use_wscale = use_wscale |
| self.minibatch_std_group_size = minibatch_std_group_size |
| self.minibatch_std_channels = minibatch_std_channels |
| self.fmaps_base = fmaps_base |
| self.fmaps_max = fmaps_max |
|
|
| |
| self.register_buffer('lod', torch.zeros(())) |
| self.pth_to_tf_var_mapping = {'lod': 'lod'} |
|
|
| for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): |
| res = 2 ** res_log2 |
| block_idx = self.final_res_log2 - res_log2 |
|
|
| |
| self.add_module( |
| f'input{block_idx}', |
| ConvBlock(in_channels=self.image_channels, |
| out_channels=self.get_nf(res), |
| kernel_size=1, |
| padding=0, |
| use_wscale=self.use_wscale)) |
| self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( |
| f'FromRGB_lod{block_idx}/weight') |
| self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( |
| f'FromRGB_lod{block_idx}/bias') |
|
|
| |
| if res != self.init_res: |
| if self.fused_scale == 'auto': |
| fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES) |
| else: |
| fused_scale = self.fused_scale |
| self.add_module( |
| f'layer{2 * block_idx}', |
| ConvBlock(in_channels=self.get_nf(res), |
| out_channels=self.get_nf(res), |
| use_wscale=self.use_wscale)) |
| tf_layer0_name = 'Conv0' |
| self.add_module( |
| f'layer{2 * block_idx + 1}', |
| ConvBlock(in_channels=self.get_nf(res), |
| out_channels=self.get_nf(res // 2), |
| downsample=True, |
| fused_scale=fused_scale, |
| use_wscale=self.use_wscale)) |
| tf_layer1_name = 'Conv1_down' |
|
|
| |
| else: |
| self.add_module( |
| f'layer{2 * block_idx}', |
| ConvBlock(in_channels=self.get_nf(res), |
| out_channels=self.get_nf(res), |
| use_wscale=self.use_wscale, |
| minibatch_std_group_size=minibatch_std_group_size, |
| minibatch_std_channels=minibatch_std_channels)) |
| tf_layer0_name = 'Conv' |
| self.add_module( |
| f'layer{2 * block_idx + 1}', |
| DenseBlock(in_channels=self.get_nf(res) * res * res, |
| out_channels=self.get_nf(res // 2), |
| use_wscale=self.use_wscale)) |
| tf_layer1_name = 'Dense0' |
|
|
| self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( |
| f'{res}x{res}/{tf_layer0_name}/weight') |
| self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( |
| f'{res}x{res}/{tf_layer0_name}/bias') |
| self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( |
| f'{res}x{res}/{tf_layer1_name}/weight') |
| self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( |
| f'{res}x{res}/{tf_layer1_name}/bias') |
|
|
| |
| self.add_module( |
| f'layer{2 * block_idx + 2}', |
| DenseBlock(in_channels=self.get_nf(res // 2), |
| out_channels=max(self.label_size, 1), |
| use_wscale=self.use_wscale, |
| wscale_gain=1.0, |
| activation_type='linear')) |
| self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = ( |
| f'{res}x{res}/Dense1/weight') |
| self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = ( |
| f'{res}x{res}/Dense1/bias') |
|
|
| self.downsample = DownsamplingLayer() |
|
|
| def get_nf(self, res): |
| """Gets number of feature maps according to current resolution.""" |
| return min(self.fmaps_base // res, self.fmaps_max) |
|
|
| def forward(self, image, label=None, lod=None, **_unused_kwargs): |
| expected_shape = (self.image_channels, self.resolution, self.resolution) |
| if image.ndim != 4 or image.shape[1:] != expected_shape: |
| raise ValueError(f'The input tensor should be with shape ' |
| f'[batch_size, channel, height, width], where ' |
| f'`channel` equals to {self.image_channels}, ' |
| f'`height`, `width` equal to {self.resolution}!\n' |
| f'But `{image.shape}` is received!') |
|
|
| lod = self.lod.cpu().tolist() if lod is None else lod |
| if lod + self.init_res_log2 > self.final_res_log2: |
| raise ValueError(f'Maximum level-of-detail (lod) is ' |
| f'{self.final_res_log2 - self.init_res_log2}, ' |
| f'but `{lod}` is received!') |
|
|
| if self.label_size: |
| if label is None: |
| raise ValueError(f'Model requires an additional label ' |
| f'(with size {self.label_size}) as input, ' |
| f'but no label is received!') |
| batch_size = image.shape[0] |
| if label.ndim != 2 or label.shape != (batch_size, self.label_size): |
| raise ValueError(f'Input label should be with shape ' |
| f'[batch_size, label_size], where ' |
| f'`batch_size` equals to that of ' |
| f'images ({image.shape[0]}) and ' |
| f'`label_size` equals to {self.label_size}!\n' |
| f'But `{label.shape}` is received!') |
|
|
| for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): |
| block_idx = current_lod = self.final_res_log2 - res_log2 |
| if current_lod <= lod < current_lod + 1: |
| x = self.__getattr__(f'input{block_idx}')(image) |
| elif current_lod - 1 < lod < current_lod: |
| alpha = lod - np.floor(lod) |
| x = (self.__getattr__(f'input{block_idx}')(image) * alpha + |
| x * (1 - alpha)) |
| if lod < current_lod + 1: |
| x = self.__getattr__(f'layer{2 * block_idx}')(x) |
| x = self.__getattr__(f'layer{2 * block_idx + 1}')(x) |
| if lod > current_lod: |
| image = self.downsample(image) |
| x = self.__getattr__(f'layer{2 * block_idx + 2}')(x) |
|
|
| if self.label_size: |
| x = torch.sum(x * label, dim=1, keepdim=True) |
|
|
| return x |
|
|
|
|
| class MiniBatchSTDLayer(nn.Module): |
| """Implements the minibatch standard deviation layer.""" |
|
|
| def __init__(self, group_size=4, new_channels=1, epsilon=1e-8): |
| super().__init__() |
| self.group_size = group_size |
| self.new_channels = new_channels |
| self.epsilon = epsilon |
|
|
| def forward(self, x): |
| if self.group_size <= 1: |
| return x |
| ng = min(self.group_size, x.shape[0]) |
| nc = self.new_channels |
| temp_c = x.shape[1] // nc |
| y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) |
| y = y - torch.mean(y, dim=0, keepdim=True) |
| y = torch.mean(y ** 2, dim=0) |
| y = torch.sqrt(y + self.epsilon) |
| y = torch.mean(y, dim=[2, 3, 4], keepdim=True) |
| y = torch.mean(y, dim=2) |
| y = y.repeat(ng, 1, x.shape[2], x.shape[3]) |
| return torch.cat([x, y], dim=1) |
|
|
|
|
| class DownsamplingLayer(nn.Module): |
| """Implements the downsampling layer. |
| |
| Basically, this layer can be used to downsample feature maps with average |
| pooling. |
| """ |
|
|
| def __init__(self, scale_factor=2): |
| super().__init__() |
| self.scale_factor = scale_factor |
|
|
| def forward(self, x): |
| if self.scale_factor <= 1: |
| return x |
| return F.avg_pool2d(x, |
| kernel_size=self.scale_factor, |
| stride=self.scale_factor, |
| padding=0) |
|
|
|
|
| class Blur(torch.autograd.Function): |
| """Defines blur operation with customized gradient computation.""" |
|
|
| @staticmethod |
| def forward(ctx, x, kernel): |
| ctx.save_for_backward(kernel) |
| y = F.conv2d(input=x, |
| weight=kernel, |
| bias=None, |
| stride=1, |
| padding=1, |
| groups=x.shape[1]) |
| return y |
|
|
| @staticmethod |
| def backward(ctx, dy): |
| kernel, = ctx.saved_tensors |
| dx = BlurBackPropagation.apply(dy, kernel) |
| return dx, None, None |
|
|
|
|
| class BlurBackPropagation(torch.autograd.Function): |
| """Defines the back propagation of blur operation. |
| |
| NOTE: This is used to speed up the backward of gradient penalty. |
| """ |
|
|
| @staticmethod |
| def forward(ctx, dy, kernel): |
| ctx.save_for_backward(kernel) |
| dx = F.conv2d(input=dy, |
| weight=kernel.flip((2, 3)), |
| bias=None, |
| stride=1, |
| padding=1, |
| groups=dy.shape[1]) |
| return dx |
|
|
| @staticmethod |
| def backward(ctx, ddx): |
| kernel, = ctx.saved_tensors |
| ddy = F.conv2d(input=ddx, |
| weight=kernel, |
| bias=None, |
| stride=1, |
| padding=1, |
| groups=ddx.shape[1]) |
| return ddy, None, None |
|
|
|
|
| class BlurLayer(nn.Module): |
| """Implements the blur layer.""" |
|
|
| def __init__(self, |
| channels, |
| kernel=(1, 2, 1), |
| normalize=True): |
| super().__init__() |
| kernel = np.array(kernel, dtype=np.float32).reshape(1, -1) |
| kernel = kernel.T.dot(kernel) |
| if normalize: |
| kernel = kernel / np.sum(kernel) |
| kernel = kernel[np.newaxis, np.newaxis] |
| kernel = np.tile(kernel, [channels, 1, 1, 1]) |
| self.register_buffer('kernel', torch.from_numpy(kernel)) |
|
|
| def forward(self, x): |
| return Blur.apply(x, self.kernel) |
|
|
|
|
| class ConvBlock(nn.Module): |
| """Implements the convolutional block. |
| |
| Basically, this block executes minibatch standard deviation layer (if |
| needed), convolutional layer, activation layer, and downsampling layer ( |
| if needed) in sequence. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| add_bias=True, |
| downsample=False, |
| fused_scale=False, |
| use_wscale=True, |
| wscale_gain=_WSCALE_GAIN, |
| lr_mul=1.0, |
| activation_type='lrelu', |
| minibatch_std_group_size=0, |
| minibatch_std_channels=1): |
| """Initializes with block settings. |
| |
| Args: |
| in_channels: Number of channels of the input tensor. |
| out_channels: Number of channels of the output tensor. |
| kernel_size: Size of the convolutional kernels. (default: 3) |
| stride: Stride parameter for convolution operation. (default: 1) |
| padding: Padding parameter for convolution operation. (default: 1) |
| add_bias: Whether to add bias onto the convolutional result. |
| (default: True) |
| downsample: Whether to downsample the result after convolution. |
| (default: False) |
| fused_scale: Whether to fused `conv2d` and `downsample` together, |
| resulting in `conv2d` with strides. (default: False) |
| use_wscale: Whether to use weight scaling. (default: True) |
| wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) |
| lr_mul: Learning multiplier for both weight and bias. (default: 1.0) |
| activation_type: Type of activation. Support `linear` and `lrelu`. |
| (default: `lrelu`) |
| minibatch_std_group_size: Group size for the minibatch standard |
| deviation layer. 0 means disable. (default: 0) |
| minibatch_std_channels: Number of new channels after the minibatch |
| standard deviation layer. (default: 1) |
| |
| Raises: |
| NotImplementedError: If the `activation_type` is not supported. |
| """ |
| super().__init__() |
|
|
| if minibatch_std_group_size > 1: |
| in_channels = in_channels + minibatch_std_channels |
| self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size, |
| new_channels=minibatch_std_channels) |
| else: |
| self.mbstd = nn.Identity() |
|
|
| if downsample: |
| self.blur = BlurLayer(channels=in_channels) |
| else: |
| self.blur = nn.Identity() |
|
|
| if downsample and not fused_scale: |
| self.downsample = DownsamplingLayer() |
| else: |
| self.downsample = nn.Identity() |
|
|
| if downsample and fused_scale: |
| self.use_stride = True |
| self.stride = 2 |
| self.padding = 1 |
| else: |
| self.use_stride = False |
| self.stride = stride |
| self.padding = padding |
|
|
| weight_shape = (out_channels, in_channels, kernel_size, kernel_size) |
| fan_in = kernel_size * kernel_size * in_channels |
| wscale = wscale_gain / np.sqrt(fan_in) |
| if use_wscale: |
| self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
| self.wscale = wscale * lr_mul |
| else: |
| self.weight = nn.Parameter( |
| torch.randn(*weight_shape) * wscale / lr_mul) |
| self.wscale = lr_mul |
|
|
| if add_bias: |
| self.bias = nn.Parameter(torch.zeros(out_channels)) |
| self.bscale = lr_mul |
| else: |
| self.bias = None |
|
|
| if activation_type == 'linear': |
| self.activate = nn.Identity() |
| elif activation_type == 'lrelu': |
| self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| else: |
| raise NotImplementedError(f'Not implemented activation function: ' |
| f'`{activation_type}`!') |
|
|
| def forward(self, x): |
| x = self.mbstd(x) |
| x = self.blur(x) |
| weight = self.weight * self.wscale |
| bias = self.bias * self.bscale if self.bias is not None else None |
| if self.use_stride: |
| weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) |
| weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + |
| weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 |
| x = F.conv2d(x, |
| weight=weight, |
| bias=bias, |
| stride=self.stride, |
| padding=self.padding) |
| x = self.downsample(x) |
| x = self.activate(x) |
| return x |
|
|
|
|
| class DenseBlock(nn.Module): |
| """Implements the dense block. |
| |
| Basically, this block executes fully-connected layer and activation layer. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| add_bias=True, |
| use_wscale=True, |
| wscale_gain=_WSCALE_GAIN, |
| lr_mul=1.0, |
| activation_type='lrelu'): |
| """Initializes with block settings. |
| |
| Args: |
| in_channels: Number of channels of the input tensor. |
| out_channels: Number of channels of the output tensor. |
| add_bias: Whether to add bias onto the fully-connected result. |
| (default: True) |
| use_wscale: Whether to use weight scaling. (default: True) |
| wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) |
| lr_mul: Learning multiplier for both weight and bias. (default: 1.0) |
| activation_type: Type of activation. Support `linear` and `lrelu`. |
| (default: `lrelu`) |
| |
| Raises: |
| NotImplementedError: If the `activation_type` is not supported. |
| """ |
| super().__init__() |
| weight_shape = (out_channels, in_channels) |
| wscale = wscale_gain / np.sqrt(in_channels) |
| if use_wscale: |
| self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
| self.wscale = wscale * lr_mul |
| else: |
| self.weight = nn.Parameter( |
| torch.randn(*weight_shape) * wscale / lr_mul) |
| self.wscale = lr_mul |
|
|
| if add_bias: |
| self.bias = nn.Parameter(torch.zeros(out_channels)) |
| self.bscale = lr_mul |
| else: |
| self.bias = None |
|
|
| if activation_type == 'linear': |
| self.activate = nn.Identity() |
| elif activation_type == 'lrelu': |
| self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| else: |
| raise NotImplementedError(f'Not implemented activation function: ' |
| f'`{activation_type}`!') |
|
|
| def forward(self, x): |
| if x.ndim != 2: |
| x = x.view(x.shape[0], -1) |
| bias = self.bias * self.bscale if self.bias is not None else None |
| x = F.linear(x, weight=self.weight * self.wscale, bias=bias) |
| x = self.activate(x) |
| return x |
|
|