File size: 6,092 Bytes
f12fa11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import List

import torch
import torchaudio
from torch import nn
import math
from decoder.modules import safe_log
from encoder.modules import SEANetEncoder, SEANetDecoder
from encoder import EncodecModel
from encoder.quantization import ResidualVectorQuantizer


class FeatureExtractor(nn.Module):
    """Base class for feature extractors."""

    def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Extract features from the given audio.

        Args:
            audio (Tensor): Input audio waveform.

        Returns:
            Tensor: Extracted features of shape (B, C, L), where B is the batch size,
                    C denotes output features, and L is the sequence length.
        """
        raise NotImplementedError("Subclasses must implement the forward method.")


class MelSpectrogramFeatures(FeatureExtractor):
    def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
        super().__init__()
        if padding not in ["center", "same"]:
            raise ValueError("Padding must be 'center' or 'same'.")
        self.padding = padding
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            center=padding == "center",
            power=1,
        )

    def forward(self, audio, **kwargs):
        if self.padding == "same":
            pad = self.mel_spec.win_length - self.mel_spec.hop_length
            audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
        mel = self.mel_spec(audio)
        features = safe_log(mel)
        return features


class EncodecFeatures(FeatureExtractor):
    def __init__(
        self,
        encodec_model: str = "encodec_24khz",
        bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
        train_codebooks: bool = False,
        num_quantizers: int = 1, 
        dowmsamples: List[int] = [6, 5, 5, 4],
        vq_bins: int = 16384,
        vq_kmeans: int = 800,
    ):
        super().__init__()

        # breakpoint()
        self.frame_rate = 25  # not use
        # n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
        n_q = num_quantizers   # important
        encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
                                dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
                                kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
                                true_skip=False, compress=2)
        decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
                                dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
                                kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
                                true_skip=False, compress=2)
        quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
                                            decay=0.99, kmeans_init=True)

        # breakpoint()
        if encodec_model == "encodec_24khz":
            self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
                                        target_bandwidths=bandwidths, sample_rate=24000, channels=1)
        else:
            raise ValueError(
                f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
            )
        for param in self.encodec.parameters():
            param.requires_grad = True
        # self.num_q = n_q
        # codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
        # self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
        self.bandwidths = bandwidths

    # @torch.no_grad()
    # def get_encodec_codes(self, audio):
    #     audio = audio.unsqueeze(1)
    #     emb = self.encodec.encoder(audio)
    #     codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
    #     return codes

    def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
        if self.training:
            self.encodec.train()

        audio = audio.unsqueeze(1)                  # audio(16,24000)

        # breakpoint()

        emb = self.encodec.encoder(audio)
        q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
        quantized = q_res.quantized
        codes = q_res.codes
        commit_loss = q_res.penalty                 # codes(8,16,75),features(16,128,75)

        return quantized, codes, commit_loss

        # codes = self.get_encodec_codes(audio)
        # # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
        # # with offsets given by the number of bins, and finally summed in a vectorized operation.
        # offsets = torch.arange(
        #     0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
        # )
        # embeddings_idxs = codes + offsets.view(-1, 1, 1)
        # features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
        # return features.transpose(1, 2)

    def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
        if self.training:
            self.encodec.train()

        audio = audio.unsqueeze(1)                  # audio(16,24000)
        emb = self.encodec.encoder(audio)
        q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
        quantized = q_res.quantized
        codes = q_res.codes
        commit_loss = q_res.penalty                 # codes(8,16,75),features(16,128,75)

        return quantized, codes, commit_loss