| """Library implementing convolutional neural networks. |
| |
| Authors |
| * Mirco Ravanelli 2020 |
| * Jianyuan Zhong 2020 |
| * Cem Subakan 2021 |
| * Davide Borra 2021 |
| * Andreas Nautsch 2022 |
| * Sarthak Yadav 2022 |
| """ |
|
|
| import logging |
| import math |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchaudio |
|
|
| class SincConv(nn.Module): |
| """This function implements SincConv (SincNet). |
| |
| M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with |
| SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158) |
| |
| Arguments |
| --------- |
| out_channels : int |
| It is the number of output channels. |
| kernel_size: int |
| Kernel size of the convolutional filters. |
| input_shape : tuple |
| The shape of the input. Alternatively use ``in_channels``. |
| in_channels : int |
| The number of input channels. Alternatively use ``input_shape``. |
| stride : int |
| Stride factor of the convolutional filters. When the stride factor > 1, |
| a decimation in time is performed. |
| dilation : int |
| Dilation factor of the convolutional filters. |
| padding : str |
| (same, valid, causal). If "valid", no padding is performed. |
| If "same" and stride is 1, output shape is the same as the input shape. |
| "causal" results in causal (dilated) convolutions. |
| padding_mode : str |
| This flag specifies the type of padding. See torch.nn documentation |
| for more information. |
| sample_rate : int |
| Sampling rate of the input signals. It is only used for sinc_conv. |
| min_low_hz : float |
| Lowest possible frequency (in Hz) for a filter. It is only used for |
| sinc_conv. |
| min_band_hz : float |
| Lowest possible value (in Hz) for a filter bandwidth. |
| |
| Example |
| ------- |
| >>> inp_tensor = torch.rand([10, 16000]) |
| >>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11) |
| >>> out_tensor = conv(inp_tensor) |
| >>> out_tensor.shape |
| torch.Size([10, 16000, 25]) |
| """ |
|
|
| def __init__( |
| self, |
| out_channels, |
| kernel_size, |
| input_shape=None, |
| in_channels=None, |
| stride=1, |
| dilation=1, |
| padding="same", |
| padding_mode="reflect", |
| sample_rate=16000, |
| min_low_hz=50, |
| min_band_hz=50, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.dilation = dilation |
| self.padding = padding |
| self.padding_mode = padding_mode |
| self.sample_rate = sample_rate |
| self.min_low_hz = min_low_hz |
| self.min_band_hz = min_band_hz |
|
|
| |
| if input_shape is None and self.in_channels is None: |
| raise ValueError("Must provide one of input_shape or in_channels") |
|
|
| if self.in_channels is None: |
| self.in_channels = self._check_input_shape(input_shape) |
|
|
| if self.out_channels % self.in_channels != 0: |
| raise ValueError( |
| "Number of output channels must be divisible by in_channels" |
| ) |
|
|
| |
| self._init_sinc_conv() |
|
|
| def forward(self, x): |
| """Returns the output of the convolution. |
| |
| Arguments |
| --------- |
| x : torch.Tensor (batch, time, channel) |
| input to convolve. 2d or 4d tensors are expected. |
| |
| Returns |
| ------- |
| wx : torch.Tensor |
| The convolved outputs. |
| """ |
| x = x.transpose(1, -1) |
| self.device = x.device |
|
|
| unsqueeze = x.ndim == 2 |
| if unsqueeze: |
| x = x.unsqueeze(1) |
|
|
| if self.padding == "same": |
| x = self._manage_padding( |
| x, self.kernel_size, self.dilation, self.stride |
| ) |
|
|
| elif self.padding == "causal": |
| num_pad = (self.kernel_size - 1) * self.dilation |
| x = F.pad(x, (num_pad, 0)) |
|
|
| elif self.padding == "valid": |
| pass |
|
|
| else: |
| raise ValueError( |
| "Padding must be 'same', 'valid' or 'causal'. Got %s." |
| % (self.padding) |
| ) |
|
|
| sinc_filters = self._get_sinc_filters() |
|
|
| wx = F.conv1d( |
| x, |
| sinc_filters, |
| stride=self.stride, |
| padding=0, |
| dilation=self.dilation, |
| groups=self.in_channels, |
| ) |
|
|
| if unsqueeze: |
| wx = wx.squeeze(1) |
|
|
| wx = wx.transpose(1, -1) |
|
|
| return wx |
|
|
| def _check_input_shape(self, shape): |
| """Checks the input shape and returns the number of input channels.""" |
|
|
| if len(shape) == 2: |
| in_channels = 1 |
| elif len(shape) == 3: |
| in_channels = shape[-1] |
| else: |
| raise ValueError( |
| "sincconv expects 2d or 3d inputs. Got " + str(len(shape)) |
| ) |
|
|
| |
| if self.kernel_size % 2 == 0: |
| raise ValueError( |
| "The field kernel size must be an odd number. Got %s." |
| % (self.kernel_size) |
| ) |
| return in_channels |
|
|
| def _get_sinc_filters(self): |
| """This functions creates the sinc-filters to used for sinc-conv.""" |
| |
| low = self.min_low_hz + torch.abs(self.low_hz_) |
|
|
| |
| high = torch.clamp( |
| low + self.min_band_hz + torch.abs(self.band_hz_), |
| self.min_low_hz, |
| self.sample_rate / 2, |
| ) |
| band = (high - low)[:, 0] |
|
|
| |
| self.n_ = self.n_.to(self.device) |
| self.window_ = self.window_.to(self.device) |
| f_times_t_low = torch.matmul(low, self.n_) |
| f_times_t_high = torch.matmul(high, self.n_) |
|
|
| |
| band_pass_left = ( |
| (torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) |
| / (self.n_ / 2) |
| ) * self.window_ |
|
|
| |
| band_pass_center = 2 * band.view(-1, 1) |
|
|
| |
| band_pass_right = torch.flip(band_pass_left, dims=[1]) |
|
|
| |
| band_pass = torch.cat( |
| [band_pass_left, band_pass_center, band_pass_right], dim=1 |
| ) |
|
|
| |
| band_pass = band_pass / (2 * band[:, None]) |
|
|
| |
| filters = band_pass.view(self.out_channels, 1, self.kernel_size) |
|
|
| return filters |
|
|
| def _init_sinc_conv(self): |
| """Initializes the parameters of the sinc_conv layer.""" |
|
|
| |
| high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) |
|
|
| mel = torch.linspace( |
| self._to_mel(self.min_low_hz), |
| self._to_mel(high_hz), |
| self.out_channels + 1, |
| ) |
|
|
| hz = self._to_hz(mel) |
|
|
| |
| self.low_hz_ = hz[:-1].unsqueeze(1) |
| self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1) |
|
|
| |
| self.low_hz_ = nn.Parameter(self.low_hz_) |
| self.band_hz_ = nn.Parameter(self.band_hz_) |
|
|
| |
| n_lin = torch.linspace( |
| 0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)) |
| ) |
| self.window_ = 0.54 - 0.46 * torch.cos( |
| 2 * math.pi * n_lin / self.kernel_size |
| ) |
|
|
| |
| n = (self.kernel_size - 1) / 2.0 |
| self.n_ = ( |
| 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate |
| ) |
|
|
| def _to_mel(self, hz): |
| """Converts frequency in Hz to the mel scale.""" |
| return 2595 * np.log10(1 + hz / 700) |
|
|
| def _to_hz(self, mel): |
| """Converts frequency in the mel scale to Hz.""" |
| return 700 * (10 ** (mel / 2595) - 1) |
|
|
| def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): |
| """This function performs zero-padding on the time axis |
| such that their lengths is unchanged after the convolution. |
| |
| Arguments |
| --------- |
| x : torch.Tensor |
| Input tensor. |
| kernel_size : int |
| Size of kernel. |
| dilation : int |
| Dilation used. |
| stride : int |
| Stride. |
| |
| Returns |
| ------- |
| x : torch.Tensor |
| """ |
|
|
| |
| L_in = self.in_channels |
|
|
| |
| padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
| |
| x = F.pad(x, padding, mode=self.padding_mode) |
|
|
| return x |
|
|
|
|
| class Conv1d(nn.Module): |
| """This function implements 1d convolution. |
| |
| Arguments |
| --------- |
| out_channels : int |
| It is the number of output channels. |
| kernel_size : int |
| Kernel size of the convolutional filters. |
| input_shape : tuple |
| The shape of the input. Alternatively use ``in_channels``. |
| in_channels : int |
| The number of input channels. Alternatively use ``input_shape``. |
| stride : int |
| Stride factor of the convolutional filters. When the stride factor > 1, |
| a decimation in time is performed. |
| dilation : int |
| Dilation factor of the convolutional filters. |
| padding : str |
| (same, valid, causal). If "valid", no padding is performed. |
| If "same" and stride is 1, output shape is the same as the input shape. |
| "causal" results in causal (dilated) convolutions. |
| groups : int |
| Number of blocked connections from input channels to output channels. |
| bias : bool |
| Whether to add a bias term to convolution operation. |
| padding_mode : str |
| This flag specifies the type of padding. See torch.nn documentation |
| for more information. |
| skip_transpose : bool |
| If False, uses batch x time x channel convention of speechbrain. |
| If True, uses batch x channel x time convention. |
| weight_norm : bool |
| If True, use weight normalization, |
| to be removed with self.remove_weight_norm() at inference |
| conv_init : str |
| Weight initialization for the convolution network |
| default_padding: str or int |
| This sets the default padding mode that will be used by the pytorch Conv1d backend. |
| |
| Example |
| ------- |
| >>> inp_tensor = torch.rand([10, 40, 16]) |
| >>> cnn_1d = Conv1d( |
| ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5 |
| ... ) |
| >>> out_tensor = cnn_1d(inp_tensor) |
| >>> out_tensor.shape |
| torch.Size([10, 40, 8]) |
| """ |
|
|
| def __init__( |
| self, |
| out_channels, |
| kernel_size, |
| input_shape=None, |
| in_channels=None, |
| stride=1, |
| dilation=1, |
| padding="same", |
| groups=1, |
| bias=True, |
| padding_mode="reflect", |
| skip_transpose=False, |
| weight_norm=False, |
| conv_init=None, |
| default_padding=0, |
| ): |
| super().__init__() |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.dilation = dilation |
| self.padding = padding |
| self.padding_mode = padding_mode |
| self.unsqueeze = False |
| self.skip_transpose = skip_transpose |
|
|
| if input_shape is None and in_channels is None: |
| raise ValueError("Must provide one of input_shape or in_channels") |
|
|
| if in_channels is None: |
| in_channels = self._check_input_shape(input_shape) |
|
|
| self.in_channels = in_channels |
|
|
| self.conv = nn.Conv1d( |
| in_channels, |
| out_channels, |
| self.kernel_size, |
| stride=self.stride, |
| dilation=self.dilation, |
| padding=default_padding, |
| groups=groups, |
| bias=bias, |
| ) |
|
|
| if conv_init == "kaiming": |
| nn.init.kaiming_normal_(self.conv.weight) |
| elif conv_init == "zero": |
| nn.init.zeros_(self.conv.weight) |
| elif conv_init == "normal": |
| nn.init.normal_(self.conv.weight, std=1e-6) |
|
|
| if weight_norm: |
| self.conv = nn.utils.weight_norm(self.conv) |
|
|
| def forward(self, x): |
| """Returns the output of the convolution. |
| |
| Arguments |
| --------- |
| x : torch.Tensor (batch, time, channel) |
| input to convolve. 2d or 4d tensors are expected. |
| |
| Returns |
| ------- |
| wx : torch.Tensor |
| The convolved outputs. |
| """ |
| if not self.skip_transpose: |
| x = x.transpose(1, -1) |
|
|
| if self.unsqueeze: |
| x = x.unsqueeze(1) |
|
|
| if self.padding == "same": |
| x = self._manage_padding( |
| x, self.kernel_size, self.dilation, self.stride |
| ) |
|
|
| elif self.padding == "causal": |
| num_pad = (self.kernel_size - 1) * self.dilation |
| x = F.pad(x, (num_pad, 0)) |
|
|
| elif self.padding == "valid": |
| pass |
|
|
| else: |
| raise ValueError( |
| "Padding must be 'same', 'valid' or 'causal'. Got " |
| + self.padding |
| ) |
|
|
| wx = self.conv(x) |
|
|
| if self.unsqueeze: |
| wx = wx.squeeze(1) |
|
|
| if not self.skip_transpose: |
| wx = wx.transpose(1, -1) |
|
|
| return wx |
|
|
| def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): |
| """This function performs zero-padding on the time axis |
| such that their lengths is unchanged after the convolution. |
| |
| Arguments |
| --------- |
| x : torch.Tensor |
| Input tensor. |
| kernel_size : int |
| Size of kernel. |
| dilation : int |
| Dilation used. |
| stride : int |
| Stride. |
| |
| Returns |
| ------- |
| x : torch.Tensor |
| The padded outputs. |
| """ |
|
|
| |
| L_in = self.in_channels |
|
|
| |
| padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
| |
| x = F.pad(x, padding, mode=self.padding_mode) |
|
|
| return x |
|
|
| def _check_input_shape(self, shape): |
| """Checks the input shape and returns the number of input channels.""" |
|
|
| if len(shape) == 2: |
| self.unsqueeze = True |
| in_channels = 1 |
| elif self.skip_transpose: |
| in_channels = shape[1] |
| elif len(shape) == 3: |
| in_channels = shape[2] |
| else: |
| raise ValueError( |
| "conv1d expects 2d, 3d inputs. Got " + str(len(shape)) |
| ) |
|
|
| |
| if not self.padding == "valid" and self.kernel_size % 2 == 0: |
| raise ValueError( |
| "The field kernel size must be an odd number. Got %s." |
| % (self.kernel_size) |
| ) |
|
|
| return in_channels |
|
|
| def remove_weight_norm(self): |
| """Removes weight normalization at inference if used during training.""" |
| self.conv = nn.utils.remove_weight_norm(self.conv) |
|
|
|
|
| def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
| """This function computes the number of elements to add for zero-padding. |
| |
| Arguments |
| --------- |
| L_in : int |
| stride: int |
| kernel_size : int |
| dilation : int |
| |
| Returns |
| ------- |
| padding : int |
| The size of the padding to be added |
| """ |
| if stride > 1: |
| padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
|
|
| else: |
| L_out = ( |
| math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
| ) |
| padding = [ |
| math.floor((L_in - L_out) / 2), |
| math.floor((L_in - L_out) / 2), |
| ] |
| return padding |
|
|
|
|