FlashSR_One-step_Versatile_Audio_Super-resolution / TorchJaekwon /Model /AudioModule /FeatureExtract /MelSpectrogram.py
| from torch import Tensor | |
| import torch | |
| import torch.nn as nn | |
| from librosa.filters import mel as librosa_mel_fn | |
| class MelSpectrogram(nn.Module): | |
| def __init__(self, | |
| sample_rate:int, | |
| nfft: int, | |
| hop_size: int, | |
| mel_size:int, | |
| frequency_min:float, | |
| frequency_max:float) -> None: | |
| super().__init__() | |
| # [mel_size, nfft // 2 + 1] | |
| self.nfft,self.hop_size,self.mel_size = nfft,hop_size,mel_size | |
| self.register_buffer( | |
| 'mel_filterbank', | |
| torch.from_numpy(librosa_mel_fn(sr = sample_rate, | |
| n_fft = nfft, | |
| n_mels = mel_size, | |
| fmin = frequency_min, | |
| fmax = frequency_max)).float(), | |
| persistent=False) | |
| self.register_buffer( | |
| 'hann_window', torch.hann_window(nfft), persistent=False) | |
| def forward(self, | |
| audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged. | |
| ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram | |
| if torch.min(audio) < -1.: | |
| print('min value is ', torch.min(audio)) | |
| if torch.max(audio) > 1.: | |
| print('max value is ', torch.max(audio)) | |
| # [BatchSize, nfft // 2 + 1, T / hop_size] | |
| spec:Tensor = torch.stft( audio, | |
| n_fft=self.nfft, | |
| hop_length=self.hop_size, | |
| window=self.hann_window, | |
| center=True, pad_mode='reflect', | |
| return_complex=True) | |
| # [BatchSize, nfft // 2 + 1, T / hop_size] | |
| mag:Tensor = abs(spec) | |
| # [BatchSize, mel_size, T / hop_size] | |
| return torch.matmul(self.mel_filterbank, mag) | |
| def get_log_mel(self, | |
| audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged. | |
| ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram | |
| mel_spec:Tensor = self.forward(audio) | |
| return torch.log(mel_spec + 1e-7) | |
| def get_dynamic_range_compresed_mel( | |
| self, | |
| audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged. | |
| ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram | |
| #used in hi-fi gan | |
| mel_spec:Tensor = self.forward(audio) | |
| return self.dynamic_range_compression(mel_spec) | |
| def dynamic_range_compression(self, x, C=1, clip_val=1e-5): | |
| return torch.log(torch.clamp(x, min=clip_val) * C) | |