Spaces:
Sleeping
Sleeping
| from functools import reduce | |
| from tokenize import Triple | |
| from torch import nn | |
| from torch.autograd import Function | |
| from torch.nn import functional as F | |
| from torch.nn.utils import spectral_norm, weight_norm | |
| from torch.utils.checkpoint import checkpoint | |
| import einops | |
| import math | |
| import numpy as np | |
| import os | |
| import random | |
| import torch | |
| import torchaudio | |
| import typing as tp | |
| import warnings | |
| from .audio import TorchMelSpectrogram | |
| from .ecapa_tdnn import ECAPA_TDNN | |
| from .hubert import HuBERT | |
| from ..acoustic_codec.vector_quantization import VectorQuantization | |
| CONV_NORMALIZATIONS = frozenset( | |
| [ | |
| "none", | |
| "weight_norm", | |
| "spectral_norm", | |
| "time_layer_norm", | |
| "layer_norm", | |
| "time_group_norm", | |
| ] | |
| ) | |
| NORM = "weight_norm" | |
| def get_mask_from_lengths(lengths, max_len=None): | |
| max_len = torch.max(lengths).item() if max_len is None else max_len | |
| ids = torch.arange(0, max_len).to(lengths.device) | |
| mask = ~(ids < lengths.unsqueeze(1)).bool() | |
| return mask | |
| class ConvLayerNorm(nn.LayerNorm): | |
| def __init__( | |
| self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs | |
| ): | |
| super().__init__(normalized_shape, **kwargs) | |
| def forward(self, x): | |
| x = einops.rearrange(x, "b ... t -> b t ...") | |
| x = super().forward(x) | |
| x = einops.rearrange(x, "b t ... -> b ... t") | |
| return | |
| def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: | |
| assert norm in CONV_NORMALIZATIONS | |
| if norm == "weight_norm": | |
| return weight_norm(module) | |
| elif norm == "spectral_norm": | |
| return spectral_norm(module) | |
| else: | |
| # We already check was in CONV_NORMALIZATION, so any other choice | |
| # doesn't need reparametrization. | |
| return module | |
| def get_norm_module( | |
| module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs | |
| ) -> nn.Module: | |
| assert norm in CONV_NORMALIZATIONS | |
| if norm == "layer_norm": | |
| assert isinstance(module, nn.modules.conv._ConvNd) | |
| return ConvLayerNorm(module.out_channels, **norm_kwargs) | |
| elif norm == "time_group_norm": | |
| if causal: | |
| raise ValueError("GroupNorm doesn't support causal evaluation.") | |
| assert isinstance(module, nn.modules.conv._ConvNd) | |
| return nn.GroupNorm(1, module.out_channels, **norm_kwargs) | |
| else: | |
| return nn.Identity() | |
| def get_extra_padding_for_conv1d( | |
| x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
| ) -> int: | |
| length = x.shape[-1] | |
| n_frames = (length - kernel_size + padding_total) / stride + 1 | |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
| return ideal_length - length | |
| def pad_for_conv1d( | |
| x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
| ): | |
| extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) | |
| return F.pad(x, (0, extra_padding)) | |
| def pad1d( | |
| x: torch.Tensor, | |
| paddings: tp.Tuple[int, int], | |
| mode: str = "zero", | |
| value: float = 0.0, | |
| ): | |
| length = x.shape[-1] | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| if mode == "reflect": | |
| max_pad = max(padding_left, padding_right) | |
| extra_pad = 0 | |
| if length <= max_pad: | |
| extra_pad = max_pad - length + 1 | |
| x = F.pad(x, (0, extra_pad)) | |
| padded = F.pad(x, paddings, mode, value) | |
| end = padded.shape[-1] - extra_pad | |
| return padded[..., :end] | |
| else: | |
| return F.pad(x, paddings, mode, value) | |
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| assert (padding_left + padding_right) <= x.shape[-1] | |
| end = x.shape[-1] - padding_right | |
| return x[..., padding_left:end] | |
| class NormConv1d(nn.Module): | |
| def __init__( | |
| self, | |
| *args, | |
| causal: bool = False, | |
| norm: str = "none", | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) | |
| self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) | |
| self.norm_type = norm | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| return x | |
| class NormConv2d(nn.Module): | |
| def __init__( | |
| self, | |
| *args, | |
| norm: str = "none", | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) | |
| self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) | |
| self.norm_type = norm | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| return x | |
| class NormConvTranspose1d(nn.Module): | |
| def __init__( | |
| self, | |
| *args, | |
| causal: bool = False, | |
| norm: str = "none", | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.convtr = apply_parametrization_norm( | |
| nn.ConvTranspose1d(*args, **kwargs), norm | |
| ) | |
| self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) | |
| self.norm_type = norm | |
| def forward(self, x): | |
| x = self.convtr(x) | |
| x = self.norm(x) | |
| return x | |
| class NormConvTranspose2d(nn.Module): | |
| def __init__( | |
| self, | |
| *args, | |
| norm: str = "none", | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.convtr = apply_parametrization_norm( | |
| nn.ConvTranspose2d(*args, **kwargs), norm | |
| ) | |
| self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) | |
| def forward(self, x): | |
| x = self.convtr(x) | |
| x = self.norm(x) | |
| return x | |
| class SConv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| causal: bool = False, | |
| norm: str = "weight_norm", | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| pad_mode: str = "reflect", | |
| ): | |
| super().__init__() | |
| # warn user on unusual setup between dilation and stride | |
| if stride > 1 and dilation > 1: | |
| warnings.warn( | |
| "SConv1d has been initialized with stride > 1 and dilation > 1" | |
| f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." | |
| ) | |
| self.conv = NormConv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| causal=causal, | |
| norm=norm, | |
| norm_kwargs=norm_kwargs, | |
| ) | |
| self.causal = causal | |
| self.pad_mode = pad_mode | |
| def forward(self, x): | |
| B, C, T = x.shape | |
| kernel_size = self.conv.conv.kernel_size[0] | |
| stride = self.conv.conv.stride[0] | |
| dilation = self.conv.conv.dilation[0] | |
| kernel_size = ( | |
| kernel_size - 1 | |
| ) * dilation + 1 # effective kernel size with dilations | |
| padding_total = kernel_size - stride | |
| extra_padding = get_extra_padding_for_conv1d( | |
| x, kernel_size, stride, padding_total | |
| ) | |
| if self.causal: | |
| # Left padding for causal | |
| x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) | |
| else: | |
| # Asymmetric padding required for odd strides | |
| padding_right = padding_total // 2 | |
| padding_left = padding_total - padding_right | |
| x = pad1d( | |
| x, (padding_left, padding_right + extra_padding), mode=self.pad_mode | |
| ) | |
| return self.conv(x) | |
| class SConvTranspose1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| causal: bool = False, | |
| norm: str = "weight_norm", | |
| trim_right_ratio: float = 1.0, | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| ): | |
| super().__init__() | |
| self.convtr = NormConvTranspose1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| causal=causal, | |
| norm=norm, | |
| norm_kwargs=norm_kwargs, | |
| ) | |
| self.causal = causal | |
| self.trim_right_ratio = trim_right_ratio | |
| assert ( | |
| self.causal or self.trim_right_ratio == 1.0 | |
| ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" | |
| assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 | |
| def forward(self, x): | |
| kernel_size = self.convtr.convtr.kernel_size[0] | |
| stride = self.convtr.convtr.stride[0] | |
| padding_total = kernel_size - stride | |
| y = self.convtr(x) | |
| # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be | |
| # removed at the very end, when keeping only the right length for the output, | |
| # as removing it here would require also passing the length at the matching layer | |
| # in the encoder. | |
| if self.causal: | |
| # Trim the padding on the right according to the specified ratio | |
| # if trim_right_ratio = 1.0, trim everything from right | |
| padding_right = math.ceil(padding_total * self.trim_right_ratio) | |
| padding_left = padding_total - padding_right | |
| y = unpad1d(y, (padding_left, padding_right)) | |
| else: | |
| # Asymmetric padding required for odd strides | |
| padding_right = padding_total // 2 | |
| padding_left = padding_total - padding_right | |
| y = unpad1d(y, (padding_left, padding_right)) | |
| return y | |
| class SLSTM(nn.Module): | |
| def __init__( | |
| self, | |
| dimension: int, | |
| num_layers: int = 2, | |
| bidirectional: bool = False, | |
| skip: bool = True, | |
| ): | |
| super().__init__() | |
| self.bidirectional = bidirectional | |
| self.skip = skip | |
| if bidirectional: | |
| self.lstm = nn.LSTM( | |
| dimension, dimension // 2, num_layers, bidirectional=bidirectional | |
| ) | |
| else: | |
| self.lstm = nn.LSTM(dimension, dimension, num_layers) | |
| def forward(self, x): | |
| x = x.permute(2, 0, 1) | |
| y, _ = self.lstm(x) | |
| if self.skip: | |
| y = y + x | |
| y = y.permute(1, 2, 0) | |
| return y | |
| class Swish(nn.Module): | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| class ResidualUnit(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, groups=1): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| SConv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels // 2, | |
| kernel_size=kernel_size, | |
| groups=groups, | |
| norm=NORM, | |
| ), | |
| Swish(), | |
| SConv1d( | |
| in_channels=out_channels // 2, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| groups=groups, | |
| norm=NORM, | |
| ), | |
| ) | |
| def forward(self, x): | |
| return x + self.layers(x) | |
| class EncoderBlock(nn.Module): | |
| def __init__(self, out_channels, stride): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| ResidualUnit(in_channels=out_channels, out_channels=out_channels), | |
| Swish(), | |
| ResidualUnit(in_channels=out_channels, out_channels=out_channels), | |
| Swish(), | |
| SConv1d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| norm=NORM, | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, in_channels, stride): | |
| super().__init__() | |
| out_channels = in_channels | |
| self.layers = nn.Sequential( | |
| SConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| norm=NORM, | |
| ), | |
| Swish(), | |
| ResidualUnit(in_channels=out_channels, out_channels=out_channels), | |
| Swish(), | |
| ResidualUnit(in_channels=out_channels, out_channels=out_channels), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class Encoder(nn.Module): | |
| def __init__(self, C, D, strides=[2, 2], checkpointing=True): | |
| super().__init__() | |
| self.checkpointing = checkpointing | |
| self.downsample_scale = np.cumprod(np.asarray(strides))[-1] | |
| self.layers = [ | |
| SConv1d(in_channels=C, out_channels=D, kernel_size=3, norm=NORM), | |
| Swish(), | |
| ] | |
| for stride in strides: | |
| self.layers += [ | |
| EncoderBlock(out_channels=D, stride=stride), | |
| Swish(), | |
| ] | |
| self.layers += [ | |
| SConv1d(in_channels=D, out_channels=D, kernel_size=3, norm=NORM), | |
| SLSTM(D, num_layers=1, bidirectional=True), | |
| ] | |
| self.layers = nn.Sequential(*self.layers) | |
| def forward(self, x): | |
| if self.checkpointing: | |
| x = checkpoint( | |
| self.layers, x.transpose(1, 2), use_reentrant=False | |
| ).transpose(1, 2) | |
| else: | |
| x = self.layers(x.transpose(1, 2)).transpose(1, 2) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self, C, D, H, strides=[2, 2], checkpointing=True): | |
| super().__init__() | |
| self.checkpointing = checkpointing | |
| self.in_layer = nn.Sequential( | |
| SConv1d(in_channels=D, out_channels=H, kernel_size=3, norm=NORM), | |
| SLSTM(H, num_layers=1, bidirectional=True), | |
| ) | |
| self.layers = nn.ModuleList() | |
| for stride in strides: | |
| self.layers.append( | |
| nn.Sequential(DecoderBlock(in_channels=H, stride=stride), Swish()) | |
| ) | |
| self.out_layer = SConv1d( | |
| in_channels=H, out_channels=C, kernel_size=3, norm=NORM | |
| ) | |
| def forward(self, x, g=None): | |
| if self.checkpointing: | |
| y = checkpoint(self._forward, x, g, use_reentrant=False) | |
| else: | |
| y = self._forward(x, g) | |
| return y | |
| def _forward(self, x, g=None): | |
| h = self.in_layer(x.transpose(1, 2)) | |
| for layer in self.layers: | |
| up_g = g.unsqueeze(-1).repeat(1, 1, h.shape[-1]) | |
| h = h + up_g | |
| h = layer(h) | |
| y = self.out_layer(h) | |
| return y.transpose(1, 2), h.transpose(1, 2) | |
| class TimeRegulator(nn.Module): | |
| def __init__(self, in_dim, scale, learnable=False): | |
| super().__init__() | |
| self.scale = scale | |
| self.learnable = learnable | |
| def forward(self, x, x_len, downsample=True): | |
| if downsample: | |
| x = self.downsample(x, x_len) | |
| else: | |
| x = self.upsample(x, x_len) | |
| return x | |
| def downsample(self, x, x_len): | |
| x = torch.nn.functional.avg_pool1d( | |
| x.transpose(1, 2), self.scale, stride=self.scale, ceil_mode=True | |
| ).transpose(1, 2) | |
| x_len = (x_len / self.scale).ceil() | |
| return x, x_len | |
| def upsample(self, x, x_len): | |
| if self.learnable: | |
| x = self.upsampler(x.transpose(1, 2)).transpose(1, 2) | |
| else: | |
| x = torch.repeat_interleave(x, self.scale, dim=1) | |
| return x | |
| class TreeVectorQuantization(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| vq_class="VectorQuantization", | |
| vq_config={}, | |
| tree_config={}, | |
| ): | |
| super().__init__() | |
| self.vq_config = vq_config | |
| self.tree_config = tree_config | |
| self.quantizers = nn.ModuleList() | |
| self.time_regulators = nn.ModuleList() | |
| for config in self.tree_config: | |
| vq_config = self.vq_config.copy() | |
| if not isinstance(vq_config["codebook_size"], (tuple, list)): | |
| vq_config["codebook_size"] = [vq_config["codebook_size"]] | |
| vq_config["codebook_dim"] = [vq_config["codebook_dim"]] | |
| vq_config["codebook_size"] = vq_config["codebook_size"] * config["n_groups"] | |
| vq_config["codebook_dim"] = vq_config["codebook_dim"] * config["n_groups"] | |
| self.quantizers.append( | |
| VectorQuantization( | |
| in_dim, | |
| n_groups=config.get("n_groups", 1), | |
| dropout_rate_per_group=config.get("dropout_rate_per_group", 0), | |
| ordered=config.get("ordered", False), | |
| **vq_config, | |
| ) | |
| ) | |
| self.time_regulators.append( | |
| TimeRegulator( | |
| in_dim, | |
| config["downsample_rate"], | |
| config.get("learnable_time_regulator", False), | |
| ) | |
| ) | |
| def forward( | |
| self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False | |
| ): | |
| output, (quants, losses, embed_inds) = self.quantize( | |
| inp, | |
| inp_len, | |
| enable_vq=enable_vq, | |
| update_codebook=update_codebook, | |
| return_pre_quant=return_pre_quant, | |
| ) | |
| loss = sum(losses) / len(losses) | |
| return output, (quants, loss, embed_inds) | |
| def quantize( | |
| self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False | |
| ): | |
| quants, losses, embed_inds = [], [], [] | |
| pre_quant_output, quant_output, residual = 0, 0, inp | |
| for tree_config, quantizer, regulator in zip( | |
| self.tree_config, self.quantizers, self.time_regulators | |
| ): | |
| # Downsample | |
| x, x_len = regulator(residual, inp_len, True) | |
| # Quantization | |
| q, diff, embed_ind = quantizer( | |
| x, | |
| x_len, | |
| enable_vq=enable_vq, | |
| update_codebook=update_codebook, | |
| return_pre_quant=return_pre_quant, | |
| ) | |
| if return_pre_quant: | |
| pq, q = q | |
| # Upsample | |
| x = regulator(q, x_len, False)[:, : residual.shape[1]] | |
| residual = residual - x | |
| quant_output = quant_output + x | |
| if return_pre_quant: | |
| pq = regulator(pq, x_len, False)[:, : residual.shape[1]] | |
| pre_quant_output = pre_quant_output + pq | |
| quants.append(q) | |
| losses.append(diff) | |
| embed_inds.append(embed_ind) | |
| if return_pre_quant: | |
| return (pre_quant_output, quant_output), (quants, losses, embed_inds) | |
| return quant_output, (quants, losses, embed_inds) | |
| def decode(self, seqs, seq_lens=None): | |
| if not isinstance(seqs, (tuple, list)): | |
| tokens, token_lens = self.deserialize(seqs, seq_lens) | |
| else: | |
| tokens, token_lens = seqs, seq_lens | |
| quant_output = 0 | |
| for token, quantizer, regulator in zip( | |
| tokens, self.quantizers, self.time_regulators | |
| ): | |
| x = quantizer.decode(token).transpose(1, 2) | |
| x = regulator(x, None, False) | |
| if torch.is_tensor(quant_output): | |
| x = x[:, : quant_output.size(1)] | |
| quant_output = quant_output + x | |
| return quant_output, token_lens | |
| def serialize(self, tokens, token_lens): | |
| assert len(tokens) <= 2, "we only support 1 or 2-scale sequences now..." | |
| scale = self.tree_config[0]["downsample_rate"] | |
| token_lens = ((token_lens.float() / scale).ceil() * scale).int() | |
| seq1 = tokens[0].unsqueeze(-1) | |
| if len(tokens) == 1: | |
| seq_cat = seq1.view(seq1.shape[0], -1) | |
| seq_cat_lens = (token_lens / scale * seq1.shape[2]).int() | |
| elif len(tokens) == 2: | |
| seq2 = F.pad( | |
| tokens[1], (0, token_lens.max() - tokens[1].size(1)), "replicate" | |
| ) | |
| seq2 = torch.stack([seq2[:, i::scale] for i in range(scale)], dim=-1) | |
| seq_cat = torch.cat((seq1, seq2), dim=-1).view(seq1.shape[0], -1) | |
| seq_cat_lens = (token_lens / scale + token_lens).int() | |
| return seq_cat, seq_cat_lens | |
| def deserialize(self, seqs, seq_lens): | |
| if len(self.tree_config) == 1: | |
| return [seqs], seq_lens | |
| max_scale = max(config["downsample_rate"] for config in self.tree_config) | |
| total_scale = sum(config["downsample_rate"] for config in self.tree_config) | |
| # Cut for aligning | |
| if seq_lens is None: | |
| seq_lens = torch.full([seqs.shape[0]], seqs.shape[1]).to(seqs.device) | |
| seq_lens = (seq_lens / total_scale).int() * total_scale | |
| token_lens = (seq_lens / total_scale).int() * max_scale | |
| seqs = seqs[:, : seq_lens.max()] | |
| # Separate | |
| tokens = torch.stack( | |
| [seqs[:, i::total_scale] for i in range(total_scale)], dim=-1 | |
| ) | |
| seq1 = tokens[..., 0] | |
| seq2 = tokens[..., 1:].contiguous().view(tokens.shape[0], -1) | |
| return [seq1, seq2], token_lens | |
| class SemanticVQVAE(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| n_model_size, | |
| downsample_scales=[1, 2], | |
| upsample_scales=[[2, 1], [2, 1]], | |
| mel_config={}, | |
| ssl_config={}, | |
| # Quantization | |
| vq_class="VectorQuantization", | |
| vq_config={}, | |
| tree_config={}, | |
| # Training | |
| checkpointing=True, | |
| dual_decoding=False, | |
| n_samples_per_token=640, | |
| online_extraction=True, | |
| ssl_extractor=None, | |
| ): | |
| super(SemanticVQVAE, self).__init__() | |
| self.in_dim = in_dim | |
| self.n_model_size = n_model_size | |
| self.mel_config = mel_config | |
| self.dual_decoding = dual_decoding | |
| self.vq_config = vq_config | |
| self.tree_config = tree_config | |
| self.output_feature = "mel" | |
| self.n_samples_per_token = n_samples_per_token | |
| self.checkpointing = checkpointing | |
| self.mel_spectrogram = TorchMelSpectrogram(**mel_config) | |
| # Speaker encoder | |
| self.speaker_encoder = ECAPA_TDNN( | |
| out_dim, | |
| n_model_size, | |
| channels=[512, 512, 512, 512, 1536], | |
| kernel_sizes=[5, 3, 3, 3, 1], | |
| dilations=[1, 2, 3, 4, 1], | |
| attention_channels=128, | |
| res2net_scale=4, | |
| se_channels=128, | |
| global_context=True, | |
| batch_norm=True, | |
| ) | |
| # Encoder & decoder | |
| self.encoder = Encoder( | |
| in_dim, n_model_size, downsample_scales, checkpointing=checkpointing | |
| ) | |
| # Quantization | |
| self.quantizer = TreeVectorQuantization( | |
| n_model_size, | |
| vq_class=vq_class, | |
| vq_config=vq_config, | |
| tree_config=tree_config, | |
| ) | |
| def forward( | |
| self, | |
| wav, | |
| wav_length, | |
| enable_vq=True, | |
| decode=True, | |
| extract_spk=True, | |
| shuffle=False, | |
| **kwargs, | |
| ): | |
| output_dict = {} | |
| with torch.no_grad(): | |
| # Pad waveform | |
| if wav.shape[1] % self.n_samples_per_token > 0: | |
| pad_size = ( | |
| self.n_samples_per_token - wav.shape[1] % self.n_samples_per_token | |
| ) | |
| wav = F.pad(wav, (0, pad_size), value=0) | |
| wav_length += pad_size | |
| # Extract mel & sll | |
| mel, mel_length = kwargs.get("mel", None), kwargs.get("mel_length", None) | |
| if mel is None: | |
| mel, mel_length = self.mel_spectrogram(wav, wav_length) | |
| output_dict.update({"mel": mel, "mel_length": mel_length}) | |
| ssl, ssl_length = kwargs.get("ssl", None), kwargs.get("ssl_length", None) | |
| if ssl is None: | |
| ssl, ssl_length = self.ssl_extractor(wav, wav_length) | |
| output_dict.update({"ssl": ssl.float(), "ssl_length": ssl_length}) | |
| input, input_length = ssl, ssl_length | |
| output, output_length = mel, mel_length | |
| encoder_outputs = self.encoder(input) | |
| quant_length = torch.ceil(input_length / self.encoder.downsample_scale) | |
| quant_length = quant_length.clamp(max=encoder_outputs.shape[1]) | |
| quant, (quants, diff, embed_ind) = self.quantizer( | |
| encoder_outputs, | |
| quant_length, | |
| enable_vq=enable_vq, | |
| update_codebook=True, | |
| return_pre_quant=self.dual_decoding, | |
| ) | |
| output_dict.update( | |
| { | |
| "quants": quants, | |
| "token": embed_ind, | |
| "token_length": quant_length.int(), | |
| "encoder_diffs": diff, | |
| } | |
| ) | |
| # Speaker | |
| if extract_spk: | |
| cond, cond_length = output, output_length | |
| speaker_embedding = self.speaker_encoder(cond, cond_length) | |
| speaker_embedding_1 = speaker_embedding_2 = speaker_embedding | |
| output_dict["spk"] = speaker_embedding | |
| return output_dict | |
| def extract_speech_tokens( | |
| self, wav, wav_length, serialize=True, extract_spk=True, shuffle=False | |
| ): | |
| output_dict = self.forward( | |
| wav, wav_length, True, False, extract_spk=extract_spk, shuffle=shuffle | |
| ) | |
| token_seqs, token_length = output_dict["token"], output_dict["token_length"] | |
| # Align sequences | |
| scale = self.tree_config[0]["downsample_rate"] | |
| token_length = (torch.ceil(token_length / scale) * scale).int() | |
| new_token_seqs, new_token_lens = [], [] | |
| for i, token_seq in enumerate(token_seqs): | |
| # discrete-continuous tokens | |
| residual = None | |
| if isinstance(token_seq, (tuple, list)): | |
| token_seq, residual = token_seq | |
| scale = self.tree_config[i]["downsample_rate"] | |
| new_token_len = token_length // scale | |
| pad = int(new_token_len.max()) - token_seq.shape[1] | |
| token_seq = F.pad( | |
| token_seq, | |
| (0, pad) if len(token_seq.shape) == 2 else (0, 0, 0, pad), | |
| "replicate", | |
| ) | |
| if residual is not None: | |
| token_seq = (token_seq, residual) | |
| new_token_seqs.append(token_seq) | |
| new_token_lens.append(new_token_len) | |
| if len(new_token_seqs) == 1: | |
| new_token_seqs, new_token_lens = new_token_seqs[0], new_token_lens[0] | |
| elif serialize: | |
| new_token_seqs, new_token_lens = self.quantizer.serialize( | |
| new_token_seqs, new_token_lens | |
| ) | |
| output_dict.update( | |
| { | |
| "embed": output_dict["quants"], | |
| "token": new_token_seqs, | |
| "token_length": new_token_lens, | |
| } | |
| ) | |
| return output_dict | |
| def code_to_latent(self, token, mel=None): | |
| quant, _ = self.quantizer.decode(token, None) | |
| speaker_embedding = self.speaker_encoder(mel) | |
| latents = quant + speaker_embedding.unsqueeze(1).repeat(1, quant.shape[1], 1) | |
| return { | |
| "latents": latents, | |
| } | |