File size: 2,724 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from torch import Tensor

import torch
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn

class MelSpectrogram(nn.Module):
    def __init__(self, 
                 sample_rate:int,
                 nfft: int, 
                 hop_size: int,
                 mel_size:int, 
                 frequency_min:float,
                 frequency_max:float) -> None:
        super().__init__()
        # [mel_size, nfft // 2 + 1]
        self.nfft,self.hop_size,self.mel_size = nfft,hop_size,mel_size
        self.register_buffer(
            'mel_filterbank',
            torch.from_numpy(librosa_mel_fn(sr = sample_rate, 
                                            n_fft = nfft, 
                                            n_mels = mel_size, 
                                            fmin = frequency_min, 
                                            fmax = frequency_max)).float(),
            persistent=False)
        self.register_buffer(
            'hann_window', torch.hann_window(nfft), persistent=False)
    
    def forward(self, 
                audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged.
                ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram
        if torch.min(audio) < -1.:
            print('min value is ', torch.min(audio))
        if torch.max(audio) > 1.:
            print('max value is ', torch.max(audio))

        # [BatchSize, nfft // 2 + 1, T / hop_size]
        spec:Tensor = torch.stft(  audio,
                            n_fft=self.nfft,
                            hop_length=self.hop_size,
                            window=self.hann_window,
                            center=True, pad_mode='reflect',
                            return_complex=True)
        # [BatchSize, nfft // 2 + 1, T / hop_size]
        mag:Tensor = abs(spec)
        # [BatchSize, mel_size, T / hop_size]
        return torch.matmul(self.mel_filterbank, mag)
    
    def get_log_mel(self, 
                audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged.
                ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram
        mel_spec:Tensor = self.forward(audio)
        return torch.log(mel_spec + 1e-7)
    
    def get_dynamic_range_compresed_mel(
            self,
            audio: Tensor #[torch.float32; [B, T]], audio signal, [-1, 1]-ranged.
            ) -> Tensor: #[torch.float32; [B, mel, T / strides]], mel spectrogram
        #used in hi-fi gan
        mel_spec:Tensor = self.forward(audio)
        return self.dynamic_range_compression(mel_spec)
    
    def dynamic_range_compression(self, x, C=1, clip_val=1e-5):
        return torch.log(torch.clamp(x, min=clip_val) * C)