# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. # LICENSE is in incl_licenses directory. if __name__ == "__main__": import os import sys sys.path.append("../") import math import os import pathlib import random import numpy as np import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn from librosa.util import normalize from scipy.io.wavfile import read from tqdm import tqdm from config import config MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) def load_wav(full_path, sr_target): sampling_rate, data = read(full_path) if sampling_rate != sr_target: raise RuntimeError( f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz" ) return data, sampling_rate def dynamic_range_compression(x, C=1, clip_val=1e-5): return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) def dynamic_range_decompression(x, C=1): return np.exp(x) / C def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C def spectral_normalize_torch(magnitudes): return dynamic_range_compression_torch(magnitudes) def spectral_de_normalize_torch(magnitudes): return dynamic_range_decompression_torch(magnitudes) mel_basis_cache = {} hann_window_cache = {} def mel_spectrogram( y: torch.Tensor, n_fft: int, num_mels: int, sampling_rate: int, hop_size: int, win_size: int, fmin: int, fmax: int = None, center: bool = False, ) -> torch.Tensor: """ Calculate the mel spectrogram of an input signal. This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). Args: y (torch.Tensor): Input signal. n_fft (int): FFT size. num_mels (int): Number of mel bins. sampling_rate (int): Sampling rate of the input signal. hop_size (int): Hop size for STFT. win_size (int): Window size for STFT. fmin (int): Minimum frequency for mel filterbank. fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn center (bool): Whether to pad the input to center the frames. Default is False. Returns: torch.Tensor: Mel spectrogram. """ if torch.min(y) < -1.0: print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") if torch.max(y) > 1.0: print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") device = y.device key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" if key not in mel_basis_cache: mel = librosa_mel_fn( sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax ) mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) hann_window_cache[key] = torch.hann_window(win_size).to(device) mel_basis = mel_basis_cache[key] hann_window = hann_window_cache[key] padding = (n_fft - hop_size) // 2 y = torch.nn.functional.pad( y.unsqueeze(1), (padding, padding), mode="reflect" ).squeeze(1) spec = torch.stft( y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) mel_spec = torch.matmul(mel_basis, spec) mel_spec = spectral_normalize_torch(mel_spec) return mel_spec def get_mel_spectrogram(wav, sr): """ Generate mel spectrogram from a waveform using given hyperparameters. Args: wav (torch.Tensor): Input waveform. h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax. Returns: torch.Tensor: Mel spectrogram. """ assert sr == config.sampling_rate, ( f"Given SR : {sr}, Required SR: {config.sampling_rate}" ) return mel_spectrogram( wav, config.filter_length, config.n_mel_channels, config.sampling_rate, config.hop_length, config.win_length, config.mel_fmin, config.mel_fmax, ) if __name__ == "__main__": import torchaudio path = "/delta/NeuralSpeak_cfm_conv/Samples/IITM_cfm_bigv_harsh/S2A/orig/0_test.wav" wav, sr = torchaudio.load(path) wav = wav[:, :sr] print(wav.shape) mel_spec = get_mel_spectrogram(wav, sr) duration = wav.shape[-1] / sr print(duration, mel_spec.shape)