File size: 3,787 Bytes
2cba492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Adapted from:
# Vocos: https://github.com/gemelo-ai/vocos/blob/main/vocos/feature_extractors.py
# BigVGAN: https://github.com/NVIDIA/BigVGAN/blob/main/meldataset.py (Also used by HiFT)

import torch
import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn


def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
    return torch.log(torch.clip(x, min=clip_val))


class MelSpectrogramFeature(nn.Module):
    def __init__(
        self,
        sample_rate: int = 24000,
        n_fft: int = 1024,
        hop_length: int = 256,
        n_mels: int = 100,
        padding: str = "center",
        fmin: int = 0,
        fmax: int | None = None,
        bigvgan_style_mel: bool = False,
    ):
        super().__init__()

        self.bigvgan_style_mel = bigvgan_style_mel
        if bigvgan_style_mel:
            # BigVGAN style: same padding, Slaney mel scale, with normalization
            self.n_fft = n_fft
            self.win_size = n_fft
            self.hop_size = hop_length
            # (n_mels, n_fft // 2 + 1)
            mel_basis = librosa_mel_fn(
                sr=sample_rate, n_fft=n_fft, n_mels=n_mels, norm="slaney", htk=False, fmin=fmin, fmax=fmax
            )
            mel_basis = torch.from_numpy(mel_basis).float()
            hann_window = torch.hann_window(n_fft)
            self.register_buffer("mel_basis", mel_basis)
            self.register_buffer("hann_window", hann_window)
        else:
            # Vocos style: center padding, HTK mel scale, without normalization
            if padding not in ["center", "same"]:
                raise ValueError("Padding must be 'center' or 'same'.")

            self.padding = padding
            self.mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels,
                center=padding == "center",
                power=1,
                fmin=fmin,
                fmax=fmax,
            )

    def forward(self, audio: torch.Tensor) -> torch.Tensor:
        """
        Returns:
            mel_specgram (Tensor): Mel spectrogram of the input audio. (B, C, L)
        """
        if self.bigvgan_style_mel:
            return self.bigvgan_mel(audio)
        else:
            return self.vocos_mel(audio)

    def vocos_mel(self, audio: torch.Tensor) -> torch.Tensor:
        if self.padding == "same":
            pad = self.mel_spec.win_length - self.mel_spec.hop_length
            audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")

        specgram = self.mel_spec.spectrogram(audio)
        mel_specgram = self.mel_spec.mel_scale(specgram)

        # Convert to log scale
        mel_specgram = safe_log(mel_specgram)
        return mel_specgram

    def bigvgan_mel(self, audio: torch.Tensor) -> torch.Tensor:
        # Pad so that the output length T = L // hop_length
        padding = (self.n_fft - self.hop_size) // 2
        audio = torch.nn.functional.pad(audio, (padding, padding), mode="reflect")
        audio = audio.reshape(-1, audio.shape[-1])

        spec = torch.stft(
            audio,
            n_fft=self.n_fft,
            hop_length=self.hop_size,
            win_length=self.win_size,
            window=self.hann_window,
            center=False,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        spec = spec.reshape(audio.shape[:-1] + spec.shape[-2:])

        spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
        mel_spec = torch.matmul(self.mel_basis, spec)
        mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
        return mel_spec