| | import librosa |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class TorchSTFT(nn.Module): |
| | """Some of the audio processing funtions using Torch for faster batch processing. |
| | |
| | Args: |
| | |
| | n_fft (int): |
| | FFT window size for STFT. |
| | |
| | hop_length (int): |
| | number of frames between STFT columns. |
| | |
| | win_length (int, optional): |
| | STFT window length. |
| | |
| | pad_wav (bool, optional): |
| | If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. |
| | |
| | window (str, optional): |
| | The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" |
| | |
| | sample_rate (int, optional): |
| | target audio sampling rate. Defaults to None. |
| | |
| | mel_fmin (int, optional): |
| | minimum filter frequency for computing melspectrograms. Defaults to None. |
| | |
| | mel_fmax (int, optional): |
| | maximum filter frequency for computing melspectrograms. Defaults to None. |
| | |
| | n_mels (int, optional): |
| | number of melspectrogram dimensions. Defaults to None. |
| | |
| | use_mel (bool, optional): |
| | If True compute the melspectrograms otherwise. Defaults to False. |
| | |
| | do_amp_to_db_linear (bool, optional): |
| | enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. |
| | |
| | spec_gain (float, optional): |
| | gain applied when converting amplitude to DB. Defaults to 1.0. |
| | |
| | power (float, optional): |
| | Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. |
| | |
| | use_htk (bool, optional): |
| | Use HTK formula in mel filter instead of Slaney. |
| | |
| | mel_norm (None, 'slaney', or number, optional): |
| | If 'slaney', divide the triangular mel weights by the width of the mel band |
| | (area normalization). |
| | |
| | If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. |
| | See `librosa.util.normalize` for a full description of supported norm values |
| | (including `+-np.inf`). |
| | |
| | Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_fft, |
| | hop_length, |
| | win_length, |
| | pad_wav=False, |
| | window="hann_window", |
| | sample_rate=None, |
| | mel_fmin=0, |
| | mel_fmax=None, |
| | n_mels=80, |
| | use_mel=False, |
| | do_amp_to_db=False, |
| | spec_gain=1.0, |
| | power=None, |
| | use_htk=False, |
| | mel_norm="slaney", |
| | normalized=False, |
| | ): |
| | super().__init__() |
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.win_length = win_length |
| | self.pad_wav = pad_wav |
| | self.sample_rate = sample_rate |
| | self.mel_fmin = mel_fmin |
| | self.mel_fmax = mel_fmax |
| | self.n_mels = n_mels |
| | self.use_mel = use_mel |
| | self.do_amp_to_db = do_amp_to_db |
| | self.spec_gain = spec_gain |
| | self.power = power |
| | self.use_htk = use_htk |
| | self.mel_norm = mel_norm |
| | self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) |
| | self.mel_basis = None |
| | self.normalized = normalized |
| | if use_mel: |
| | self._build_mel_basis() |
| |
|
| | def __call__(self, x): |
| | """Compute spectrogram frames by torch based stft. |
| | |
| | Args: |
| | x (Tensor): input waveform |
| | |
| | Returns: |
| | Tensor: spectrogram frames. |
| | |
| | Shapes: |
| | x: [B x T] or [:math:`[B, 1, T]`] |
| | """ |
| | if x.ndim == 2: |
| | x = x.unsqueeze(1) |
| | if self.pad_wav: |
| | padding = int((self.n_fft - self.hop_length) / 2) |
| | x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") |
| | |
| | o = torch.stft( |
| | x.squeeze(1), |
| | self.n_fft, |
| | self.hop_length, |
| | self.win_length, |
| | self.window, |
| | center=True, |
| | pad_mode="reflect", |
| | normalized=self.normalized, |
| | onesided=True, |
| | return_complex=False, |
| | ) |
| | M = o[:, :, :, 0] |
| | P = o[:, :, :, 1] |
| | S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) |
| |
|
| | if self.power is not None: |
| | S = S**self.power |
| |
|
| | if self.use_mel: |
| | S = torch.matmul(self.mel_basis.to(x), S) |
| | if self.do_amp_to_db: |
| | S = self._amp_to_db(S, spec_gain=self.spec_gain) |
| | return S |
| |
|
| | def _build_mel_basis(self): |
| | mel_basis = librosa.filters.mel( |
| | sr=self.sample_rate, |
| | n_fft=self.n_fft, |
| | n_mels=self.n_mels, |
| | fmin=self.mel_fmin, |
| | fmax=self.mel_fmax, |
| | htk=self.use_htk, |
| | norm=self.mel_norm, |
| | ) |
| | self.mel_basis = torch.from_numpy(mel_basis).float() |
| |
|
| | @staticmethod |
| | def _amp_to_db(x, spec_gain=1.0): |
| | return torch.log(torch.clamp(x, min=1e-5) * spec_gain) |
| |
|
| | @staticmethod |
| | def _db_to_amp(x, spec_gain=1.0): |
| | return torch.exp(x) / spec_gain |
| |
|