Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from scipy.signal import get_window | |
| def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False): | |
| if win_type == "None" or win_type is None: | |
| window = np.ones(win_size) | |
| else: | |
| window = get_window(win_type, win_size, fftbins=True)**0.5 | |
| fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size] | |
| real_kernel = np.real(fourier_basis) | |
| image_kernel = np.imag(fourier_basis) | |
| kernel = np.concatenate([real_kernel, image_kernel], 1).T | |
| if inverse: | |
| kernel = np.linalg.pinv(kernel).T | |
| kernel = kernel * window | |
| kernel = kernel[:, None, :] | |
| result = ( | |
| torch.from_numpy(kernel.astype(np.float32)), | |
| torch.from_numpy(window[None, :, None].astype(np.float32)) | |
| ) | |
| return result | |
| class ConvSTFT(nn.Module): | |
| def __init__(self, | |
| nfft: int, | |
| win_size: int, | |
| hop_size: int, | |
| win_type: str = "hamming", | |
| feature_type: str = "real", | |
| requires_grad: bool = False): | |
| super(ConvSTFT, self).__init__() | |
| if nfft is None: | |
| self.nfft = int(2**np.ceil(np.log2(win_size))) | |
| else: | |
| self.nfft = nfft | |
| kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type) | |
| self.weight = nn.Parameter(kernel, requires_grad=requires_grad) | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.stride = hop_size | |
| self.dim = self.nfft | |
| self.feature_type = feature_type | |
| def forward(self, inputs: torch.Tensor): | |
| if inputs.dim() == 2: | |
| inputs = torch.unsqueeze(inputs, 1) | |
| outputs = F.conv1d(inputs, self.weight, stride=self.stride) | |
| if self.feature_type == "complex": | |
| return outputs | |
| else: | |
| dim = self.dim // 2 + 1 | |
| real = outputs[:, :dim, :] | |
| imag = outputs[:, dim:, :] | |
| mags = torch.sqrt(real**2 + imag**2) | |
| phase = torch.atan2(imag, real) | |
| return mags, phase | |
| class ConviSTFT(nn.Module): | |
| def __init__(self, | |
| win_size: int, | |
| hop_size: int, | |
| nfft: int = None, | |
| win_type: str = "hamming", | |
| feature_type: str = "real", | |
| requires_grad: bool = False): | |
| super(ConviSTFT, self).__init__() | |
| if nfft is None: | |
| self.nfft = int(2**np.ceil(np.log2(win_size))) | |
| else: | |
| self.nfft = nfft | |
| kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True) | |
| self.weight = nn.Parameter(kernel, requires_grad=requires_grad) | |
| self.win_size = win_size | |
| self.hop_size = hop_size | |
| self.win_type = win_type | |
| self.stride = hop_size | |
| self.dim = self.nfft | |
| self.feature_type = feature_type | |
| self.register_buffer("window", window) | |
| self.register_buffer("enframe", torch.eye(win_size)[:, None, :]) | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| phase: torch.Tensor = None): | |
| """ | |
| :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags) | |
| :param phase: torch.Tensor, shape: [b, n//2+1, t] | |
| :return: | |
| """ | |
| if phase is not None: | |
| real = inputs * torch.cos(phase) | |
| imag = inputs * torch.sin(phase) | |
| inputs = torch.cat([real, imag], 1) | |
| outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) | |
| # this is from torch-stft: https://github.com/pseeth/torch-stft | |
| t = self.window.repeat(1, 1, inputs.size(-1))**2 | |
| coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) | |
| outputs = outputs / (coff + 1e-8) | |
| return outputs | |
| def main(): | |
| stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex") | |
| istft = ConviSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex") | |
| mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32) | |
| spec = stft.forward(mixture) | |
| # shape: [batch_size, freq_bins, time_steps] | |
| print(spec.shape) | |
| waveform = istft.forward(spec) | |
| # shape: [batch_size, channels, num_samples] | |
| print(waveform.shape) | |
| return | |
| if __name__ == "__main__": | |
| main() | |