Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from librosa.util import pad_center | |
| from scipy.signal import get_window | |
| class TorchSTFT(nn.Module): | |
| def __init__( | |
| self, filter_length=1024, hop_length=512, win_length=None, window="hann" | |
| ): | |
| """ | |
| This module implements an STFT using PyTorch's stft function. | |
| Keyword Arguments: | |
| filter_length {int} -- Length of filters used (default: {1024}) | |
| hop_length {int} -- Hop length of STFT (default: {512}) | |
| win_length {[type]} -- Length of the window function applied to each frame (if not specified, it | |
| equals the filter length). (default: {None}) | |
| window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) | |
| (default: {'hann'}) | |
| """ | |
| super(TorchSTFT, self).__init__() | |
| self.n_fft_new = filter_length | |
| self.hop_length_new = hop_length | |
| self.win_length_new = win_length if win_length else filter_length | |
| self.center = True | |
| hann_window_0 = torch.hann_window(self.win_length_new) | |
| self.register_buffer("hann_window_0", hann_window_0, persistent=False) | |
| def forward(self, input_data): | |
| fft = torch.stft( | |
| input_data, | |
| n_fft=self.n_fft_new, | |
| hop_length=self.hop_length_new, | |
| win_length=self.win_length_new, | |
| window=self.hann_window_0, | |
| center=self.center, | |
| return_complex=True, | |
| ) | |
| magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) | |
| return magnitude | |
| class STFT(nn.Module): | |
| def __init__( | |
| self, filter_length=1024, hop_length=512, win_length=None, window="hann" | |
| ): | |
| """ | |
| This module implements an STFT using 1D convolution and 1D transpose convolutions. | |
| This is a bit tricky so there are some cases that probably won't work as working | |
| out the same sizes before and after in all overlap add setups is tough. Right now, | |
| this code should work with hop lengths that are half the filter length (50% overlap | |
| between frames). | |
| Keyword Arguments: | |
| filter_length {int} -- Length of filters used (default: {1024}) | |
| hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) | |
| win_length {[type]} -- Length of the window function applied to each frame (if not specified, it | |
| equals the filter length). (default: {None}) | |
| window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) | |
| (default: {'hann'}) | |
| """ | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.win_length = win_length if win_length else filter_length | |
| self.window = window | |
| self.forward_transform = None | |
| self.pad_amount = int(self.filter_length / 2) | |
| fourier_basis = np.fft.fft(np.eye(self.filter_length)) | |
| cutoff = int((self.filter_length / 2 + 1)) | |
| fourier_basis = np.vstack( | |
| [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] | |
| ) | |
| forward_basis = torch.FloatTensor(fourier_basis) | |
| inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis)) | |
| assert filter_length >= self.win_length | |
| # get window and zero center pad it to filter_length | |
| fft_window = get_window(window, self.win_length, fftbins=True) | |
| fft_window = pad_center(fft_window, size=filter_length) | |
| fft_window = torch.from_numpy(fft_window).float() | |
| # window the bases | |
| forward_basis *= fft_window | |
| inverse_basis = (inverse_basis.T * fft_window).T | |
| self.register_buffer("forward_basis", forward_basis.float(), persistent=False) | |
| self.register_buffer("inverse_basis", inverse_basis.float(), persistent=False) | |
| self.register_buffer("fft_window", fft_window.float(), persistent=False) | |
| def forward(self, input_data): | |
| """Take input data (audio) to STFT domain using convolution.""" | |
| input_data = F.pad( | |
| input_data, | |
| (self.pad_amount, self.pad_amount), | |
| mode="reflect", | |
| ) | |
| # Reshape input for convolution | |
| input_data = input_data.unsqueeze(1) | |
| # Create windowed basis as convolution weights | |
| forward_transform = F.conv1d( | |
| input_data, | |
| self.forward_basis.unsqueeze(1), | |
| stride=self.hop_length, | |
| groups=1, | |
| ) | |
| cutoff = int((self.filter_length / 2) + 1) | |
| real_part = forward_transform[:, :cutoff, :] | |
| imag_part = forward_transform[:, cutoff:, :] | |
| magnitude = torch.sqrt(real_part**2 + imag_part**2) | |
| return magnitude | |