| |
|
|
| |
| |
|
|
| """Causal convolusion layer modules.""" |
|
|
|
|
| import torch |
|
|
|
|
| class CausalConv1d(torch.nn.Module): |
| """CausalConv1d module with customized initialization.""" |
|
|
| def __init__(self, in_channels, out_channels, kernel_size, |
| dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}): |
| """Initialize CausalConv1d module.""" |
| super(CausalConv1d, self).__init__() |
| self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) |
| self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, |
| dilation=dilation, bias=bias) |
|
|
| def forward(self, x): |
| """Calculate forward propagation. |
| |
| Args: |
| x (Tensor): Input tensor (B, in_channels, T). |
| |
| Returns: |
| Tensor: Output tensor (B, out_channels, T). |
| |
| """ |
| return self.conv(self.pad(x))[:, :, :x.size(2)] |
|
|
|
|
| class CausalConvTranspose1d(torch.nn.Module): |
| """CausalConvTranspose1d module with customized initialization.""" |
|
|
| def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): |
| """Initialize CausalConvTranspose1d module.""" |
| super(CausalConvTranspose1d, self).__init__() |
| self.deconv = torch.nn.ConvTranspose1d( |
| in_channels, out_channels, kernel_size, stride, bias=bias) |
| self.stride = stride |
|
|
| def forward(self, x): |
| """Calculate forward propagation. |
| |
| Args: |
| x (Tensor): Input tensor (B, in_channels, T_in). |
| |
| Returns: |
| Tensor: Output tensor (B, out_channels, T_out). |
| |
| """ |
| return self.deconv(x)[:, :, :-self.stride] |
|
|