| |
| """Contains the implementation of discriminator described in StyleGAN2. |
| |
| Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip |
| connections, increases model size and disables progressive growth. This script |
| ONLY supports config F in the original paper. |
| |
| Paper: https://arxiv.org/pdf/1912.04958.pdf |
| |
| Official TensorFlow implementation: https://github.com/NVlabs/stylegan2 |
| """ |
|
|
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from third_party.stylegan2_official_ops import bias_act |
| from third_party.stylegan2_official_ops import upfirdn2d |
| from third_party.stylegan2_official_ops import conv2d_gradfix |
|
|
| __all__ = ['StyleGAN2Discriminator'] |
|
|
| |
| _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] |
|
|
| |
| _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] |
|
|
| |
|
|
| class StyleGAN2Discriminator(nn.Module): |
| """Defines the discriminator network in StyleGAN2. |
| |
| NOTE: The discriminator takes images with `RGB` channel order and pixel |
| range [-1, 1] as inputs. |
| |
| Settings for the backbone: |
| |
| (1) resolution: The resolution of the input image. (default: -1) |
| (2) init_res: Smallest resolution of the convolutional backbone. |
| (default: 4) |
| (3) image_channels: Number of channels of the input image. (default: 3) |
| (4) architecture: Type of architecture. Support `origin`, `skip`, and |
| `resnet`. (default: `resnet`) |
| (5) use_wscale: Whether to use weight scaling. (default: True) |
| (6) wscale_gain: The factor to control weight scaling. (default: 1.0) |
| (7) lr_mul: Learning rate multiplier for backbone. (default: 1.0) |
| (8) mbstd_groups: Group size for the minibatch standard deviation layer. |
| `0` means disable. (default: 4) |
| (9) mbstd_channels: Number of new channels (appended to the original feature |
| map) after the minibatch standard deviation layer. (default: 1) |
| (10) fmaps_base: Factor to control number of feature maps for each layer. |
| (default: 32 << 10) |
| (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512) |
| (12) filter_kernel: Kernel used for filtering (e.g., downsampling). |
| (default: (1, 3, 3, 1)) |
| (13) conv_clamp: A threshold to clamp the output of convolution layers to |
| avoid overflow under FP16 training. (default: None) |
| (14) eps: A small value to avoid divide overflow. (default: 1e-8) |
| |
| Settings for conditional model: |
| |
| (1) label_dim: Dimension of the additional label for conditional generation. |
| In one-hot conditioning case, it is equal to the number of classes. If |
| set to 0, conditioning training will be disabled. (default: 0) |
| (2) embedding_dim: Dimension of the embedding space, if needed. |
| (default: 512) |
| (3) embedding_bias: Whether to add bias to embedding learning. |
| (default: True) |
| (4) embedding_use_wscale: Whether to use weight scaling for embedding |
| learning. (default: True) |
| (5) embedding_lr_mul: Learning rate multiplier for the embedding learning. |
| (default: 1.0) |
| (6) normalize_embedding: Whether to normalize the embedding. (default: True) |
| (7) mapping_layers: Number of layers of the additional mapping network after |
| embedding. (default: 0) |
| (8) mapping_fmaps: Number of hidden channels of the additional mapping |
| network after embedding. (default: 512) |
| (9) mapping_use_wscale: Whether to use weight scaling for the additional |
| mapping network. (default: True) |
| (10) mapping_lr_mul: Learning rate multiplier for the additional mapping |
| network after embedding. (default: 0.1) |
| |
| Runtime settings: |
| |
| (1) fp16_res: Layers at resolution higher than (or equal to) this field will |
| use `float16` precision for computation. This is merely used for |
| acceleration. If set as `None`, all layers will use `float32` by |
| default. (default: None) |
| (2) impl: Implementation mode of some particular ops, e.g., `filtering`, |
| `bias_act`, etc. `cuda` means using the official CUDA implementation |
| from StyleGAN2, while `ref` means using the native PyTorch ops. |
| (default: `cuda`) |
| """ |
|
|
| def __init__(self, |
| |
| resolution=-1, |
| init_res=4, |
| image_channels=3, |
| architecture='resnet', |
| use_wscale=True, |
| wscale_gain=1.0, |
| lr_mul=1.0, |
| mbstd_groups=4, |
| mbstd_channels=1, |
| fmaps_base=32 << 10, |
| fmaps_max=512, |
| filter_kernel=(1, 3, 3, 1), |
| conv_clamp=None, |
| eps=1e-8, |
| |
| label_dim=0, |
| embedding_dim=512, |
| embedding_bias=True, |
| embedding_use_wscale=True, |
| embedding_lr_mul=1.0, |
| normalize_embedding=True, |
| mapping_layers=0, |
| mapping_fmaps=512, |
| mapping_use_wscale=True, |
| mapping_lr_mul=0.1): |
| """Initializes with basic settings. |
| |
| Raises: |
| ValueError: If the `resolution` is not supported, or `architecture` |
| is not supported. |
| """ |
| super().__init__() |
|
|
| if resolution not in _RESOLUTIONS_ALLOWED: |
| raise ValueError(f'Invalid resolution: `{resolution}`!\n' |
| f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') |
| architecture = architecture.lower() |
| if architecture not in _ARCHITECTURES_ALLOWED: |
| raise ValueError(f'Invalid architecture: `{architecture}`!\n' |
| f'Architectures allowed: ' |
| f'{_ARCHITECTURES_ALLOWED}.') |
|
|
| self.init_res = init_res |
| self.init_res_log2 = int(np.log2(init_res)) |
| self.resolution = resolution |
| self.final_res_log2 = int(np.log2(resolution)) |
| self.image_channels = image_channels |
| self.architecture = architecture |
| self.use_wscale = use_wscale |
| self.wscale_gain = wscale_gain |
| self.lr_mul = lr_mul |
| self.mbstd_groups = mbstd_groups |
| self.mbstd_channels = mbstd_channels |
| self.fmaps_base = fmaps_base |
| self.fmaps_max = fmaps_max |
| self.filter_kernel = filter_kernel |
| self.conv_clamp = conv_clamp |
| self.eps = eps |
|
|
| self.label_dim = label_dim |
| self.embedding_dim = embedding_dim |
| self.embedding_bias = embedding_bias |
| self.embedding_use_wscale = embedding_use_wscale |
| self.embedding_lr_mul = embedding_lr_mul |
| self.normalize_embedding = normalize_embedding |
| self.mapping_layers = mapping_layers |
| self.mapping_fmaps = mapping_fmaps |
| self.mapping_use_wscale = mapping_use_wscale |
| self.mapping_lr_mul = mapping_lr_mul |
|
|
| self.pth_to_tf_var_mapping = {} |
|
|
| |
| self.use_embedding = label_dim > 0 and embedding_dim > 0 |
| if self.use_embedding: |
| self.embedding = DenseLayer(in_channels=label_dim, |
| out_channels=embedding_dim, |
| add_bias=embedding_bias, |
| init_bias=0.0, |
| use_wscale=embedding_use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=embedding_lr_mul, |
| activation_type='linear') |
| self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' |
| if self.embedding_bias: |
| self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' |
|
|
| if self.normalize_embedding: |
| self.norm = PixelNormLayer(dim=1, eps=eps) |
|
|
| for i in range(mapping_layers): |
| in_channels = (embedding_dim if i == 0 else mapping_fmaps) |
| out_channels = (embedding_dim if i == (mapping_layers - 1) else |
| mapping_fmaps) |
| layer_name = f'mapping{i}' |
| self.add_module(layer_name, |
| DenseLayer(in_channels=in_channels, |
| out_channels=out_channels, |
| add_bias=True, |
| init_bias=0.0, |
| use_wscale=mapping_use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=mapping_lr_mul, |
| activation_type='lrelu')) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'Mapping{i}/weight') |
| self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
| f'Mapping{i}/bias') |
|
|
| |
| for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): |
| res = 2 ** res_log2 |
| in_channels = self.get_nf(res) |
| out_channels = self.get_nf(res // 2) |
| block_idx = self.final_res_log2 - res_log2 |
|
|
| |
| if res_log2 == self.final_res_log2 or self.architecture == 'skip': |
| layer_name = f'input{block_idx}' |
| self.add_module(layer_name, |
| ConvLayer(in_channels=image_channels, |
| out_channels=in_channels, |
| kernel_size=1, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='lrelu', |
| conv_clamp=conv_clamp)) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'{res}x{res}/FromRGB/weight') |
| self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
| f'{res}x{res}/FromRGB/bias') |
|
|
| |
| if res != self.init_res: |
| |
| layer_name = f'layer{2 * block_idx}' |
| self.add_module(layer_name, |
| ConvLayer(in_channels=in_channels, |
| out_channels=in_channels, |
| kernel_size=3, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='lrelu', |
| conv_clamp=conv_clamp)) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'{res}x{res}/Conv0/weight') |
| self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
| f'{res}x{res}/Conv0/bias') |
|
|
| |
| layer_name = f'layer{2 * block_idx + 1}' |
| self.add_module(layer_name, |
| ConvLayer(in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| add_bias=True, |
| scale_factor=2, |
| filter_kernel=filter_kernel, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='lrelu', |
| conv_clamp=conv_clamp)) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'{res}x{res}/Conv1_down/weight') |
| self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
| f'{res}x{res}/Conv1_down/bias') |
|
|
| |
| |
| if self.architecture == 'resnet': |
| layer_name = f'residual{block_idx}' |
| self.add_module(layer_name, |
| ConvLayer(in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=1, |
| add_bias=False, |
| scale_factor=2, |
| filter_kernel=filter_kernel, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='linear', |
| conv_clamp=None)) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'{res}x{res}/Skip/weight') |
|
|
| |
| else: |
| self.mbstd = MiniBatchSTDLayer( |
| groups=mbstd_groups, new_channels=mbstd_channels, eps=eps) |
|
|
| |
| layer_name = f'layer{2 * block_idx}' |
| self.add_module( |
| layer_name, |
| ConvLayer(in_channels=in_channels + mbstd_channels, |
| out_channels=in_channels, |
| kernel_size=3, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='lrelu', |
| conv_clamp=conv_clamp)) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'{res}x{res}/Conv/weight') |
| self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
| f'{res}x{res}/Conv/bias') |
|
|
| |
| layer_name = f'layer{2 * block_idx + 1}' |
| self.add_module(layer_name, |
| DenseLayer(in_channels=in_channels * res * res, |
| out_channels=in_channels, |
| add_bias=True, |
| init_bias=0.0, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='lrelu')) |
| self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
| f'{res}x{res}/Dense0/weight') |
| self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
| f'{res}x{res}/Dense0/bias') |
|
|
| |
| self.output = DenseLayer(in_channels=in_channels, |
| out_channels=(embedding_dim |
| if self.use_embedding |
| else max(label_dim, 1)), |
| add_bias=True, |
| init_bias=0.0, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='linear') |
| self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight' |
| self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias' |
|
|
| |
| if self.architecture == 'skip': |
| self.register_buffer( |
| 'filter', upfirdn2d.setup_filter(filter_kernel)) |
|
|
| def get_nf(self, res): |
| """Gets number of feature maps according to the given resolution.""" |
| return min(self.fmaps_base // res, self.fmaps_max) |
|
|
| def forward(self, image, label=None, fp16_res=None, impl='cuda'): |
| |
| expected_shape = (self.image_channels, self.resolution, self.resolution) |
| if image.ndim != 4 or image.shape[1:] != expected_shape: |
| raise ValueError(f'The input tensor should be with shape ' |
| f'[batch_size, channel, height, width], where ' |
| f'`channel` equals to {self.image_channels}, ' |
| f'`height`, `width` equal to {self.resolution}!\n' |
| f'But `{image.shape}` is received!') |
| if self.label_dim > 0: |
| if label is None: |
| raise ValueError(f'Model requires an additional label ' |
| f'(with dimension {self.label_dim}) as input, ' |
| f'but no label is received!') |
| batch_size = image.shape[0] |
| if label.ndim != 2 or label.shape != (batch_size, self.label_dim): |
| raise ValueError(f'Input label should be with shape ' |
| f'[batch_size, label_dim], where ' |
| f'`batch_size` equals to that of ' |
| f'images ({image.shape[0]}) and ' |
| f'`label_dim` equals to {self.label_dim}!\n' |
| f'But `{label.shape}` is received!') |
| label = label.to(dtype=torch.float32) |
| if self.use_embedding: |
| embed = self.embedding(label, impl=impl) |
| if self.normalize_embedding: |
| embed = self.norm(embed) |
| for i in range(self.mapping_layers): |
| embed = getattr(self, f'mapping{i}')(embed, impl=impl) |
|
|
| |
| if fp16_res is not None and self.resolution >= fp16_res: |
| image = image.to(torch.float16) |
|
|
| x = self.input0(image, impl=impl) |
|
|
| for res_log2 in range(self.final_res_log2, self.init_res_log2, -1): |
| res = 2 ** res_log2 |
| |
| if fp16_res is not None and res >= fp16_res: |
| x = x.to(torch.float16) |
| else: |
| x = x.to(torch.float32) |
|
|
| idx = self.final_res_log2 - res_log2 |
|
|
| if self.architecture == 'skip' and idx > 0: |
| image = upfirdn2d.downsample2d(image, self.filter, impl=impl) |
| |
| if fp16_res is not None and res >= fp16_res: |
| image = image.to(torch.float16) |
| else: |
| image = image.to(torch.float32) |
| y = getattr(self, f'input{idx}')(image, impl=impl) |
| x = x + y |
|
|
| if self.architecture == 'resnet': |
| residual = getattr(self, f'residual{idx}')( |
| x, runtime_gain=np.sqrt(0.5), impl=impl) |
| x = getattr(self, f'layer{2 * idx}')(x, impl=impl) |
| x = getattr(self, f'layer{2 * idx + 1}')( |
| x, runtime_gain=np.sqrt(0.5), impl=impl) |
| x = x + residual |
| else: |
| x = getattr(self, f'layer{2 * idx}')(x, impl=impl) |
| x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) |
|
|
| |
| idx += 1 |
| if fp16_res is not None: |
| x = x.to(torch.float32) |
| if self.architecture == 'skip': |
| image = upfirdn2d.downsample2d(image, self.filter, impl=impl) |
| if fp16_res is not None: |
| image = image.to(torch.float32) |
| y = getattr(self, f'input{idx}')(image, impl=impl) |
| x = x + y |
| x = self.mbstd(x) |
| x = getattr(self, f'layer{2 * idx}')(x, impl=impl) |
| x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) |
| x = self.output(x, impl=impl) |
|
|
| if self.use_embedding: |
| x = (x * embed).sum(dim=1, keepdim=True) |
| x = x / np.sqrt(self.embedding_dim) |
| elif self.label_dim > 0: |
| x = (x * label).sum(dim=1, keepdim=True) |
|
|
| results = { |
| 'score': x, |
| 'label': label |
| } |
| if self.use_embedding: |
| results['embedding'] = embed |
| return results |
|
|
|
|
| class PixelNormLayer(nn.Module): |
| """Implements pixel-wise feature vector normalization layer.""" |
|
|
| def __init__(self, dim, eps): |
| super().__init__() |
| self.dim = dim |
| self.eps = eps |
|
|
| def extra_repr(self): |
| return f'dim={self.dim}, epsilon={self.eps}' |
|
|
| def forward(self, x): |
| scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() |
| return x * scale |
|
|
|
|
| class MiniBatchSTDLayer(nn.Module): |
| """Implements the minibatch standard deviation layer.""" |
|
|
| def __init__(self, groups, new_channels, eps): |
| super().__init__() |
| self.groups = groups |
| self.new_channels = new_channels |
| self.eps = eps |
|
|
| def extra_repr(self): |
| return (f'groups={self.groups}, ' |
| f'new_channels={self.new_channels}, ' |
| f'epsilon={self.eps}') |
|
|
| def forward(self, x): |
| if self.groups <= 1 or self.new_channels < 1: |
| return x |
|
|
| dtype = x.dtype |
|
|
| N, C, H, W = x.shape |
| G = min(self.groups, N) |
| nC = self.new_channels |
| c = C // nC |
|
|
| y = x.reshape(G, -1, nC, c, H, W) |
| y = y - y.mean(dim=0) |
| y = y.square().mean(dim=0) |
| y = (y + self.eps).sqrt() |
| y = y.mean(dim=(2, 3, 4)) |
| y = y.reshape(-1, nC, 1, 1) |
| y = y.repeat(G, 1, H, W) |
| x = torch.cat((x, y), dim=1) |
|
|
| assert x.dtype == dtype |
| return x |
|
|
|
|
| class ConvLayer(nn.Module): |
| """Implements the convolutional layer. |
| |
| If downsampling is needed (i.e., `scale_factor = 2`), the feature map will |
| be filtered with `filter_kernel` first. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| add_bias, |
| scale_factor, |
| filter_kernel, |
| use_wscale, |
| wscale_gain, |
| lr_mul, |
| activation_type, |
| conv_clamp): |
| """Initializes with layer settings. |
| |
| Args: |
| in_channels: Number of channels of the input tensor. |
| out_channels: Number of channels of the output tensor. |
| kernel_size: Size of the convolutional kernels. |
| add_bias: Whether to add bias onto the convolutional result. |
| scale_factor: Scale factor for downsampling. `1` means skip |
| downsampling. |
| filter_kernel: Kernel used for filtering. |
| use_wscale: Whether to use weight scaling. |
| wscale_gain: Gain factor for weight scaling. |
| lr_mul: Learning multiplier for both weight and bias. |
| activation_type: Type of activation. |
| conv_clamp: A threshold to clamp the output of convolution layers to |
| avoid overflow under FP16 training. |
| """ |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.add_bias = add_bias |
| self.scale_factor = scale_factor |
| self.filter_kernel = filter_kernel |
| self.use_wscale = use_wscale |
| self.wscale_gain = wscale_gain |
| self.lr_mul = lr_mul |
| self.activation_type = activation_type |
| self.conv_clamp = conv_clamp |
|
|
| weight_shape = (out_channels, in_channels, kernel_size, kernel_size) |
| fan_in = kernel_size * kernel_size * in_channels |
| wscale = wscale_gain / np.sqrt(fan_in) |
| if use_wscale: |
| self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
| self.wscale = wscale * lr_mul |
| else: |
| self.weight = nn.Parameter( |
| torch.randn(*weight_shape) * wscale / lr_mul) |
| self.wscale = lr_mul |
|
|
| if add_bias: |
| self.bias = nn.Parameter(torch.zeros(out_channels)) |
| self.bscale = lr_mul |
| else: |
| self.bias = None |
| self.act_gain = bias_act.activation_funcs[activation_type].def_gain |
|
|
| if scale_factor > 1: |
| assert filter_kernel is not None |
| self.register_buffer( |
| 'filter', upfirdn2d.setup_filter(filter_kernel)) |
| fh, fw = self.filter.shape |
| self.filter_padding = ( |
| kernel_size // 2 + (fw - scale_factor + 1) // 2, |
| kernel_size // 2 + (fw - scale_factor) // 2, |
| kernel_size // 2 + (fh - scale_factor + 1) // 2, |
| kernel_size // 2 + (fh - scale_factor) // 2) |
|
|
| def extra_repr(self): |
| return (f'in_ch={self.in_channels}, ' |
| f'out_ch={self.out_channels}, ' |
| f'ksize={self.kernel_size}, ' |
| f'wscale_gain={self.wscale_gain:.3f}, ' |
| f'bias={self.add_bias}, ' |
| f'lr_mul={self.lr_mul:.3f}, ' |
| f'downsample={self.scale_factor}, ' |
| f'downsample_filter={self.filter_kernel}, ' |
| f'act={self.activation_type}, ' |
| f'clamp={self.conv_clamp}') |
|
|
| def forward(self, x, runtime_gain=1.0, impl='cuda'): |
| dtype = x.dtype |
|
|
| weight = self.weight |
| if self.wscale != 1.0: |
| weight = weight * self.wscale |
| bias = None |
| if self.bias is not None: |
| bias = self.bias.to(dtype) |
| if self.bscale != 1.0: |
| bias = bias * self.bscale |
|
|
| if self.scale_factor == 1: |
| padding = self.kernel_size // 2 |
| x = conv2d_gradfix.conv2d( |
| x, weight.to(dtype), stride=1, padding=padding, impl=impl) |
| else: |
| down = self.scale_factor |
| f = self.filter |
| padding = self.filter_padding |
| |
| if self.kernel_size == 1: |
| x = upfirdn2d.upfirdn2d( |
| x, f, down=down, padding=padding, impl=impl) |
| x = conv2d_gradfix.conv2d( |
| x, weight.to(dtype), stride=1, padding=0, impl=impl) |
| |
| else: |
| x = upfirdn2d.upfirdn2d( |
| x, f, down=1, padding=padding, impl=impl) |
| x = conv2d_gradfix.conv2d( |
| x, weight.to(dtype), stride=down, padding=0, impl=impl) |
|
|
| act_gain = self.act_gain * runtime_gain |
| act_clamp = None |
| if self.conv_clamp is not None: |
| act_clamp = self.conv_clamp * runtime_gain |
| x = bias_act.bias_act(x, bias, |
| act=self.activation_type, |
| gain=act_gain, |
| clamp=act_clamp, |
| impl=impl) |
|
|
| assert x.dtype == dtype |
| return x |
|
|
|
|
| class DenseLayer(nn.Module): |
| """Implements the dense layer.""" |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| add_bias, |
| init_bias, |
| use_wscale, |
| wscale_gain, |
| lr_mul, |
| activation_type): |
| """Initializes with layer settings. |
| |
| Args: |
| in_channels: Number of channels of the input tensor. |
| out_channels: Number of channels of the output tensor. |
| add_bias: Whether to add bias onto the fully-connected result. |
| init_bias: The initial bias value before training. |
| use_wscale: Whether to use weight scaling. |
| wscale_gain: Gain factor for weight scaling. |
| lr_mul: Learning multiplier for both weight and bias. |
| activation_type: Type of activation. |
| """ |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.add_bias = add_bias |
| self.init_bias = init_bias |
| self.use_wscale = use_wscale |
| self.wscale_gain = wscale_gain |
| self.lr_mul = lr_mul |
| self.activation_type = activation_type |
|
|
| weight_shape = (out_channels, in_channels) |
| wscale = wscale_gain / np.sqrt(in_channels) |
| if use_wscale: |
| self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
| self.wscale = wscale * lr_mul |
| else: |
| self.weight = nn.Parameter( |
| torch.randn(*weight_shape) * wscale / lr_mul) |
| self.wscale = lr_mul |
|
|
| if add_bias: |
| init_bias = np.float32(init_bias) / lr_mul |
| self.bias = nn.Parameter(torch.full([out_channels], init_bias)) |
| self.bscale = lr_mul |
| else: |
| self.bias = None |
|
|
| def extra_repr(self): |
| return (f'in_ch={self.in_channels}, ' |
| f'out_ch={self.out_channels}, ' |
| f'wscale_gain={self.wscale_gain:.3f}, ' |
| f'bias={self.add_bias}, ' |
| f'init_bias={self.init_bias}, ' |
| f'lr_mul={self.lr_mul:.3f}, ' |
| f'act={self.activation_type}') |
|
|
| def forward(self, x, impl='cuda'): |
| dtype = x.dtype |
|
|
| if x.ndim != 2: |
| x = x.flatten(start_dim=1) |
|
|
| weight = self.weight.to(dtype) * self.wscale |
| bias = None |
| if self.bias is not None: |
| bias = self.bias.to(dtype) |
| if self.bscale != 1.0: |
| bias = bias * self.bscale |
|
|
| |
| if self.activation_type == 'linear' and bias is not None: |
| x = torch.addmm(bias.unsqueeze(0), x, weight.t()) |
| else: |
| x = x.matmul(weight.t()) |
| x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) |
|
|
| assert x.dtype == dtype |
| return x |
|
|
| |
|
|