Spaces:
Running
on
Zero
Running
on
Zero
| # Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py | |
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # MIT License | |
| # Copyright (c) 2020 Shimin Zhang | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import torch as th | |
| import torch.nn.functional as F | |
| from scipy.signal import check_COLA, get_window | |
| support_clp_op = None | |
| if th.__version__ >= "1.7.0": | |
| from torch.fft import rfft as fft | |
| support_clp_op = True | |
| else: | |
| from torch import rfft as fft | |
| class STFT(th.nn.Module): | |
| def __init__( | |
| self, | |
| win_len=1024, | |
| win_hop=512, | |
| fft_len=1024, | |
| enframe_mode="continue", | |
| win_type="hann", | |
| win_sqrt=False, | |
| pad_center=True, | |
| ): | |
| """ | |
| Implement of STFT using 1D convolution and 1D transpose convolutions. | |
| Implement of framing the signal in 2 ways, `break` and `continue`. | |
| `break` method is a kaldi-like framing. | |
| `continue` method is a librosa-like framing. | |
| More information about `perfect reconstruction`: | |
| 1. https://ww2.mathworks.cn/help/signal/ref/stft.html | |
| 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html | |
| Args: | |
| win_len (int): Number of points in one frame. Defaults to 1024. | |
| win_hop (int): Number of framing stride. Defaults to 512. | |
| fft_len (int): Number of DFT points. Defaults to 1024. | |
| enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'. | |
| win_type (str, optional): The type of window to create. Defaults to 'hann'. | |
| win_sqrt (bool, optional): using square root window. Defaults to True. | |
| pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True. | |
| """ | |
| super(STFT, self).__init__() | |
| assert enframe_mode in ["break", "continue"] | |
| assert fft_len >= win_len | |
| self.win_len = win_len | |
| self.win_hop = win_hop | |
| self.fft_len = fft_len | |
| self.mode = enframe_mode | |
| self.win_type = win_type | |
| self.win_sqrt = win_sqrt | |
| self.pad_center = pad_center | |
| self.pad_amount = self.fft_len // 2 | |
| en_k, fft_k, ifft_k, ola_k = self.__init_kernel__() | |
| self.register_buffer("en_k", en_k) | |
| self.register_buffer("fft_k", fft_k) | |
| self.register_buffer("ifft_k", ifft_k) | |
| self.register_buffer("ola_k", ola_k) | |
| def __init_kernel__(self): | |
| """ | |
| Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel. | |
| ** enframe_kernel: Using conv1d layer and identity matrix. | |
| ** fft_kernel: Using linear layer for matrix multiplication. In fact, | |
| enframe_kernel and fft_kernel can be combined, But for the sake of | |
| readability, I took the two apart. | |
| ** ifft_kernel, pinv of fft_kernel. | |
| ** overlap-add kernel, just like enframe_kernel, but transposed. | |
| Returns: | |
| tuple: four kernels. | |
| """ | |
| enframed_kernel = th.eye(self.fft_len)[:, None, :] | |
| if support_clp_op: | |
| tmp = fft(th.eye(self.fft_len)) | |
| fft_kernel = th.stack([tmp.real, tmp.imag], dim=2) | |
| else: | |
| fft_kernel = fft(th.eye(self.fft_len), 1) | |
| if self.mode == "break": | |
| enframed_kernel = th.eye(self.win_len)[:, None, :] | |
| fft_kernel = fft_kernel[: self.win_len] | |
| fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) | |
| ifft_kernel = th.pinverse(fft_kernel)[:, None, :] | |
| window = get_window(self.win_type, self.win_len) | |
| self.perfect_reconstruct = check_COLA( | |
| window, self.win_len, self.win_len - self.win_hop | |
| ) | |
| window = th.FloatTensor(window) | |
| if self.mode == "continue": | |
| left_pad = (self.fft_len - self.win_len) // 2 | |
| right_pad = left_pad + (self.fft_len - self.win_len) % 2 | |
| window = F.pad(window, (left_pad, right_pad)) | |
| if self.win_sqrt: | |
| self.padded_window = window | |
| window = th.sqrt(window) | |
| else: | |
| self.padded_window = window**2 | |
| fft_kernel = fft_kernel.T * window | |
| ifft_kernel = ifft_kernel * window | |
| ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :] | |
| if self.mode == "continue": | |
| ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len] | |
| return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel | |
| def is_perfect(self): | |
| """ | |
| Whether the parameters win_len, win_hop and win_sqrt | |
| obey constants overlap-add(COLA) | |
| Returns: | |
| bool: Return true if parameters obey COLA. | |
| """ | |
| return self.perfect_reconstruct and self.pad_center | |
| def transform(self, inputs, return_type="complex"): | |
| """Take input data (audio) to STFT domain. | |
| Args: | |
| inputs (tensor): Tensor of floats, with shape (num_batch, num_samples) | |
| return_type (str, optional): return (mag, phase) when `magphase`, | |
| return (real, imag) when `realimag` and complex(real, imag) when `complex`. | |
| Defaults to 'complex'. | |
| Returns: | |
| tuple: (mag, phase) when `magphase`, return (real, imag) when | |
| `realimag`. Defaults to 'complex', each elements with shape | |
| [num_batch, num_frequencies, num_frames] | |
| """ | |
| assert return_type in ["magphase", "realimag", "complex"] | |
| if inputs.dim() == 2: | |
| inputs = th.unsqueeze(inputs, 1) | |
| self.num_samples = inputs.size(-1) | |
| if self.pad_center: | |
| inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect") | |
| enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop) | |
| outputs = th.transpose(enframe_inputs, 1, 2) | |
| outputs = F.linear(outputs, self.fft_k) | |
| outputs = th.transpose(outputs, 1, 2) | |
| dim = self.fft_len // 2 + 1 | |
| real = outputs[:, :dim, :] | |
| imag = outputs[:, dim:, :] | |
| if return_type == "realimag": | |
| return real, imag | |
| elif return_type == "complex": | |
| assert support_clp_op | |
| return th.complex(real, imag) | |
| else: | |
| mags = th.sqrt(real**2 + imag**2) | |
| phase = th.atan2(imag, real) | |
| return mags, phase | |
| def inverse(self, input1, input2=None, input_type="magphase"): | |
| """Call the inverse STFT (iSTFT), given tensors produced | |
| by the `transform` function. | |
| Args: | |
| input1 (tensors): Magnitude/Real-part of STFT with shape | |
| [num_batch, num_frequencies, num_frames] | |
| input2 (tensors): Phase/Imag-part of STFT with shape | |
| [num_batch, num_frequencies, num_frames] | |
| input_type (str, optional): Mathematical meaning of input tensor's. | |
| Defaults to 'magphase'. | |
| Returns: | |
| tensors: Reconstructed audio given magnitude and phase. Of | |
| shape [num_batch, num_samples] | |
| """ | |
| assert input_type in ["magphase", "realimag"] | |
| if input_type == "realimag": | |
| real, imag = None, None | |
| if support_clp_op and th.is_complex(input1): | |
| real, imag = input1.real, input1.imag | |
| else: | |
| real, imag = input1, input2 | |
| else: | |
| real = input1 * th.cos(input2) | |
| imag = input1 * th.sin(input2) | |
| inputs = th.cat([real, imag], dim=1) | |
| outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop) | |
| t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1)) | |
| t = t.to(inputs.device) | |
| coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop) | |
| num_frames = input1.size(-1) | |
| num_samples = num_frames * self.win_hop | |
| rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples | |
| outputs = outputs[..., rm_start:rm_end] | |
| coff = coff[..., rm_start:rm_end] | |
| coffidx = th.where(coff > 1e-8) | |
| outputs[coffidx] = outputs[coffidx] / (coff[coffidx]) | |
| return outputs.squeeze(dim=1) | |
| def forward(self, inputs): | |
| """Take input data (audio) to STFT domain and then back to audio. | |
| Args: | |
| inputs (tensor): Tensor of floats, with shape [num_batch, num_samples] | |
| Returns: | |
| tensor: Reconstructed audio given magnitude and phase. | |
| Of shape [num_batch, num_samples] | |
| """ | |
| mag, phase = self.transform(inputs) | |
| rec_wav = self.inverse(mag, phase) | |
| return rec_wav | |