import os import sys import torch from torchaudio.transforms import Resample sys.path.append(os.getcwd()) from infer.lib.predictors.FCPE.stft import STFT class Wav2Mel: def __init__( self, device=None, dtype=torch.float32 ): self.sample_rate = 16000 self.hop_size = 160 if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.dtype = dtype self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000) self.resample_kernel = {} def extract_nvstft( self, audio, keyshift=0, train=False ): return self.stft.get_mel( audio, keyshift=keyshift, train=train ).transpose(1, 2) def extract_mel( self, audio, sample_rate, keyshift=0, train=False ): audio = audio.to(self.dtype).to(self.device) if sample_rate == self.sample_rate: audio_res = audio else: key_str = str(sample_rate) if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample( sample_rate, self.sample_rate, lowpass_filter_width=128 ) self.resample_kernel[key_str] = ( self.resample_kernel[key_str].to(self.dtype).to(self.device) ) audio_res = self.resample_kernel[key_str](audio) mel = self.extract_nvstft( audio_res, keyshift=keyshift, train=train ) n_frames = int(audio.shape[1] // self.hop_size) + 1 mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel) return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel def __call__(self, audio, sample_rate, keyshift=0, train=False): return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)