# Adapted from: # Vocos: https://github.com/gemelo-ai/vocos/blob/main/vocos/feature_extractors.py # BigVGAN: https://github.com/NVIDIA/BigVGAN/blob/main/meldataset.py (Also used by HiFT) import torch import torchaudio from librosa.filters import mel as librosa_mel_fn from torch import nn def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: return torch.log(torch.clip(x, min=clip_val)) class MelSpectrogramFeature(nn.Module): def __init__( self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, padding: str = "center", fmin: int = 0, fmax: int | None = None, bigvgan_style_mel: bool = False, ): super().__init__() self.bigvgan_style_mel = bigvgan_style_mel if bigvgan_style_mel: # BigVGAN style: same padding, Slaney mel scale, with normalization self.n_fft = n_fft self.win_size = n_fft self.hop_size = hop_length # (n_mels, n_fft // 2 + 1) mel_basis = librosa_mel_fn( sr=sample_rate, n_fft=n_fft, n_mels=n_mels, norm="slaney", htk=False, fmin=fmin, fmax=fmax ) mel_basis = torch.from_numpy(mel_basis).float() hann_window = torch.hann_window(n_fft) self.register_buffer("mel_basis", mel_basis) self.register_buffer("hann_window", hann_window) else: # Vocos style: center padding, HTK mel scale, without normalization if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.mel_spec = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=padding == "center", power=1, fmin=fmin, fmax=fmax, ) def forward(self, audio: torch.Tensor) -> torch.Tensor: """ Returns: mel_specgram (Tensor): Mel spectrogram of the input audio. (B, C, L) """ if self.bigvgan_style_mel: return self.bigvgan_mel(audio) else: return self.vocos_mel(audio) def vocos_mel(self, audio: torch.Tensor) -> torch.Tensor: if self.padding == "same": pad = self.mel_spec.win_length - self.mel_spec.hop_length audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") specgram = self.mel_spec.spectrogram(audio) mel_specgram = self.mel_spec.mel_scale(specgram) # Convert to log scale mel_specgram = safe_log(mel_specgram) return mel_specgram def bigvgan_mel(self, audio: torch.Tensor) -> torch.Tensor: # Pad so that the output length T = L // hop_length padding = (self.n_fft - self.hop_size) // 2 audio = torch.nn.functional.pad(audio, (padding, padding), mode="reflect") audio = audio.reshape(-1, audio.shape[-1]) spec = torch.stft( audio, n_fft=self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, center=False, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) spec = spec.reshape(audio.shape[:-1] + spec.shape[-2:]) spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) mel_spec = torch.matmul(self.mel_basis, spec) mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) return mel_spec