Spaces:
Build error
Build error
File size: 6,092 Bytes
f12fa11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import List
import torch
import torchaudio
from torch import nn
import math
from decoder.modules import safe_log
from encoder.modules import SEANetEncoder, SEANetDecoder
from encoder import EncodecModel
from encoder.quantization import ResidualVectorQuantizer
class FeatureExtractor(nn.Module):
"""Base class for feature extractors."""
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Extract features from the given audio.
Args:
audio (Tensor): Input audio waveform.
Returns:
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=padding == "center",
power=1,
)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
features = safe_log(mel)
return features
class EncodecFeatures(FeatureExtractor):
def __init__(
self,
encodec_model: str = "encodec_24khz",
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
train_codebooks: bool = False,
num_quantizers: int = 1,
dowmsamples: List[int] = [6, 5, 5, 4],
vq_bins: int = 16384,
vq_kmeans: int = 800,
):
super().__init__()
# breakpoint()
self.frame_rate = 25 # not use
# n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
n_q = num_quantizers # important
encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
decay=0.99, kmeans_init=True)
# breakpoint()
if encodec_model == "encodec_24khz":
self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
target_bandwidths=bandwidths, sample_rate=24000, channels=1)
else:
raise ValueError(
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
)
for param in self.encodec.parameters():
param.requires_grad = True
# self.num_q = n_q
# codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
# self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
self.bandwidths = bandwidths
# @torch.no_grad()
# def get_encodec_codes(self, audio):
# audio = audio.unsqueeze(1)
# emb = self.encodec.encoder(audio)
# codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
# return codes
def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
# breakpoint()
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
# codes = self.get_encodec_codes(audio)
# # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
# # with offsets given by the number of bins, and finally summed in a vectorized operation.
# offsets = torch.arange(
# 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
# )
# embeddings_idxs = codes + offsets.view(-1, 1, 1)
# features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
# return features.transpose(1, 2)
def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
if self.training:
self.encodec.train()
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encodec.encoder(audio)
q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
quantized = q_res.quantized
codes = q_res.codes
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
return quantized, codes, commit_loss
|