| import torch |
|
|
| if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"): |
|
|
| class STFT: |
| def __init__(self): |
| self.device = "cuda" |
| self.fourier_bases = {} |
|
|
| def _get_fourier_basis(self, n_fft): |
| |
| if n_fft in self.fourier_bases: |
| return self.fourier_bases[n_fft] |
| fourier_basis = torch.fft.fft(torch.eye(n_fft, device="cpu")).to( |
| self.device |
| ) |
| |
| cutoff = n_fft // 2 + 1 |
| fourier_basis = torch.cat( |
| [fourier_basis.real[:cutoff], fourier_basis.imag[:cutoff]], dim=0 |
| ) |
| |
| self.fourier_bases[n_fft] = fourier_basis |
| return fourier_basis |
|
|
| def transform(self, input, n_fft, hop_length, window): |
| |
| fourier_basis = self._get_fourier_basis(n_fft) |
| |
| fourier_basis = fourier_basis * window |
| |
| pad_amount = n_fft // 2 |
| input = torch.nn.functional.pad( |
| input, (pad_amount, pad_amount), mode="reflect" |
| ) |
| |
| input_frames = input.unfold(1, n_fft, hop_length).permute(0, 2, 1) |
| |
| fourier_transform = torch.matmul(fourier_basis, input_frames) |
| cutoff = n_fft // 2 + 1 |
| return torch.complex( |
| fourier_transform[:, :cutoff, :], fourier_transform[:, cutoff:, :] |
| ) |
|
|
| stft = STFT() |
| _torch_stft = torch.stft |
|
|
| def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): |
| |
| if ( |
| kwargs.get("win_length") == None |
| and kwargs.get("center") == None |
| and kwargs.get("return_complex") == True |
| ): |
| |
| return stft.transform( |
| input, kwargs.get("n_fft"), kwargs.get("hop_length"), window |
| ) |
| else: |
| |
| return _torch_stft( |
| input=input.cpu(), window=window.cpu(), *args, **kwargs |
| ).to(input.device) |
|
|
| def z_jit(f, *_, **__): |
| f.graph = torch._C.Graph() |
| return f |
|
|
| |
| torch.stft = z_stft |
| torch.jit.script = z_jit |
| |
| torch.backends.cudnn.enabled = False |
| torch.backends.cuda.enable_flash_sdp(False) |
| torch.backends.cuda.enable_math_sdp(True) |
| torch.backends.cuda.enable_mem_efficient_sdp(False) |
|
|