| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| sys.path.append(os.getcwd()) | |
| class Spectrogram(nn.Module): | |
| def __init__( | |
| self, | |
| hop_length, | |
| win_length, | |
| n_fft=None, | |
| clamp=1e-10 | |
| ): | |
| super(Spectrogram, self).__init__() | |
| self.n_fft = win_length if n_fft is None else n_fft | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.clamp = clamp | |
| self.register_buffer("window", torch.hann_window(win_length), persistent=False) | |
| def forward(self, audio, center=True): | |
| bs, c, segment_samples = audio.shape | |
| audio = audio.reshape(bs * c, segment_samples) | |
| if str(audio.device).startswith(("ocl", "privateuseone")): | |
| if not hasattr(self, "stft"): | |
| from main.library.backends.utils import STFT | |
| self.stft = STFT( | |
| filter_length=self.n_fft, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length | |
| ).to(audio.device) | |
| magnitude = self.stft.transform(audio, 1e-9) | |
| else: | |
| fft = torch.stft( | |
| audio, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| window=self.window, | |
| center=center, | |
| pad_mode="reflect", | |
| return_complex=True | |
| ) | |
| magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt() | |
| mag = magnitude.transpose(1, 2).clamp(self.clamp, np.inf) | |
| mag = mag.reshape(bs, c, mag.shape[1], mag.shape[2]) | |
| return mag |