| | import math |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class Deltas(torch.nn.Module): |
| | """Computes delta coefficients (time derivatives). |
| | Arguments |
| | --------- |
| | win_length : int |
| | Length of the window used to compute the time derivatives. |
| | Example |
| | ------- |
| | >>> inputs = torch.randn([10, 101, 20]) |
| | >>> compute_deltas = Deltas(input_size=inputs.size(-1)) |
| | >>> features = compute_deltas(inputs) |
| | >>> features.shape |
| | torch.Size([10, 101, 20]) |
| | """ |
| |
|
| | def __init__( |
| | self, input_size, window_length=5, |
| | ): |
| | super().__init__() |
| | self.n = (window_length - 1) // 2 |
| | self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3 |
| |
|
| | self.register_buffer( |
| | "kernel", |
| | torch.arange(-self.n, self.n + 1, dtype=torch.float32,).repeat( |
| | input_size, 1, 1 |
| | ), |
| | ) |
| |
|
| | def forward(self, x): |
| | """Returns the delta coefficients. |
| | Arguments |
| | --------- |
| | x : tensor |
| | A batch of tensors. |
| | """ |
| | |
| | x = x.transpose(1, 2).transpose(2, -1) |
| | or_shape = x.shape |
| | if len(or_shape) == 4: |
| | x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) |
| |
|
| | |
| | x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate") |
| |
|
| | |
| | delta_coeff = ( |
| | torch.nn.functional.conv1d( |
| | x, self.kernel.to(x.device), groups=x.shape[1] |
| | ) |
| | / self.denom |
| | ) |
| |
|
| | |
| | if len(or_shape) == 4: |
| | delta_coeff = delta_coeff.reshape( |
| | or_shape[0], or_shape[1], or_shape[2], or_shape[3], |
| | ) |
| | delta_coeff = delta_coeff.transpose(1, -1).transpose(2, -1) |
| |
|
| | return delta_coeff |
| |
|
| |
|
| | class Filterbank(torch.nn.Module): |
| | """computes filter bank (FBANK) features given spectral magnitudes. |
| | Arguments |
| | --------- |
| | n_mels : float |
| | Number of Mel filters used to average the spectrogram. |
| | log_mel : bool |
| | If True, it computes the log of the FBANKs. |
| | filter_shape : str |
| | Shape of the filters ('triangular', 'rectangular', 'gaussian'). |
| | f_min : int |
| | Lowest frequency for the Mel filters. |
| | f_max : int |
| | Highest frequency for the Mel filters. |
| | n_fft : int |
| | Number of fft points of the STFT. It defines the frequency resolution |
| | (n_fft should be<= than win_len). |
| | sample_rate : int |
| | Sample rate of the input audio signal (e.g, 16000) |
| | power_spectrogram : float |
| | Exponent used for spectrogram computation. |
| | amin : float |
| | Minimum amplitude (used for numerical stability). |
| | ref_value : float |
| | Reference value used for the dB scale. |
| | top_db : float |
| | Minimum negative cut-off in decibels. |
| | freeze : bool |
| | If False, it the central frequency and the band of each filter are |
| | added into nn.parameters. If True, the standard frozen features |
| | are computed. |
| | param_change_factor: bool |
| | If freeze=False, this parameter affects the speed at which the filter |
| | parameters (i.e., central_freqs and bands) can be changed. When high |
| | (e.g., param_change_factor=1) the filters change a lot during training. |
| | When low (e.g. param_change_factor=0.1) the filter parameters are more |
| | stable during training |
| | param_rand_factor: float |
| | This parameter can be used to randomly change the filter parameters |
| | (i.e, central frequencies and bands) during training. It is thus a |
| | sort of regularization. param_rand_factor=0 does not affect, while |
| | param_rand_factor=0.15 allows random variations within +-15% of the |
| | standard values of the filter parameters (e.g., if the central freq |
| | is 100 Hz, we can randomly change it from 85 Hz to 115 Hz). |
| | Example |
| | ------- |
| | >>> import torch |
| | >>> compute_fbanks = Filterbank() |
| | >>> inputs = torch.randn([10, 101, 201]) |
| | >>> features = compute_fbanks(inputs) |
| | >>> features.shape |
| | torch.Size([10, 101, 40]) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_mels=40, |
| | log_mel=True, |
| | filter_shape="triangular", |
| | f_min=0, |
| | f_max=8000, |
| | n_fft=400, |
| | sample_rate=16000, |
| | power_spectrogram=2, |
| | amin=1e-10, |
| | ref_value=1.0, |
| | top_db=80.0, |
| | param_change_factor=1.0, |
| | param_rand_factor=0.0, |
| | freeze=True, |
| | ): |
| | super().__init__() |
| | self.n_mels = n_mels |
| | self.log_mel = log_mel |
| | self.filter_shape = filter_shape |
| | self.f_min = f_min |
| | self.f_max = f_max |
| | self.n_fft = n_fft |
| | self.sample_rate = sample_rate |
| | self.power_spectrogram = power_spectrogram |
| | self.amin = amin |
| | self.ref_value = ref_value |
| | self.top_db = top_db |
| | self.freeze = freeze |
| | self.n_stft = self.n_fft // 2 + 1 |
| | self.db_multiplier = math.log10(max(self.amin, self.ref_value)) |
| | self.device_inp = torch.device("cpu") |
| | self.param_change_factor = param_change_factor |
| | self.param_rand_factor = param_rand_factor |
| |
|
| | if self.power_spectrogram == 2: |
| | self.multiplier = 10 |
| | else: |
| | self.multiplier = 20 |
| |
|
| | |
| | if self.f_min >= self.f_max: |
| | err_msg = "Require f_min: %f < f_max: %f" % ( |
| | self.f_min, |
| | self.f_max, |
| | ) |
| | print(err_msg) |
| |
|
| | |
| | mel = torch.linspace( |
| | self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2 |
| | ) |
| | hz = self._to_hz(mel) |
| |
|
| | |
| | band = hz[1:] - hz[:-1] |
| | self.band = band[:-1] |
| | self.f_central = hz[1:-1] |
| |
|
| | |
| | if not self.freeze: |
| | self.f_central = torch.nn.Parameter( |
| | self.f_central / (self.sample_rate * self.param_change_factor) |
| | ) |
| | self.band = torch.nn.Parameter( |
| | self.band / (self.sample_rate * self.param_change_factor) |
| | ) |
| |
|
| | |
| | all_freqs = torch.linspace(0, self.sample_rate // 2, self.n_stft) |
| |
|
| | |
| | self.all_freqs_mat = all_freqs.repeat(self.f_central.shape[0], 1) |
| |
|
| | def forward(self, spectrogram): |
| | """Returns the FBANks. |
| | Arguments |
| | --------- |
| | x : tensor |
| | A batch of spectrogram tensors. |
| | """ |
| | |
| | f_central_mat = self.f_central.repeat( |
| | self.all_freqs_mat.shape[1], 1 |
| | ).transpose(0, 1) |
| | band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose( |
| | 0, 1 |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | if not self.freeze: |
| | f_central_mat = f_central_mat * ( |
| | self.sample_rate |
| | * self.param_change_factor |
| | * self.param_change_factor |
| | ) |
| | band_mat = band_mat * ( |
| | self.sample_rate |
| | * self.param_change_factor |
| | * self.param_change_factor |
| | ) |
| |
|
| | |
| | elif self.param_rand_factor != 0 and self.training: |
| | rand_change = ( |
| | 1.0 |
| | + torch.rand(2) * 2 * self.param_rand_factor |
| | - self.param_rand_factor |
| | ) |
| | f_central_mat = f_central_mat * rand_change[0] |
| | band_mat = band_mat * rand_change[1] |
| |
|
| | fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to( |
| | spectrogram.device |
| | ) |
| |
|
| | sp_shape = spectrogram.shape |
| |
|
| | |
| | if len(sp_shape) == 4: |
| | spectrogram = spectrogram.permute(0, 3, 1, 2) |
| | spectrogram = spectrogram.reshape( |
| | sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2] |
| | ) |
| |
|
| | |
| | fbanks = torch.matmul(spectrogram, fbank_matrix) |
| | if self.log_mel: |
| | fbanks = self._amplitude_to_DB(fbanks) |
| |
|
| | |
| | if len(sp_shape) == 4: |
| | fb_shape = fbanks.shape |
| | fbanks = fbanks.reshape( |
| | sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2] |
| | ) |
| | fbanks = fbanks.permute(0, 2, 3, 1) |
| |
|
| | return fbanks |
| |
|
| | @staticmethod |
| | def _to_mel(hz): |
| | """Returns mel-frequency value corresponding to the input |
| | frequency value in Hz. |
| | Arguments |
| | --------- |
| | x : float |
| | The frequency point in Hz. |
| | """ |
| | return 2595 * math.log10(1 + hz / 700) |
| |
|
| | @staticmethod |
| | def _to_hz(mel): |
| | """Returns hz-frequency value corresponding to the input |
| | mel-frequency value. |
| | Arguments |
| | --------- |
| | x : float |
| | The frequency point in the mel-scale. |
| | """ |
| | return 700 * (10 ** (mel / 2595) - 1) |
| |
|
| | def _triangular_filters(self, all_freqs, f_central, band): |
| | """Returns fbank matrix using triangular filters. |
| | Arguments |
| | --------- |
| | all_freqs : Tensor |
| | Tensor gathering all the frequency points. |
| | f_central : Tensor |
| | Tensor gathering central frequencies of each filter. |
| | band : Tensor |
| | Tensor gathering the bands of each filter. |
| | """ |
| |
|
| | |
| | slope = (all_freqs - f_central) / band |
| | left_side = slope + 1.0 |
| | right_side = -slope + 1.0 |
| |
|
| | |
| | zero = torch.zeros(1, device=self.device_inp) |
| | fbank_matrix = torch.max( |
| | zero, torch.min(left_side, right_side) |
| | ).transpose(0, 1) |
| |
|
| | return fbank_matrix |
| |
|
| | def _rectangular_filters(self, all_freqs, f_central, band): |
| | """Returns fbank matrix using rectangular filters. |
| | Arguments |
| | --------- |
| | all_freqs : Tensor |
| | Tensor gathering all the frequency points. |
| | f_central : Tensor |
| | Tensor gathering central frequencies of each filter. |
| | band : Tensor |
| | Tensor gathering the bands of each filter. |
| | """ |
| |
|
| | |
| | low_hz = f_central - band |
| | high_hz = f_central + band |
| |
|
| | |
| | left_side = right_size = all_freqs.ge(low_hz) |
| | right_size = all_freqs.le(high_hz) |
| |
|
| | fbank_matrix = (left_side * right_size).float().transpose(0, 1) |
| |
|
| | return fbank_matrix |
| |
|
| | def _gaussian_filters( |
| | self, all_freqs, f_central, band, smooth_factor=torch.tensor(2) |
| | ): |
| | """Returns fbank matrix using gaussian filters. |
| | Arguments |
| | --------- |
| | all_freqs : Tensor |
| | Tensor gathering all the frequency points. |
| | f_central : Tensor |
| | Tensor gathering central frequencies of each filter. |
| | band : Tensor |
| | Tensor gathering the bands of each filter. |
| | smooth_factor: Tensor |
| | Smoothing factor of the gaussian filter. It can be used to employ |
| | sharper or flatter filters. |
| | """ |
| | fbank_matrix = torch.exp( |
| | -0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2 |
| | ).transpose(0, 1) |
| |
|
| | return fbank_matrix |
| |
|
| | def _create_fbank_matrix(self, f_central_mat, band_mat): |
| | """Returns fbank matrix to use for averaging the spectrum with |
| | the set of filter-banks. |
| | Arguments |
| | --------- |
| | f_central : Tensor |
| | Tensor gathering central frequencies of each filter. |
| | band : Tensor |
| | Tensor gathering the bands of each filter. |
| | smooth_factor: Tensor |
| | Smoothing factor of the gaussian filter. It can be used to employ |
| | sharper or flatter filters. |
| | """ |
| | if self.filter_shape == "triangular": |
| | fbank_matrix = self._triangular_filters( |
| | self.all_freqs_mat, f_central_mat, band_mat |
| | ) |
| |
|
| | elif self.filter_shape == "rectangular": |
| | fbank_matrix = self._rectangular_filters( |
| | self.all_freqs_mat, f_central_mat, band_mat |
| | ) |
| |
|
| | else: |
| | fbank_matrix = self._gaussian_filters( |
| | self.all_freqs_mat, f_central_mat, band_mat |
| | ) |
| |
|
| | return fbank_matrix |
| |
|
| | def _amplitude_to_DB(self, x): |
| | """Converts linear-FBANKs to log-FBANKs. |
| | Arguments |
| | --------- |
| | x : Tensor |
| | A batch of linear FBANK tensors. |
| | """ |
| |
|
| | x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin)) |
| | x_db -= self.multiplier * self.db_multiplier |
| |
|
| | |
| | |
| | new_x_db_max = x_db.amax(dim=(-2, -1)) - self.top_db |
| |
|
| | |
| | |
| | x_db = torch.max(x_db, new_x_db_max.view(x_db.shape[0], 1, 1)) |
| |
|
| | return x_db |
| |
|
| |
|
| | class STFT(torch.nn.Module): |
| | """computes the Short-Term Fourier Transform (STFT). |
| | This class computes the Short-Term Fourier Transform of an audio signal. |
| | It supports multi-channel audio inputs (batch, time, channels). |
| | Arguments |
| | --------- |
| | sample_rate : int |
| | Sample rate of the input audio signal (e.g 16000). |
| | win_length : float |
| | Length (in ms) of the sliding window used to compute the STFT. |
| | hop_length : float |
| | Length (in ms) of the hope of the sliding window used to compute |
| | the STFT. |
| | n_fft : int |
| | Number of fft point of the STFT. It defines the frequency resolution |
| | (n_fft should be <= than win_len). |
| | window_fn : function |
| | A function that takes an integer (number of samples) and outputs a |
| | tensor to be multiplied with each window before fft. |
| | normalized_stft : bool |
| | If True, the function returns the normalized STFT results, |
| | i.e., multiplied by win_length^-0.5 (default is False). |
| | center : bool |
| | If True (default), the input will be padded on both sides so that the |
| | t-th frame is centered at time t×hop_length. Otherwise, the t-th frame |
| | begins at time t×hop_length. |
| | pad_mode : str |
| | It can be 'constant','reflect','replicate', 'circular', 'reflect' |
| | (default). 'constant' pads the input tensor boundaries with a |
| | constant value. 'reflect' pads the input tensor using the reflection |
| | of the input boundary. 'replicate' pads the input tensor using |
| | replication of the input boundary. 'circular' pads using circular |
| | replication. |
| | onesided : True |
| | If True (default) only returns nfft/2 values. Note that the other |
| | samples are redundant due to the Fourier transform conjugate symmetry. |
| | Example |
| | ------- |
| | >>> import torch |
| | >>> compute_STFT = STFT( |
| | ... sample_rate=16000, win_length=25, hop_length=10, n_fft=400 |
| | ... ) |
| | >>> inputs = torch.randn([10, 16000]) |
| | >>> features = compute_STFT(inputs) |
| | >>> features.shape |
| | torch.Size([10, 101, 201, 2]) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | sample_rate, |
| | win_length=25, |
| | hop_length=10, |
| | n_fft=400, |
| | window_fn=torch.hamming_window, |
| | normalized_stft=False, |
| | center=True, |
| | pad_mode="constant", |
| | onesided=True, |
| | ): |
| | super().__init__() |
| | self.sample_rate = sample_rate |
| | self.win_length = win_length |
| | self.hop_length = hop_length |
| | self.n_fft = n_fft |
| | self.normalized_stft = normalized_stft |
| | self.center = center |
| | self.pad_mode = pad_mode |
| | self.onesided = onesided |
| |
|
| | |
| | self.win_length = int( |
| | round((self.sample_rate / 1000.0) * self.win_length) |
| | ) |
| | self.hop_length = int( |
| | round((self.sample_rate / 1000.0) * self.hop_length) |
| | ) |
| |
|
| | self.window = window_fn(self.win_length) |
| |
|
| | def forward(self, x): |
| | """Returns the STFT generated from the input waveforms. |
| | Arguments |
| | --------- |
| | x : tensor |
| | A batch of audio signals to transform. |
| | """ |
| |
|
| | |
| | or_shape = x.shape |
| | if len(or_shape) == 3: |
| | x = x.transpose(1, 2) |
| | x = x.reshape(or_shape[0] * or_shape[2], or_shape[1]) |
| |
|
| | stft = torch.stft( |
| | x, |
| | self.n_fft, |
| | self.hop_length, |
| | self.win_length, |
| | self.window.to(x.device), |
| | self.center, |
| | self.pad_mode, |
| | self.normalized_stft, |
| | self.onesided, |
| | return_complex=True, |
| | ) |
| |
|
| | stft = torch.view_as_real(stft) |
| |
|
| | |
| | if len(or_shape) == 3: |
| | stft = stft.reshape( |
| | or_shape[0], |
| | or_shape[2], |
| | stft.shape[1], |
| | stft.shape[2], |
| | stft.shape[3], |
| | ) |
| | stft = stft.permute(0, 3, 2, 4, 1) |
| | else: |
| | |
| | stft = stft.transpose(2, 1) |
| |
|
| | return stft |
| |
|
| |
|
| | def spectral_magnitude( |
| | stft, power: int = 1, log: bool = False, eps: float = 1e-14 |
| | ): |
| | """Returns the magnitude of a complex spectrogram. |
| | Arguments |
| | --------- |
| | stft : torch.Tensor |
| | A tensor, output from the stft function. |
| | power : int |
| | What power to use in computing the magnitude. |
| | Use power=1 for the power spectrogram. |
| | Use power=0.5 for the magnitude spectrogram. |
| | log : bool |
| | Whether to apply log to the spectral features. |
| | Example |
| | ------- |
| | >>> a = torch.Tensor([[3, 4]]) |
| | >>> spectral_magnitude(a, power=0.5) |
| | tensor([5.]) |
| | """ |
| | spectr = stft.pow(2).sum(-1) |
| |
|
| | |
| | if power < 1: |
| | spectr = spectr + eps |
| | spectr = spectr.pow(power) |
| |
|
| | if log: |
| | return torch.log(spectr + eps) |
| | return spectr |
| |
|
| |
|
| | class ContextWindow(torch.nn.Module): |
| | """Computes the context window. |
| | This class applies a context window by gathering multiple time steps |
| | in a single feature vector. The operation is performed with a |
| | convolutional layer based on a fixed kernel designed for that. |
| | Arguments |
| | --------- |
| | left_frames : int |
| | Number of left frames (i.e, past frames) to collect. |
| | right_frames : int |
| | Number of right frames (i.e, future frames) to collect. |
| | Example |
| | ------- |
| | >>> import torch |
| | >>> compute_cw = ContextWindow(left_frames=5, right_frames=5) |
| | >>> inputs = torch.randn([10, 101, 20]) |
| | >>> features = compute_cw(inputs) |
| | >>> features.shape |
| | torch.Size([10, 101, 220]) |
| | """ |
| |
|
| | def __init__( |
| | self, left_frames=0, right_frames=0, |
| | ): |
| | super().__init__() |
| | self.left_frames = left_frames |
| | self.right_frames = right_frames |
| | self.context_len = self.left_frames + self.right_frames + 1 |
| | self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1 |
| |
|
| | |
| | self.kernel = torch.eye(self.context_len, self.kernel_len) |
| |
|
| | if self.right_frames > self.left_frames: |
| | lag = self.right_frames - self.left_frames |
| | self.kernel = torch.roll(self.kernel, lag, 1) |
| |
|
| | self.first_call = True |
| |
|
| | def forward(self, x): |
| | """Returns the tensor with the surrounding context. |
| | Arguments |
| | --------- |
| | x : tensor |
| | A batch of tensors. |
| | """ |
| |
|
| | x = x.transpose(1, 2) |
| |
|
| | if self.first_call is True: |
| | self.first_call = False |
| | self.kernel = ( |
| | self.kernel.repeat(x.shape[1], 1, 1) |
| | .view(x.shape[1] * self.context_len, self.kernel_len,) |
| | .unsqueeze(1) |
| | ) |
| |
|
| | |
| | or_shape = x.shape |
| | if len(or_shape) == 4: |
| | x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) |
| |
|
| | |
| | cw_x = torch.nn.functional.conv1d( |
| | x, |
| | self.kernel.to(x.device), |
| | groups=x.shape[1], |
| | padding=max(self.left_frames, self.right_frames), |
| | ) |
| |
|
| | |
| | if len(or_shape) == 4: |
| | cw_x = cw_x.reshape( |
| | or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1] |
| | ) |
| |
|
| | cw_x = cw_x.transpose(1, 2) |
| |
|
| | return cw_x |
| |
|
| |
|
| | class Fbank(torch.nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | deltas=False, |
| | context=False, |
| | requires_grad=False, |
| | sample_rate=16000, |
| | f_min=0, |
| | f_max=None, |
| | n_fft=400, |
| | n_mels=40, |
| | filter_shape="triangular", |
| | param_change_factor=1.0, |
| | param_rand_factor=0.0, |
| | left_frames=5, |
| | right_frames=5, |
| | win_length=25, |
| | hop_length=10, |
| | ): |
| | super().__init__() |
| | self.deltas = deltas |
| | self.context = context |
| | self.requires_grad = requires_grad |
| |
|
| | if f_max is None: |
| | f_max = sample_rate / 2 |
| |
|
| | self.compute_STFT = STFT( |
| | sample_rate=sample_rate, |
| | n_fft=n_fft, |
| | win_length=win_length, |
| | hop_length=hop_length, |
| | ) |
| | self.compute_fbanks = Filterbank( |
| | sample_rate=sample_rate, |
| | n_fft=n_fft, |
| | n_mels=n_mels, |
| | f_min=f_min, |
| | f_max=f_max, |
| | freeze=not requires_grad, |
| | filter_shape=filter_shape, |
| | param_change_factor=param_change_factor, |
| | param_rand_factor=param_rand_factor, |
| | ) |
| | self.compute_deltas = Deltas(input_size=n_mels) |
| | self.context_window = ContextWindow( |
| | left_frames=left_frames, right_frames=right_frames, |
| | ) |
| |
|
| | def forward(self, wav): |
| | """Returns a set of features generated from the input waveforms. |
| | Arguments |
| | --------- |
| | wav : tensor |
| | A batch of audio signals to transform to features. |
| | """ |
| | STFT = self.compute_STFT(wav) |
| | mag = spectral_magnitude(STFT) |
| | fbanks = self.compute_fbanks(mag) |
| | if self.deltas: |
| | delta1 = self.compute_deltas(fbanks) |
| | delta2 = self.compute_deltas(delta1) |
| | fbanks = torch.cat([fbanks, delta1, delta2], dim=2) |
| | if self.context: |
| | fbanks = self.context_window(fbanks) |
| | return fbanks |