Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .modules import Conv1d1x1, ResidualConv1dGLU | |
| from .upsample import ConvInUpsampleNetwork | |
| def receptive_field_size( | |
| total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x | |
| ): | |
| """Compute receptive field size | |
| Args: | |
| total_layers (int): total layers | |
| num_cycles (int): cycles | |
| kernel_size (int): kernel size | |
| dilation (lambda): lambda to compute dilation factor. ``lambda x : 1`` | |
| to disable dilated convolution. | |
| Returns: | |
| int: receptive field size in sample | |
| """ | |
| assert total_layers % num_cycles == 0 | |
| layers_per_cycle = total_layers // num_cycles | |
| dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)] | |
| return (kernel_size - 1) * sum(dilations) + 1 | |
| class WaveNet(nn.Module): | |
| """The WaveNet model that supports local and global conditioning. | |
| Args: | |
| out_channels (int): Output channels. If input_type is mu-law quantized | |
| one-hot vecror. this must equal to the quantize channels. Other wise | |
| num_mixtures x 3 (pi, mu, log_scale). | |
| layers (int): Number of total layers | |
| stacks (int): Number of dilation cycles | |
| residual_channels (int): Residual input / output channels | |
| gate_channels (int): Gated activation channels. | |
| skip_out_channels (int): Skip connection channels. | |
| kernel_size (int): Kernel size of convolution layers. | |
| dropout (float): Dropout probability. | |
| input_dim (int): Number of mel-spec dimension. | |
| upsample_scales (list): List of upsample scale. | |
| ``np.prod(upsample_scales)`` must equal to hop size. Used only if | |
| upsample_conditional_features is enabled. | |
| freq_axis_kernel_size (int): Freq-axis kernel_size for transposed | |
| convolution layers for upsampling. If you only care about time-axis | |
| upsampling, set this to 1. | |
| scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise | |
| quantized one-hot vector is expected.. | |
| """ | |
| def __init__(self, cfg): | |
| super(WaveNet, self).__init__() | |
| self.cfg = cfg | |
| self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT | |
| self.out_channels = self.cfg.VOCODER.OUT_CHANNELS | |
| self.cin_channels = self.cfg.VOCODER.INPUT_DIM | |
| self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS | |
| self.layers = self.cfg.VOCODER.LAYERS | |
| self.stacks = self.cfg.VOCODER.STACKS | |
| self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS | |
| self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE | |
| self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS | |
| self.dropout = self.cfg.VOCODER.DROPOUT | |
| self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES | |
| self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD | |
| assert self.layers % self.stacks == 0 | |
| layers_per_stack = self.layers // self.stacks | |
| if self.scalar_input: | |
| self.first_conv = Conv1d1x1(1, self.residual_channels) | |
| else: | |
| self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels) | |
| self.conv_layers = nn.ModuleList() | |
| for layer in range(self.layers): | |
| dilation = 2 ** (layer % layers_per_stack) | |
| conv = ResidualConv1dGLU( | |
| self.residual_channels, | |
| self.gate_channels, | |
| kernel_size=self.kernel_size, | |
| skip_out_channels=self.skip_out_channels, | |
| bias=True, | |
| dilation=dilation, | |
| dropout=self.dropout, | |
| cin_channels=self.cin_channels, | |
| ) | |
| self.conv_layers.append(conv) | |
| self.last_conv_layers = nn.ModuleList( | |
| [ | |
| nn.ReLU(inplace=True), | |
| Conv1d1x1(self.skip_out_channels, self.skip_out_channels), | |
| nn.ReLU(inplace=True), | |
| Conv1d1x1(self.skip_out_channels, self.out_channels), | |
| ] | |
| ) | |
| self.upsample_net = ConvInUpsampleNetwork( | |
| upsample_scales=self.upsample_scales, | |
| cin_pad=self.mel_frame_pad, | |
| cin_channels=self.cin_channels, | |
| ) | |
| self.receptive_field = receptive_field_size( | |
| self.layers, self.stacks, self.kernel_size | |
| ) | |
| def forward(self, x, mel, softmax=False): | |
| """Forward step | |
| Args: | |
| x (Tensor): One-hot encoded audio signal, shape (B x C x T) | |
| mel (Tensor): Local conditioning features, | |
| shape (B x cin_channels x T) | |
| softmax (bool): Whether applies softmax or not. | |
| Returns: | |
| Tensor: output, shape B x out_channels x T | |
| """ | |
| B, _, T = x.size() | |
| mel = self.upsample_net(mel) | |
| assert mel.shape[-1] == x.shape[-1] | |
| x = self.first_conv(x) | |
| skips = 0 | |
| for f in self.conv_layers: | |
| x, h = f(x, mel) | |
| skips += h | |
| skips *= math.sqrt(1.0 / len(self.conv_layers)) | |
| x = skips | |
| for f in self.last_conv_layers: | |
| x = f(x) | |
| x = F.softmax(x, dim=1) if softmax else x | |
| return x | |
| def clear_buffer(self): | |
| self.first_conv.clear_buffer() | |
| for f in self.conv_layers: | |
| f.clear_buffer() | |
| for f in self.last_conv_layers: | |
| try: | |
| f.clear_buffer() | |
| except AttributeError: | |
| pass | |
| def make_generation_fast_(self): | |
| def remove_weight_norm(m): | |
| try: | |
| nn.utils.remove_weight_norm(m) | |
| except ValueError: # this module didn't have weight norm | |
| return | |
| self.apply(remove_weight_norm) | |