| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.nn.utils import parametrize |
|
|
| from TTS.vocoder.layers.lvc_block import LVCBlock |
|
|
| LRELU_SLOPE = 0.1 |
|
|
|
|
| class UnivnetGenerator(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| hidden_channels: int, |
| cond_channels: int, |
| upsample_factors: List[int], |
| lvc_layers_each_block: int, |
| lvc_kernel_size: int, |
| kpnet_hidden_channels: int, |
| kpnet_conv_size: int, |
| dropout: float, |
| use_weight_norm=True, |
| ): |
| """Univnet Generator network. |
| |
| Paper: https://arxiv.org/pdf/2106.07889.pdf |
| |
| Args: |
| in_channels (int): Number of input tensor channels. |
| out_channels (int): Number of channels of the output tensor. |
| hidden_channels (int): Number of hidden network channels. |
| cond_channels (int): Number of channels of the conditioning tensors. |
| upsample_factors (List[int]): List of uplsample factors for the upsampling layers. |
| lvc_layers_each_block (int): Number of LVC layers in each block. |
| lvc_kernel_size (int): Kernel size of the LVC layers. |
| kpnet_hidden_channels (int): Number of hidden channels in the key-point network. |
| kpnet_conv_size (int): Number of convolution channels in the key-point network. |
| dropout (float): Dropout rate. |
| use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. |
| """ |
|
|
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.cond_channels = cond_channels |
| self.upsample_scale = np.prod(upsample_factors) |
| self.lvc_block_nums = len(upsample_factors) |
|
|
| |
| self.first_conv = torch.nn.Conv1d( |
| in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True |
| ) |
|
|
| |
| self.lvc_blocks = torch.nn.ModuleList() |
| cond_hop_length = 1 |
| for n in range(self.lvc_block_nums): |
| cond_hop_length = cond_hop_length * upsample_factors[n] |
| lvcb = LVCBlock( |
| in_channels=hidden_channels, |
| cond_channels=cond_channels, |
| upsample_ratio=upsample_factors[n], |
| conv_layers=lvc_layers_each_block, |
| conv_kernel_size=lvc_kernel_size, |
| cond_hop_length=cond_hop_length, |
| kpnet_hidden_channels=kpnet_hidden_channels, |
| kpnet_conv_size=kpnet_conv_size, |
| kpnet_dropout=dropout, |
| ) |
| self.lvc_blocks += [lvcb] |
|
|
| |
| self.last_conv_layers = torch.nn.ModuleList( |
| [ |
| torch.nn.Conv1d( |
| hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True |
| ), |
| ] |
| ) |
|
|
| |
| if use_weight_norm: |
| self.apply_weight_norm() |
|
|
| def forward(self, c): |
| """Calculate forward propagation. |
| Args: |
| c (Tensor): Local conditioning auxiliary features (B, C ,T'). |
| Returns: |
| Tensor: Output tensor (B, out_channels, T) |
| """ |
| |
| x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) |
| x = x.to(self.first_conv.bias.device) |
| x = self.first_conv(x) |
|
|
| for n in range(self.lvc_block_nums): |
| x = self.lvc_blocks[n](x, c) |
|
|
| |
| for f in self.last_conv_layers: |
| x = F.leaky_relu(x, LRELU_SLOPE) |
| x = f(x) |
| x = torch.tanh(x) |
| return x |
|
|
| def remove_weight_norm(self): |
| """Remove weight normalization module from all of the layers.""" |
|
|
| def _remove_weight_norm(m): |
| try: |
| |
| parametrize.remove_parametrizations(m, "weight") |
| except ValueError: |
| return |
|
|
| self.apply(_remove_weight_norm) |
|
|
| def apply_weight_norm(self): |
| """Apply weight normalization module from all of the layers.""" |
|
|
| def _apply_weight_norm(m): |
| if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): |
| torch.nn.utils.parametrizations.weight_norm(m) |
| |
|
|
| self.apply(_apply_weight_norm) |
|
|
| @staticmethod |
| def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): |
| assert layers % stacks == 0 |
| layers_per_cycle = layers // stacks |
| dilations = [dilation(i % layers_per_cycle) for i in range(layers)] |
| return (kernel_size - 1) * sum(dilations) + 1 |
|
|
| @property |
| def receptive_field_size(self): |
| """Return receptive field size.""" |
| return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) |
|
|
| @torch.no_grad() |
| def inference(self, c): |
| """Perform inference. |
| Args: |
| c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`. |
| Returns: |
| Tensor: Output tensor (T, out_channels) |
| """ |
| x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) |
| x = x.to(self.first_conv.bias.device) |
|
|
| c = c.to(next(self.parameters())) |
| return self.forward(c) |
|
|