| import torch |
| from torch import nn |
|
|
| from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder |
| from TTS.tts.utils.helpers import sequence_mask |
|
|
|
|
| class Decoder(nn.Module): |
| """Uses glow decoder with some modifications. |
| :: |
| |
| Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze |
| |
| Args: |
| in_channels (int): channels of input tensor. |
| hidden_channels (int): hidden decoder channels. |
| kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) |
| dilation_rate (int): rate to increase dilation by each layer in a decoder block. |
| num_flow_blocks (int): number of decoder blocks. |
| num_coupling_layers (int): number coupling layers. (number of wavenet layers.) |
| dropout_p (float): wavenet dropout rate. |
| sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| hidden_channels, |
| kernel_size, |
| dilation_rate, |
| num_flow_blocks, |
| num_coupling_layers, |
| dropout_p=0.0, |
| num_splits=4, |
| num_squeeze=2, |
| sigmoid_scale=False, |
| c_in_channels=0, |
| ): |
| super().__init__() |
|
|
| self.glow_decoder = GlowDecoder( |
| in_channels, |
| hidden_channels, |
| kernel_size, |
| dilation_rate, |
| num_flow_blocks, |
| num_coupling_layers, |
| dropout_p, |
| num_splits, |
| num_squeeze, |
| sigmoid_scale, |
| c_in_channels, |
| ) |
| self.n_sqz = num_squeeze |
|
|
| def forward(self, x, x_len, g=None, reverse=False): |
| """ |
| Input shapes: |
| - x: :math:`[B, C, T]` |
| - x_len :math:`[B]` |
| - g: :math:`[B, C]` |
| |
| Output shapes: |
| - x: :math:`[B, C, T]` |
| - x_len :math:`[B]` |
| - logget_tot :math:`[B]` |
| """ |
| x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) |
| x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) |
| x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) |
| return x, x_len, logdet_tot |
|
|
| def preprocess(self, y, y_lengths, y_max_length): |
| if y_max_length is not None: |
| y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz |
| y = y[:, :, :y_max_length] |
| y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz |
| return y, y_lengths, y_max_length |
|
|
| def store_inverse(self): |
| self.glow_decoder.store_inverse() |
|
|