| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, channels, eps=1e-5): |
| | super().__init__() |
| | self.channels = channels |
| | self.eps = eps |
| |
|
| | self.gamma = nn.Parameter(torch.ones(channels)) |
| | self.beta = nn.Parameter(torch.zeros(channels)) |
| |
|
| | def forward(self, x): |
| | x = x.transpose(1, -1) |
| | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) |
| | return x.transpose(1, -1) |
| |
|
| |
|
| | class ConvReluNorm(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | hidden_channels, |
| | out_channels, |
| | kernel_size, |
| | n_layers, |
| | p_dropout, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.hidden_channels = hidden_channels |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.n_layers = n_layers |
| | self.p_dropout = p_dropout |
| | assert n_layers > 1, "Number of layers should be larger than 0." |
| |
|
| | self.conv_layers = nn.ModuleList() |
| | self.norm_layers = nn.ModuleList() |
| | self.conv_layers.append( |
| | nn.Conv1d( |
| | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 |
| | ) |
| | ) |
| | self.norm_layers.append(LayerNorm(hidden_channels)) |
| | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) |
| | for _ in range(n_layers - 1): |
| | self.conv_layers.append( |
| | nn.Conv1d( |
| | hidden_channels, |
| | hidden_channels, |
| | kernel_size, |
| | padding=kernel_size // 2, |
| | ) |
| | ) |
| | self.norm_layers.append(LayerNorm(hidden_channels)) |
| | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) |
| | self.proj.weight.data.zero_() |
| | self.proj.bias.data.zero_() |
| |
|
| | def forward(self, x, x_mask): |
| | x_org = x |
| | for i in range(self.n_layers): |
| | x = self.conv_layers[i](x * x_mask) |
| | x = self.norm_layers[i](x) |
| | x = self.relu_drop(x) |
| | x = x_org + self.proj(x) |
| | return x * x_mask |
| |
|