| import torch |
| from torch.nn import functional as F |
|
|
|
|
| class ResidualBlock(torch.nn.Module): |
| """Residual block module in WaveNet.""" |
|
|
| def __init__( |
| self, |
| kernel_size=3, |
| res_channels=64, |
| gate_channels=128, |
| skip_channels=64, |
| aux_channels=80, |
| dropout=0.0, |
| dilation=1, |
| bias=True, |
| use_causal_conv=False, |
| ): |
| super().__init__() |
| self.dropout = dropout |
| |
| if use_causal_conv: |
| padding = (kernel_size - 1) * dilation |
| else: |
| assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." |
| padding = (kernel_size - 1) // 2 * dilation |
| self.use_causal_conv = use_causal_conv |
|
|
| |
| self.conv = torch.nn.Conv1d( |
| res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias |
| ) |
|
|
| |
| if aux_channels > 0: |
| self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) |
| else: |
| self.conv1x1_aux = None |
|
|
| |
| gate_out_channels = gate_channels // 2 |
| self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) |
| self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) |
|
|
| def forward(self, x, c): |
| """ |
| x: B x D_res x T |
| c: B x D_aux x T |
| """ |
| residual = x |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = self.conv(x) |
|
|
| |
| x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x |
|
|
| |
| splitdim = 1 |
| xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) |
|
|
| |
| if c is not None: |
| assert self.conv1x1_aux is not None |
| c = self.conv1x1_aux(c) |
| ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) |
| xa, xb = xa + ca, xb + cb |
|
|
| x = torch.tanh(xa) * torch.sigmoid(xb) |
|
|
| |
| s = self.conv1x1_skip(x) |
|
|
| |
| x = (self.conv1x1_out(x) + residual) * (0.5**2) |
|
|
| return x, s |
|
|