|
|
|
|
|
|
|
|
|
|
| """MelGAN Modules."""
|
|
|
| import logging
|
|
|
| import numpy as np
|
| import torch
|
|
|
| from modules.parallel_wavegan.layers import CausalConv1d
|
| from modules.parallel_wavegan.layers import CausalConvTranspose1d
|
| from modules.parallel_wavegan.layers import ResidualStack
|
|
|
|
|
| class MelGANGenerator(torch.nn.Module):
|
| """MelGAN generator module."""
|
|
|
| def __init__(self,
|
| in_channels=80,
|
| out_channels=1,
|
| kernel_size=7,
|
| channels=512,
|
| bias=True,
|
| upsample_scales=[8, 8, 2, 2],
|
| stack_kernel_size=3,
|
| stacks=3,
|
| nonlinear_activation="LeakyReLU",
|
| nonlinear_activation_params={"negative_slope": 0.2},
|
| pad="ReflectionPad1d",
|
| pad_params={},
|
| use_final_nonlinear_activation=True,
|
| use_weight_norm=True,
|
| use_causal_conv=False,
|
| ):
|
| """Initialize MelGANGenerator module.
|
|
|
| Args:
|
| in_channels (int): Number of input channels.
|
| out_channels (int): Number of output channels.
|
| kernel_size (int): Kernel size of initial and final conv layer.
|
| channels (int): Initial number of channels for conv layer.
|
| bias (bool): Whether to add bias parameter in convolution layers.
|
| upsample_scales (list): List of upsampling scales.
|
| stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
|
| stacks (int): Number of stacks in a single residual stack.
|
| nonlinear_activation (str): Activation function module name.
|
| nonlinear_activation_params (dict): Hyperparameters for activation function.
|
| pad (str): Padding function module name before dilated convolution layer.
|
| pad_params (dict): Hyperparameters for padding function.
|
| use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
|
| use_weight_norm (bool): Whether to use weight norm.
|
| If set to true, it will be applied to all of the conv layers.
|
| use_causal_conv (bool): Whether to use causal convolution.
|
|
|
| """
|
| super(MelGANGenerator, self).__init__()
|
|
|
|
|
| assert channels >= np.prod(upsample_scales)
|
| assert channels % (2 ** len(upsample_scales)) == 0
|
| if not use_causal_conv:
|
| assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
|
|
|
|
| layers = []
|
| if not use_causal_conv:
|
| layers += [
|
| getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
|
| torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias),
|
| ]
|
| else:
|
| layers += [
|
| CausalConv1d(in_channels, channels, kernel_size,
|
| bias=bias, pad=pad, pad_params=pad_params),
|
| ]
|
|
|
| for i, upsample_scale in enumerate(upsample_scales):
|
|
|
| layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
|
| if not use_causal_conv:
|
| layers += [
|
| torch.nn.ConvTranspose1d(
|
| channels // (2 ** i),
|
| channels // (2 ** (i + 1)),
|
| upsample_scale * 2,
|
| stride=upsample_scale,
|
| padding=upsample_scale // 2 + upsample_scale % 2,
|
| output_padding=upsample_scale % 2,
|
| bias=bias,
|
| )
|
| ]
|
| else:
|
| layers += [
|
| CausalConvTranspose1d(
|
| channels // (2 ** i),
|
| channels // (2 ** (i + 1)),
|
| upsample_scale * 2,
|
| stride=upsample_scale,
|
| bias=bias,
|
| )
|
| ]
|
|
|
|
|
| for j in range(stacks):
|
| layers += [
|
| ResidualStack(
|
| kernel_size=stack_kernel_size,
|
| channels=channels // (2 ** (i + 1)),
|
| dilation=stack_kernel_size ** j,
|
| bias=bias,
|
| nonlinear_activation=nonlinear_activation,
|
| nonlinear_activation_params=nonlinear_activation_params,
|
| pad=pad,
|
| pad_params=pad_params,
|
| use_causal_conv=use_causal_conv,
|
| )
|
| ]
|
|
|
|
|
| layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
|
| if not use_causal_conv:
|
| layers += [
|
| getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
|
| torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias),
|
| ]
|
| else:
|
| layers += [
|
| CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size,
|
| bias=bias, pad=pad, pad_params=pad_params),
|
| ]
|
| if use_final_nonlinear_activation:
|
| layers += [torch.nn.Tanh()]
|
|
|
|
|
| self.melgan = torch.nn.Sequential(*layers)
|
|
|
|
|
| if use_weight_norm:
|
| self.apply_weight_norm()
|
|
|
|
|
| self.reset_parameters()
|
|
|
| def forward(self, c):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| c (Tensor): Input tensor (B, channels, T).
|
|
|
| Returns:
|
| Tensor: Output tensor (B, 1, T ** prod(upsample_scales)).
|
|
|
| """
|
| return self.melgan(c)
|
|
|
| def remove_weight_norm(self):
|
| """Remove weight normalization module from all of the layers."""
|
| def _remove_weight_norm(m):
|
| try:
|
| logging.debug(f"Weight norm is removed from {m}.")
|
| torch.nn.utils.remove_weight_norm(m)
|
| except ValueError:
|
| return
|
|
|
| self.apply(_remove_weight_norm)
|
|
|
| def apply_weight_norm(self):
|
| """Apply weight normalization module from all of the layers."""
|
| def _apply_weight_norm(m):
|
| if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
| torch.nn.utils.weight_norm(m)
|
| logging.debug(f"Weight norm is applied to {m}.")
|
|
|
| self.apply(_apply_weight_norm)
|
|
|
| def reset_parameters(self):
|
| """Reset parameters.
|
|
|
| This initialization follows official implementation manner.
|
| https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
|
|
|
| """
|
| def _reset_parameters(m):
|
| if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
| m.weight.data.normal_(0.0, 0.02)
|
| logging.debug(f"Reset parameters in {m}.")
|
|
|
| self.apply(_reset_parameters)
|
|
|
|
|
| class MelGANDiscriminator(torch.nn.Module):
|
| """MelGAN discriminator module."""
|
|
|
| def __init__(self,
|
| in_channels=1,
|
| out_channels=1,
|
| kernel_sizes=[5, 3],
|
| channels=16,
|
| max_downsample_channels=1024,
|
| bias=True,
|
| downsample_scales=[4, 4, 4, 4],
|
| nonlinear_activation="LeakyReLU",
|
| nonlinear_activation_params={"negative_slope": 0.2},
|
| pad="ReflectionPad1d",
|
| pad_params={},
|
| ):
|
| """Initilize MelGAN discriminator module.
|
|
|
| Args:
|
| in_channels (int): Number of input channels.
|
| out_channels (int): Number of output channels.
|
| kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,
|
| and the first and the second kernel sizes will be used for the last two layers.
|
| For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
|
| the last two layers' kernel size will be 5 and 3, respectively.
|
| channels (int): Initial number of channels for conv layer.
|
| max_downsample_channels (int): Maximum number of channels for downsampling layers.
|
| bias (bool): Whether to add bias parameter in convolution layers.
|
| downsample_scales (list): List of downsampling scales.
|
| nonlinear_activation (str): Activation function module name.
|
| nonlinear_activation_params (dict): Hyperparameters for activation function.
|
| pad (str): Padding function module name before dilated convolution layer.
|
| pad_params (dict): Hyperparameters for padding function.
|
|
|
| """
|
| super(MelGANDiscriminator, self).__init__()
|
| self.layers = torch.nn.ModuleList()
|
|
|
|
|
| assert len(kernel_sizes) == 2
|
| assert kernel_sizes[0] % 2 == 1
|
| assert kernel_sizes[1] % 2 == 1
|
|
|
|
|
| self.layers += [
|
| torch.nn.Sequential(
|
| getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
|
| torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias),
|
| getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| )
|
| ]
|
|
|
|
|
| in_chs = channels
|
| for downsample_scale in downsample_scales:
|
| out_chs = min(in_chs * downsample_scale, max_downsample_channels)
|
| self.layers += [
|
| torch.nn.Sequential(
|
| torch.nn.Conv1d(
|
| in_chs, out_chs,
|
| kernel_size=downsample_scale * 10 + 1,
|
| stride=downsample_scale,
|
| padding=downsample_scale * 5,
|
| groups=in_chs // 4,
|
| bias=bias,
|
| ),
|
| getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| )
|
| ]
|
| in_chs = out_chs
|
|
|
|
|
| out_chs = min(in_chs * 2, max_downsample_channels)
|
| self.layers += [
|
| torch.nn.Sequential(
|
| torch.nn.Conv1d(
|
| in_chs, out_chs, kernel_sizes[0],
|
| padding=(kernel_sizes[0] - 1) // 2,
|
| bias=bias,
|
| ),
|
| getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| )
|
| ]
|
| self.layers += [
|
| torch.nn.Conv1d(
|
| out_chs, out_channels, kernel_sizes[1],
|
| padding=(kernel_sizes[1] - 1) // 2,
|
| bias=bias,
|
| ),
|
| ]
|
|
|
| def forward(self, x):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| x (Tensor): Input noise signal (B, 1, T).
|
|
|
| Returns:
|
| List: List of output tensors of each layer.
|
|
|
| """
|
| outs = []
|
| for f in self.layers:
|
| x = f(x)
|
| outs += [x]
|
|
|
| return outs
|
|
|
|
|
| class MelGANMultiScaleDiscriminator(torch.nn.Module):
|
| """MelGAN multi-scale discriminator module."""
|
|
|
| def __init__(self,
|
| in_channels=1,
|
| out_channels=1,
|
| scales=3,
|
| downsample_pooling="AvgPool1d",
|
|
|
| downsample_pooling_params={
|
| "kernel_size": 4,
|
| "stride": 2,
|
| "padding": 1,
|
| "count_include_pad": False,
|
| },
|
| kernel_sizes=[5, 3],
|
| channels=16,
|
| max_downsample_channels=1024,
|
| bias=True,
|
| downsample_scales=[4, 4, 4, 4],
|
| nonlinear_activation="LeakyReLU",
|
| nonlinear_activation_params={"negative_slope": 0.2},
|
| pad="ReflectionPad1d",
|
| pad_params={},
|
| use_weight_norm=True,
|
| ):
|
| """Initilize MelGAN multi-scale discriminator module.
|
|
|
| Args:
|
| in_channels (int): Number of input channels.
|
| out_channels (int): Number of output channels.
|
| downsample_pooling (str): Pooling module name for downsampling of the inputs.
|
| downsample_pooling_params (dict): Parameters for the above pooling module.
|
| kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer,
|
| and the first and the second kernel sizes will be used for the last two layers.
|
| channels (int): Initial number of channels for conv layer.
|
| max_downsample_channels (int): Maximum number of channels for downsampling layers.
|
| bias (bool): Whether to add bias parameter in convolution layers.
|
| downsample_scales (list): List of downsampling scales.
|
| nonlinear_activation (str): Activation function module name.
|
| nonlinear_activation_params (dict): Hyperparameters for activation function.
|
| pad (str): Padding function module name before dilated convolution layer.
|
| pad_params (dict): Hyperparameters for padding function.
|
| use_causal_conv (bool): Whether to use causal convolution.
|
|
|
| """
|
| super(MelGANMultiScaleDiscriminator, self).__init__()
|
| self.discriminators = torch.nn.ModuleList()
|
|
|
|
|
| for _ in range(scales):
|
| self.discriminators += [
|
| MelGANDiscriminator(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| kernel_sizes=kernel_sizes,
|
| channels=channels,
|
| max_downsample_channels=max_downsample_channels,
|
| bias=bias,
|
| downsample_scales=downsample_scales,
|
| nonlinear_activation=nonlinear_activation,
|
| nonlinear_activation_params=nonlinear_activation_params,
|
| pad=pad,
|
| pad_params=pad_params,
|
| )
|
| ]
|
| self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params)
|
|
|
|
|
| if use_weight_norm:
|
| self.apply_weight_norm()
|
|
|
|
|
| self.reset_parameters()
|
|
|
| def forward(self, x):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| x (Tensor): Input noise signal (B, 1, T).
|
|
|
| Returns:
|
| List: List of list of each discriminator outputs, which consists of each layer output tensors.
|
|
|
| """
|
| outs = []
|
| for f in self.discriminators:
|
| outs += [f(x)]
|
| x = self.pooling(x)
|
|
|
| return outs
|
|
|
| def remove_weight_norm(self):
|
| """Remove weight normalization module from all of the layers."""
|
| def _remove_weight_norm(m):
|
| try:
|
| logging.debug(f"Weight norm is removed from {m}.")
|
| torch.nn.utils.remove_weight_norm(m)
|
| except ValueError:
|
| return
|
|
|
| self.apply(_remove_weight_norm)
|
|
|
| def apply_weight_norm(self):
|
| """Apply weight normalization module from all of the layers."""
|
| def _apply_weight_norm(m):
|
| if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
| torch.nn.utils.weight_norm(m)
|
| logging.debug(f"Weight norm is applied to {m}.")
|
|
|
| self.apply(_apply_weight_norm)
|
|
|
| def reset_parameters(self):
|
| """Reset parameters.
|
|
|
| This initialization follows official implementation manner.
|
| https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
|
|
|
| """
|
| def _reset_parameters(m):
|
| if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
| m.weight.data.normal_(0.0, 0.02)
|
| logging.debug(f"Reset parameters in {m}.")
|
|
|
| self.apply(_reset_parameters)
|
|
|