| from torch import nn |
|
|
| from .normalization import LayerNorm |
|
|
|
|
| class GatedConvBlock(nn.Module): |
| """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf |
| Args: |
| in_out_channels (int): number of input/output channels. |
| kernel_size (int): convolution kernel size. |
| dropout_p (float): dropout rate. |
| """ |
|
|
| def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): |
| super().__init__() |
| |
| self.dropout_p = dropout_p |
| self.num_layers = num_layers |
| |
| self.conv_layers = nn.ModuleList() |
| self.norm_layers = nn.ModuleList() |
| self.layers = nn.ModuleList() |
| for _ in range(num_layers): |
| self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)] |
| self.norm_layers += [LayerNorm(2 * in_out_channels)] |
|
|
| def forward(self, x, x_mask): |
| o = x |
| res = x |
| for idx in range(self.num_layers): |
| o = nn.functional.dropout(o, p=self.dropout_p, training=self.training) |
| o = self.conv_layers[idx](o * x_mask) |
| o = self.norm_layers[idx](o) |
| o = nn.functional.glu(o, dim=1) |
| o = res + o |
| res = o |
| return o |
|
|