| import torch |
| from .commons import fused_add_tanh_sigmoid_multiply |
| from packaging import version |
| is_pytorch2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") |
|
|
| class WaveNet(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| hidden_channels: int, |
| kernel_size: int, |
| dilation_rate, |
| n_layers: int, |
| gin_channels: int = 0, |
| p_dropout: int = 0, |
| ): |
| super().__init__() |
| assert kernel_size % 2 == 1, "Kernel size must be odd for proper padding." |
|
|
| self.hidden_channels = hidden_channels |
| self.kernel_size = (kernel_size,) |
| self.dilation_rate = dilation_rate |
| self.n_layers = n_layers |
| self.gin_channels = gin_channels |
| self.p_dropout = p_dropout |
| self.n_channels_tensor = torch.IntTensor([hidden_channels]) |
|
|
| self.in_layers = torch.nn.ModuleList() |
| self.res_skip_layers = torch.nn.ModuleList() |
| self.drop = torch.nn.Dropout(p_dropout) |
|
|
| if gin_channels: |
| if is_pytorch2_1: |
| self.cond_layer = torch.nn.utils.parametrizations.weight_norm( |
| torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), |
| name="weight", |
| ) |
| else: |
| self.cond_layer = torch.nn.utils.weight_norm( |
| torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), |
| name="weight", |
| ) |
|
|
| dilations = [dilation_rate**i for i in range(n_layers)] |
| paddings = [(kernel_size * d - d) // 2 for d in dilations] |
|
|
| for i in range(n_layers): |
| if is_pytorch2_1: |
| self.in_layers.append( |
| torch.nn.utils.parametrizations.weight_norm( |
| torch.nn.Conv1d( |
| hidden_channels, |
| 2 * hidden_channels, |
| kernel_size, |
| dilation=dilations[i], |
| padding=paddings[i], |
| ), |
| name="weight", |
| ) |
| ) |
| else: |
| self.in_layers.append( |
| torch.nn.utils.weight_norm( |
| torch.nn.Conv1d( |
| hidden_channels, |
| 2 * hidden_channels, |
| kernel_size, |
| dilation=dilations[i], |
| padding=paddings[i], |
| ), |
| name="weight", |
| ) |
| ) |
|
|
| res_skip_channels = ( |
| hidden_channels if i == n_layers - 1 else 2 * hidden_channels |
| ) |
| if is_pytorch2_1: |
| self.res_skip_layers.append( |
| torch.nn.utils.parametrizations.weight_norm( |
| torch.nn.Conv1d(hidden_channels, res_skip_channels, 1), |
| name="weight", |
| ) |
| ) |
| else: |
| self.res_skip_layers.append( |
| torch.nn.utils.weight_norm( |
| torch.nn.Conv1d(hidden_channels, res_skip_channels, 1), |
| name="weight", |
| ) |
| ) |
|
|
|
|
| def forward(self, x, x_mask, g=None): |
| output = x.clone().zero_() |
|
|
| g = self.cond_layer(g) if g is not None else None |
|
|
| for i in range(self.n_layers): |
| x_in = self.in_layers[i](x) |
| g_l = ( |
| g[ |
| :, |
| i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, |
| :, |
| ] |
| if g is not None |
| else 0 |
| ) |
|
|
| acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.n_channels_tensor) |
| acts = self.drop(acts) |
|
|
| res_skip_acts = self.res_skip_layers[i](acts) |
| if i < self.n_layers - 1: |
| res_acts = res_skip_acts[:, : self.hidden_channels, :] |
| x = (x + res_acts) * x_mask |
| output = output + res_skip_acts[:, self.hidden_channels :, :] |
| else: |
| output = output + res_skip_acts |
|
|
| return output * x_mask |
|
|
| def remove_weight_norm(self): |
| if self.gin_channels: |
| torch.nn.utils.remove_weight_norm(self.cond_layer) |
| for layer in self.in_layers: |
| torch.nn.utils.remove_weight_norm(layer) |
| for layer in self.res_skip_layers: |
| torch.nn.utils.remove_weight_norm(layer) |
|
|