|
|
|
|
| """Upsampling module.
|
|
|
| This code is modified from https://github.com/r9y9/wavenet_vocoder.
|
|
|
| """
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| from . import Conv1d
|
|
|
|
|
| class Stretch2d(torch.nn.Module):
|
| """Stretch2d module."""
|
|
|
| def __init__(self, x_scale, y_scale, mode="nearest"):
|
| """Initialize Stretch2d module.
|
|
|
| Args:
|
| x_scale (int): X scaling factor (Time axis in spectrogram).
|
| y_scale (int): Y scaling factor (Frequency axis in spectrogram).
|
| mode (str): Interpolation mode.
|
|
|
| """
|
| super(Stretch2d, self).__init__()
|
| self.x_scale = x_scale
|
| self.y_scale = y_scale
|
| self.mode = mode
|
|
|
| def forward(self, x):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| x (Tensor): Input tensor (B, C, F, T).
|
|
|
| Returns:
|
| Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
|
|
| """
|
| return F.interpolate(
|
| x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
|
|
|
|
| class Conv2d(torch.nn.Conv2d):
|
| """Conv2d module with customized initialization."""
|
|
|
| def __init__(self, *args, **kwargs):
|
| """Initialize Conv2d module."""
|
| super(Conv2d, self).__init__(*args, **kwargs)
|
|
|
| def reset_parameters(self):
|
| """Reset parameters."""
|
| self.weight.data.fill_(1. / np.prod(self.kernel_size))
|
| if self.bias is not None:
|
| torch.nn.init.constant_(self.bias, 0.0)
|
|
|
|
|
| class UpsampleNetwork(torch.nn.Module):
|
| """Upsampling network module."""
|
|
|
| def __init__(self,
|
| upsample_scales,
|
| nonlinear_activation=None,
|
| nonlinear_activation_params={},
|
| interpolate_mode="nearest",
|
| freq_axis_kernel_size=1,
|
| use_causal_conv=False,
|
| ):
|
| """Initialize upsampling network module.
|
|
|
| Args:
|
| upsample_scales (list): List of upsampling scales.
|
| nonlinear_activation (str): Activation function name.
|
| nonlinear_activation_params (dict): Arguments for specified activation function.
|
| interpolate_mode (str): Interpolation mode.
|
| freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
|
|
| """
|
| super(UpsampleNetwork, self).__init__()
|
| self.use_causal_conv = use_causal_conv
|
| self.up_layers = torch.nn.ModuleList()
|
| for scale in upsample_scales:
|
|
|
| stretch = Stretch2d(scale, 1, interpolate_mode)
|
| self.up_layers += [stretch]
|
|
|
|
|
| assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
| freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
| kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
| if use_causal_conv:
|
| padding = (freq_axis_padding, scale * 2)
|
| else:
|
| padding = (freq_axis_padding, scale)
|
| conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
| self.up_layers += [conv]
|
|
|
|
|
| if nonlinear_activation is not None:
|
| nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
| self.up_layers += [nonlinear]
|
|
|
| def forward(self, c):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| c : Input tensor (B, C, T).
|
|
|
| Returns:
|
| Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
|
|
|
| """
|
| c = c.unsqueeze(1)
|
| for f in self.up_layers:
|
| if self.use_causal_conv and isinstance(f, Conv2d):
|
| c = f(c)[..., :c.size(-1)]
|
| else:
|
| c = f(c)
|
| return c.squeeze(1)
|
|
|
|
|
| class ConvInUpsampleNetwork(torch.nn.Module):
|
| """Convolution + upsampling network module."""
|
|
|
| def __init__(self,
|
| upsample_scales,
|
| nonlinear_activation=None,
|
| nonlinear_activation_params={},
|
| interpolate_mode="nearest",
|
| freq_axis_kernel_size=1,
|
| aux_channels=80,
|
| aux_context_window=0,
|
| use_causal_conv=False
|
| ):
|
| """Initialize convolution + upsampling network module.
|
|
|
| Args:
|
| upsample_scales (list): List of upsampling scales.
|
| nonlinear_activation (str): Activation function name.
|
| nonlinear_activation_params (dict): Arguments for specified activation function.
|
| mode (str): Interpolation mode.
|
| freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
| aux_channels (int): Number of channels of pre-convolutional layer.
|
| aux_context_window (int): Context window size of the pre-convolutional layer.
|
| use_causal_conv (bool): Whether to use causal structure.
|
|
|
| """
|
| super(ConvInUpsampleNetwork, self).__init__()
|
| self.aux_context_window = aux_context_window
|
| self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
|
|
| kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
|
|
| self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
| self.upsample = UpsampleNetwork(
|
| upsample_scales=upsample_scales,
|
| nonlinear_activation=nonlinear_activation,
|
| nonlinear_activation_params=nonlinear_activation_params,
|
| interpolate_mode=interpolate_mode,
|
| freq_axis_kernel_size=freq_axis_kernel_size,
|
| use_causal_conv=use_causal_conv,
|
| )
|
|
|
| def forward(self, c):
|
| """Calculate forward propagation.
|
|
|
| Args:
|
| c : Input tensor (B, C, T').
|
|
|
| Returns:
|
| Tensor: Upsampled tensor (B, C, T),
|
| where T = (T' - aux_context_window * 2) * prod(upsample_scales).
|
|
|
| Note:
|
| The length of inputs considers the context window size.
|
|
|
| """
|
| c_ = self.conv_in(c)
|
| c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
|
| return self.upsample(c)
|
|
|