NeoPy's picture
Update infer/lib/predictors/FCPE/wav2mel.py
e25f2a2 verified
import os
import sys
import torch
from torchaudio.transforms import Resample
sys.path.append(os.getcwd())
from infer.lib.predictors.FCPE.stft import STFT
class Wav2Mel:
def __init__(
self,
device=None,
dtype=torch.float32
):
self.sample_rate = 16000
self.hop_size = 160
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
self.resample_kernel = {}
def extract_nvstft(
self,
audio,
keyshift=0,
train=False
):
return self.stft.get_mel(
audio,
keyshift=keyshift,
train=train
).transpose(1, 2)
def extract_mel(
self,
audio,
sample_rate,
keyshift=0,
train=False
):
audio = audio.to(self.dtype).to(self.device)
if sample_rate == self.sample_rate:
audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel:
self.resample_kernel[key_str] = Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=128
)
self.resample_kernel[key_str] = (
self.resample_kernel[key_str].to(self.dtype).to(self.device)
)
audio_res = self.resample_kernel[key_str](audio)
mel = self.extract_nvstft(
audio_res,
keyshift=keyshift,
train=train
)
n_frames = int(audio.shape[1] // self.hop_size) + 1
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
def __call__(self, audio, sample_rate, keyshift=0, train=False):
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)