File size: 2,079 Bytes
0a0615c
 
 
 
 
 
 
 
e25f2a2
0a0615c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import sys
import torch

from torchaudio.transforms import Resample

sys.path.append(os.getcwd())

from infer.lib.predictors.FCPE.stft import STFT

class Wav2Mel:
    def __init__(
        self, 
        device=None, 
        dtype=torch.float32
    ):
        self.sample_rate = 16000
        self.hop_size = 160

        if device is None: 
            device = "cuda" if torch.cuda.is_available() else "cpu"

        self.device = device
        self.dtype = dtype
        self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
        self.resample_kernel = {}

    def extract_nvstft(
        self, 
        audio, 
        keyshift=0, 
        train=False
    ):
        return self.stft.get_mel(
            audio, 
            keyshift=keyshift, 
            train=train
        ).transpose(1, 2)

    def extract_mel(
        self, 
        audio, 
        sample_rate, 
        keyshift=0, 
        train=False
    ):
        audio = audio.to(self.dtype).to(self.device)

        if sample_rate == self.sample_rate: 
            audio_res = audio
        else:
            key_str = str(sample_rate)
            if key_str not in self.resample_kernel: 
                self.resample_kernel[key_str] = Resample(
                    sample_rate, 
                    self.sample_rate, 
                    lowpass_filter_width=128
                )

            self.resample_kernel[key_str] = (
                self.resample_kernel[key_str].to(self.dtype).to(self.device)
            )

            audio_res = self.resample_kernel[key_str](audio)

        mel = self.extract_nvstft(
            audio_res, 
            keyshift=keyshift, 
            train=train
        ) 

        n_frames = int(audio.shape[1] // self.hop_size) + 1
        mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)

        return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel

    def __call__(self, audio, sample_rate, keyshift=0, train=False):
        return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)