| import math |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class Conv1d(nn.Module): |
| """This function implements 1d convolution. |
| |
| Arguments |
| --------- |
| out_channels : int |
| It is the number of output channels. |
| kernel_size : int |
| Kernel size of the convolutional filters. |
| input_shape : tuple |
| The shape of the input. Alternatively use ``in_channels``. |
| in_channels : int |
| The number of input channels. Alternatively use ``input_shape``. |
| stride : int |
| Stride factor of the convolutional filters. When the stride factor > 1, |
| a decimation in time is performed. |
| dilation : int |
| Dilation factor of the convolutional filters. |
| padding : str |
| (same, valid, causal). If "valid", no padding is performed. |
| If "same" and stride is 1, output shape is the same as the input shape. |
| "causal" results in causal (dilated) convolutions. |
| groups : int |
| Number of blocked connections from input channels to output channels. |
| bias : bool |
| Whether to add a bias term to convolution operation. |
| padding_mode : str |
| This flag specifies the type of padding. See torch.nn documentation |
| for more information. |
| skip_transpose : bool |
| If False, uses batch x time x channel convention of speechbrain. |
| If True, uses batch x channel x time convention. |
| weight_norm : bool |
| If True, use weight normalization, |
| to be removed with self.remove_weight_norm() at inference |
| conv_init : str |
| Weight initialization for the convolution network |
| default_padding: str or int |
| This sets the default padding mode that will be used by the pytorch Conv1d backend. |
| |
| Example |
| ------- |
| >>> inp_tensor = torch.rand([10, 40, 16]) |
| >>> cnn_1d = Conv1d( |
| ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5 |
| ... ) |
| >>> out_tensor = cnn_1d(inp_tensor) |
| >>> out_tensor.shape |
| torch.Size([10, 40, 8]) |
| """ |
|
|
| def __init__( |
| self, |
| out_channels, |
| kernel_size, |
| input_shape=None, |
| in_channels=None, |
| stride=1, |
| dilation=1, |
| padding="same", |
| groups=1, |
| bias=True, |
| padding_mode="reflect", |
| skip_transpose=False, |
| weight_norm=False, |
| conv_init=None, |
| default_padding=0, |
| ): |
| super().__init__() |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.dilation = dilation |
| self.padding = padding |
| self.padding_mode = padding_mode |
| self.unsqueeze = False |
| self.skip_transpose = skip_transpose |
|
|
| if input_shape is None and in_channels is None: |
| raise ValueError("Must provide one of input_shape or in_channels") |
|
|
| if in_channels is None: |
| in_channels = self._check_input_shape(input_shape) |
|
|
| self.in_channels = in_channels |
|
|
| self.conv = nn.Conv1d( |
| in_channels, |
| out_channels, |
| self.kernel_size, |
| stride=self.stride, |
| dilation=self.dilation, |
| padding=default_padding, |
| groups=groups, |
| bias=bias, |
| ) |
|
|
| if conv_init == "kaiming": |
| nn.init.kaiming_normal_(self.conv.weight) |
| elif conv_init == "zero": |
| nn.init.zeros_(self.conv.weight) |
| elif conv_init == "normal": |
| nn.init.normal_(self.conv.weight, std=1e-6) |
|
|
| if weight_norm: |
| self.conv = nn.utils.weight_norm(self.conv) |
|
|
| def forward(self, x, *args, **kwargs): |
| """Returns the output of the convolution. |
| |
| Arguments |
| --------- |
| x : torch.Tensor (batch, time, channel) |
| input to convolve. 2d or 4d tensors are expected. |
| |
| Returns |
| ------- |
| wx : torch.Tensor |
| The convolved outputs. |
| """ |
| if not self.skip_transpose: |
| x = x.transpose(1, -1) |
|
|
| if self.unsqueeze: |
| x = x.unsqueeze(1) |
|
|
| if self.padding == "same": |
| x = self._manage_padding( |
| x, self.kernel_size, self.dilation, self.stride |
| ) |
|
|
| elif self.padding == "causal": |
| num_pad = (self.kernel_size - 1) * self.dilation |
| x = F.pad(x, (num_pad, 0)) |
|
|
| elif self.padding == "valid": |
| pass |
|
|
| else: |
| raise ValueError( |
| "Padding must be 'same', 'valid' or 'causal'. Got " |
| + self.padding |
| ) |
|
|
| wx = self.conv(x) |
|
|
| if self.unsqueeze: |
| wx = wx.squeeze(1) |
|
|
| if not self.skip_transpose: |
| wx = wx.transpose(1, -1) |
|
|
| return wx |
|
|
| def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): |
| """This function performs zero-padding on the time axis |
| such that their lengths is unchanged after the convolution. |
| |
| Arguments |
| --------- |
| x : torch.Tensor |
| Input tensor. |
| kernel_size : int |
| Size of kernel. |
| dilation : int |
| Dilation used. |
| stride : int |
| Stride. |
| |
| Returns |
| ------- |
| x : torch.Tensor |
| The padded outputs. |
| """ |
|
|
| |
| L_in = self.in_channels |
|
|
| |
| padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
| |
| x = F.pad(x, padding, mode=self.padding_mode) |
|
|
| return x |
|
|
| def _check_input_shape(self, shape): |
| """Checks the input shape and returns the number of input channels.""" |
|
|
| if len(shape) == 2: |
| self.unsqueeze = True |
| in_channels = 1 |
| elif self.skip_transpose: |
| in_channels = shape[1] |
| elif len(shape) == 3: |
| in_channels = shape[2] |
| else: |
| raise ValueError( |
| "conv1d expects 2d, 3d inputs. Got " + str(len(shape)) |
| ) |
|
|
| |
| if not self.padding == "valid" and self.kernel_size % 2 == 0: |
| raise ValueError( |
| "The field kernel size must be an odd number. Got %s." |
| % (self.kernel_size) |
| ) |
|
|
| return in_channels |
|
|
| def remove_weight_norm(self): |
| """Removes weight normalization at inference if used during training.""" |
| self.conv = nn.utils.remove_weight_norm(self.conv) |
|
|
|
|
| def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
| """This function computes the number of elements to add for zero-padding. |
| |
| Arguments |
| --------- |
| L_in : int |
| stride: int |
| kernel_size : int |
| dilation : int |
| |
| Returns |
| ------- |
| padding : int |
| The size of the padding to be added |
| """ |
| if stride > 1: |
| padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
|
|
| else: |
| L_out = ( |
| math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
| ) |
| padding = [ |
| math.floor((L_in - L_out) / 2), |
| math.floor((L_in - L_out) / 2), |
| ] |
| return padding |