from functools import reduce from tokenize import Triple from torch import nn from torch.autograd import Function from torch.nn import functional as F from torch.nn.utils import spectral_norm, weight_norm from torch.utils.checkpoint import checkpoint import einops import math import numpy as np import os import random import torch import torchaudio import typing as tp import warnings from .audio import TorchMelSpectrogram from .ecapa_tdnn import ECAPA_TDNN from .hubert import HuBERT from ..acoustic_codec.vector_quantization import VectorQuantization CONV_NORMALIZATIONS = frozenset( [ "none", "weight_norm", "spectral_norm", "time_layer_norm", "layer_norm", "time_group_norm", ] ) NORM = "weight_norm" def get_mask_from_lengths(lengths, max_len=None): max_len = torch.max(lengths).item() if max_len is None else max_len ids = torch.arange(0, max_len).to(lengths.device) mask = ~(ids < lengths.unsqueeze(1)).bool() return mask class ConvLayerNorm(nn.LayerNorm): def __init__( self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs ): super().__init__(normalized_shape, **kwargs) def forward(self, x): x = einops.rearrange(x, "b ... t -> b t ...") x = super().forward(x) x = einops.rearrange(x, "b t ... -> b ... t") return def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: assert norm in CONV_NORMALIZATIONS if norm == "weight_norm": return weight_norm(module) elif norm == "spectral_norm": return spectral_norm(module) else: # We already check was in CONV_NORMALIZATION, so any other choice # doesn't need reparametrization. return module def get_norm_module( module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs ) -> nn.Module: assert norm in CONV_NORMALIZATIONS if norm == "layer_norm": assert isinstance(module, nn.modules.conv._ConvNd) return ConvLayerNorm(module.out_channels, **norm_kwargs) elif norm == "time_group_norm": if causal: raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) return nn.GroupNorm(1, module.out_channels, **norm_kwargs) else: return nn.Identity() def get_extra_padding_for_conv1d( x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 ) -> int: length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) return ideal_length - length def pad_for_conv1d( x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 ): extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) return F.pad(x, (0, extra_padding)) def pad1d( x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "zero", value: float = 0.0, ): length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: extra_pad = max_pad - length + 1 x = F.pad(x, (0, extra_pad)) padded = F.pad(x, paddings, mode, value) end = padded.shape[-1] - extra_pad return padded[..., :end] else: return F.pad(x, paddings, mode, value) def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right return x[..., padding_left:end] class NormConv1d(nn.Module): def __init__( self, *args, causal: bool = False, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs, ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.conv(x) x = self.norm(x) return x class NormConv2d(nn.Module): def __init__( self, *args, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs, ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.conv(x) x = self.norm(x) return x class NormConvTranspose1d(nn.Module): def __init__( self, *args, causal: bool = False, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs, ): super().__init__() self.convtr = apply_parametrization_norm( nn.ConvTranspose1d(*args, **kwargs), norm ) self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.convtr(x) x = self.norm(x) return x class NormConvTranspose2d(nn.Module): def __init__( self, *args, norm: str = "none", norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs, ): super().__init__() self.convtr = apply_parametrization_norm( nn.ConvTranspose2d(*args, **kwargs), norm ) self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) def forward(self, x): x = self.convtr(x) x = self.norm(x) return x class SConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, causal: bool = False, norm: str = "weight_norm", norm_kwargs: tp.Dict[str, tp.Any] = {}, pad_mode: str = "reflect", ): super().__init__() # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: warnings.warn( "SConv1d has been initialized with stride > 1 and dilation > 1" f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." ) self.conv = NormConv1d( in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias, causal=causal, norm=norm, norm_kwargs=norm_kwargs, ) self.causal = causal self.pad_mode = pad_mode def forward(self, x): B, C, T = x.shape kernel_size = self.conv.conv.kernel_size[0] stride = self.conv.conv.stride[0] dilation = self.conv.conv.dilation[0] kernel_size = ( kernel_size - 1 ) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride extra_padding = get_extra_padding_for_conv1d( x, kernel_size, stride, padding_total ) if self.causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right x = pad1d( x, (padding_left, padding_right + extra_padding), mode=self.pad_mode ) return self.conv(x) class SConvTranspose1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, norm: str = "weight_norm", trim_right_ratio: float = 1.0, norm_kwargs: tp.Dict[str, tp.Any] = {}, ): super().__init__() self.convtr = NormConvTranspose1d( in_channels, out_channels, kernel_size, stride, causal=causal, norm=norm, norm_kwargs=norm_kwargs, ) self.causal = causal self.trim_right_ratio = trim_right_ratio assert ( self.causal or self.trim_right_ratio == 1.0 ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 def forward(self, x): kernel_size = self.convtr.convtr.kernel_size[0] stride = self.convtr.convtr.stride[0] padding_total = kernel_size - stride y = self.convtr(x) # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be # removed at the very end, when keeping only the right length for the output, # as removing it here would require also passing the length at the matching layer # in the encoder. if self.causal: # Trim the padding on the right according to the specified ratio # if trim_right_ratio = 1.0, trim everything from right padding_right = math.ceil(padding_total * self.trim_right_ratio) padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) return y class SLSTM(nn.Module): def __init__( self, dimension: int, num_layers: int = 2, bidirectional: bool = False, skip: bool = True, ): super().__init__() self.bidirectional = bidirectional self.skip = skip if bidirectional: self.lstm = nn.LSTM( dimension, dimension // 2, num_layers, bidirectional=bidirectional ) else: self.lstm = nn.LSTM(dimension, dimension, num_layers) def forward(self, x): x = x.permute(2, 0, 1) y, _ = self.lstm(x) if self.skip: y = y + x y = y.permute(1, 2, 0) return y class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class ResidualUnit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, groups=1): super().__init__() self.layers = nn.Sequential( SConv1d( in_channels=in_channels, out_channels=out_channels // 2, kernel_size=kernel_size, groups=groups, norm=NORM, ), Swish(), SConv1d( in_channels=out_channels // 2, out_channels=out_channels, kernel_size=kernel_size, groups=groups, norm=NORM, ), ) def forward(self, x): return x + self.layers(x) class EncoderBlock(nn.Module): def __init__(self, out_channels, stride): super().__init__() self.layers = nn.Sequential( ResidualUnit(in_channels=out_channels, out_channels=out_channels), Swish(), ResidualUnit(in_channels=out_channels, out_channels=out_channels), Swish(), SConv1d( in_channels=out_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride, norm=NORM, ), ) def forward(self, x): return self.layers(x) class DecoderBlock(nn.Module): def __init__(self, in_channels, stride): super().__init__() out_channels = in_channels self.layers = nn.Sequential( SConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride, norm=NORM, ), Swish(), ResidualUnit(in_channels=out_channels, out_channels=out_channels), Swish(), ResidualUnit(in_channels=out_channels, out_channels=out_channels), ) def forward(self, x): return self.layers(x) class Encoder(nn.Module): def __init__(self, C, D, strides=[2, 2], checkpointing=True): super().__init__() self.checkpointing = checkpointing self.downsample_scale = np.cumprod(np.asarray(strides))[-1] self.layers = [ SConv1d(in_channels=C, out_channels=D, kernel_size=3, norm=NORM), Swish(), ] for stride in strides: self.layers += [ EncoderBlock(out_channels=D, stride=stride), Swish(), ] self.layers += [ SConv1d(in_channels=D, out_channels=D, kernel_size=3, norm=NORM), SLSTM(D, num_layers=1, bidirectional=True), ] self.layers = nn.Sequential(*self.layers) def forward(self, x): if self.checkpointing: x = checkpoint( self.layers, x.transpose(1, 2), use_reentrant=False ).transpose(1, 2) else: x = self.layers(x.transpose(1, 2)).transpose(1, 2) return x class Decoder(nn.Module): def __init__(self, C, D, H, strides=[2, 2], checkpointing=True): super().__init__() self.checkpointing = checkpointing self.in_layer = nn.Sequential( SConv1d(in_channels=D, out_channels=H, kernel_size=3, norm=NORM), SLSTM(H, num_layers=1, bidirectional=True), ) self.layers = nn.ModuleList() for stride in strides: self.layers.append( nn.Sequential(DecoderBlock(in_channels=H, stride=stride), Swish()) ) self.out_layer = SConv1d( in_channels=H, out_channels=C, kernel_size=3, norm=NORM ) def forward(self, x, g=None): if self.checkpointing: y = checkpoint(self._forward, x, g, use_reentrant=False) else: y = self._forward(x, g) return y def _forward(self, x, g=None): h = self.in_layer(x.transpose(1, 2)) for layer in self.layers: up_g = g.unsqueeze(-1).repeat(1, 1, h.shape[-1]) h = h + up_g h = layer(h) y = self.out_layer(h) return y.transpose(1, 2), h.transpose(1, 2) class TimeRegulator(nn.Module): def __init__(self, in_dim, scale, learnable=False): super().__init__() self.scale = scale self.learnable = learnable def forward(self, x, x_len, downsample=True): if downsample: x = self.downsample(x, x_len) else: x = self.upsample(x, x_len) return x def downsample(self, x, x_len): x = torch.nn.functional.avg_pool1d( x.transpose(1, 2), self.scale, stride=self.scale, ceil_mode=True ).transpose(1, 2) x_len = (x_len / self.scale).ceil() return x, x_len def upsample(self, x, x_len): if self.learnable: x = self.upsampler(x.transpose(1, 2)).transpose(1, 2) else: x = torch.repeat_interleave(x, self.scale, dim=1) return x class TreeVectorQuantization(nn.Module): def __init__( self, in_dim, vq_class="VectorQuantization", vq_config={}, tree_config={}, ): super().__init__() self.vq_config = vq_config self.tree_config = tree_config self.quantizers = nn.ModuleList() self.time_regulators = nn.ModuleList() for config in self.tree_config: vq_config = self.vq_config.copy() if not isinstance(vq_config["codebook_size"], (tuple, list)): vq_config["codebook_size"] = [vq_config["codebook_size"]] vq_config["codebook_dim"] = [vq_config["codebook_dim"]] vq_config["codebook_size"] = vq_config["codebook_size"] * config["n_groups"] vq_config["codebook_dim"] = vq_config["codebook_dim"] * config["n_groups"] self.quantizers.append( VectorQuantization( in_dim, n_groups=config.get("n_groups", 1), dropout_rate_per_group=config.get("dropout_rate_per_group", 0), ordered=config.get("ordered", False), **vq_config, ) ) self.time_regulators.append( TimeRegulator( in_dim, config["downsample_rate"], config.get("learnable_time_regulator", False), ) ) def forward( self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False ): output, (quants, losses, embed_inds) = self.quantize( inp, inp_len, enable_vq=enable_vq, update_codebook=update_codebook, return_pre_quant=return_pre_quant, ) loss = sum(losses) / len(losses) return output, (quants, loss, embed_inds) def quantize( self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False ): quants, losses, embed_inds = [], [], [] pre_quant_output, quant_output, residual = 0, 0, inp for tree_config, quantizer, regulator in zip( self.tree_config, self.quantizers, self.time_regulators ): # Downsample x, x_len = regulator(residual, inp_len, True) # Quantization q, diff, embed_ind = quantizer( x, x_len, enable_vq=enable_vq, update_codebook=update_codebook, return_pre_quant=return_pre_quant, ) if return_pre_quant: pq, q = q # Upsample x = regulator(q, x_len, False)[:, : residual.shape[1]] residual = residual - x quant_output = quant_output + x if return_pre_quant: pq = regulator(pq, x_len, False)[:, : residual.shape[1]] pre_quant_output = pre_quant_output + pq quants.append(q) losses.append(diff) embed_inds.append(embed_ind) if return_pre_quant: return (pre_quant_output, quant_output), (quants, losses, embed_inds) return quant_output, (quants, losses, embed_inds) def decode(self, seqs, seq_lens=None): if not isinstance(seqs, (tuple, list)): tokens, token_lens = self.deserialize(seqs, seq_lens) else: tokens, token_lens = seqs, seq_lens quant_output = 0 for token, quantizer, regulator in zip( tokens, self.quantizers, self.time_regulators ): x = quantizer.decode(token).transpose(1, 2) x = regulator(x, None, False) if torch.is_tensor(quant_output): x = x[:, : quant_output.size(1)] quant_output = quant_output + x return quant_output, token_lens def serialize(self, tokens, token_lens): assert len(tokens) <= 2, "we only support 1 or 2-scale sequences now..." scale = self.tree_config[0]["downsample_rate"] token_lens = ((token_lens.float() / scale).ceil() * scale).int() seq1 = tokens[0].unsqueeze(-1) if len(tokens) == 1: seq_cat = seq1.view(seq1.shape[0], -1) seq_cat_lens = (token_lens / scale * seq1.shape[2]).int() elif len(tokens) == 2: seq2 = F.pad( tokens[1], (0, token_lens.max() - tokens[1].size(1)), "replicate" ) seq2 = torch.stack([seq2[:, i::scale] for i in range(scale)], dim=-1) seq_cat = torch.cat((seq1, seq2), dim=-1).view(seq1.shape[0], -1) seq_cat_lens = (token_lens / scale + token_lens).int() return seq_cat, seq_cat_lens def deserialize(self, seqs, seq_lens): if len(self.tree_config) == 1: return [seqs], seq_lens max_scale = max(config["downsample_rate"] for config in self.tree_config) total_scale = sum(config["downsample_rate"] for config in self.tree_config) # Cut for aligning if seq_lens is None: seq_lens = torch.full([seqs.shape[0]], seqs.shape[1]).to(seqs.device) seq_lens = (seq_lens / total_scale).int() * total_scale token_lens = (seq_lens / total_scale).int() * max_scale seqs = seqs[:, : seq_lens.max()] # Separate tokens = torch.stack( [seqs[:, i::total_scale] for i in range(total_scale)], dim=-1 ) seq1 = tokens[..., 0] seq2 = tokens[..., 1:].contiguous().view(tokens.shape[0], -1) return [seq1, seq2], token_lens class SemanticVQVAE(nn.Module): def __init__( self, in_dim, out_dim, n_model_size, downsample_scales=[1, 2], upsample_scales=[[2, 1], [2, 1]], mel_config={}, ssl_config={}, # Quantization vq_class="VectorQuantization", vq_config={}, tree_config={}, # Training checkpointing=True, dual_decoding=False, n_samples_per_token=640, online_extraction=True, ssl_extractor=None, ): super(SemanticVQVAE, self).__init__() self.in_dim = in_dim self.n_model_size = n_model_size self.mel_config = mel_config self.dual_decoding = dual_decoding self.vq_config = vq_config self.tree_config = tree_config self.output_feature = "mel" self.n_samples_per_token = n_samples_per_token self.checkpointing = checkpointing self.mel_spectrogram = TorchMelSpectrogram(**mel_config) # Speaker encoder self.speaker_encoder = ECAPA_TDNN( out_dim, n_model_size, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=4, se_channels=128, global_context=True, batch_norm=True, ) # Encoder & decoder self.encoder = Encoder( in_dim, n_model_size, downsample_scales, checkpointing=checkpointing ) # Quantization self.quantizer = TreeVectorQuantization( n_model_size, vq_class=vq_class, vq_config=vq_config, tree_config=tree_config, ) def forward( self, wav, wav_length, enable_vq=True, decode=True, extract_spk=True, shuffle=False, **kwargs, ): output_dict = {} with torch.no_grad(): # Pad waveform if wav.shape[1] % self.n_samples_per_token > 0: pad_size = ( self.n_samples_per_token - wav.shape[1] % self.n_samples_per_token ) wav = F.pad(wav, (0, pad_size), value=0) wav_length += pad_size # Extract mel & sll mel, mel_length = kwargs.get("mel", None), kwargs.get("mel_length", None) if mel is None: mel, mel_length = self.mel_spectrogram(wav, wav_length) output_dict.update({"mel": mel, "mel_length": mel_length}) ssl, ssl_length = kwargs.get("ssl", None), kwargs.get("ssl_length", None) if ssl is None: ssl, ssl_length = self.ssl_extractor(wav, wav_length) output_dict.update({"ssl": ssl.float(), "ssl_length": ssl_length}) input, input_length = ssl, ssl_length output, output_length = mel, mel_length encoder_outputs = self.encoder(input) quant_length = torch.ceil(input_length / self.encoder.downsample_scale) quant_length = quant_length.clamp(max=encoder_outputs.shape[1]) quant, (quants, diff, embed_ind) = self.quantizer( encoder_outputs, quant_length, enable_vq=enable_vq, update_codebook=True, return_pre_quant=self.dual_decoding, ) output_dict.update( { "quants": quants, "token": embed_ind, "token_length": quant_length.int(), "encoder_diffs": diff, } ) # Speaker if extract_spk: cond, cond_length = output, output_length speaker_embedding = self.speaker_encoder(cond, cond_length) speaker_embedding_1 = speaker_embedding_2 = speaker_embedding output_dict["spk"] = speaker_embedding return output_dict @torch.no_grad() def extract_speech_tokens( self, wav, wav_length, serialize=True, extract_spk=True, shuffle=False ): output_dict = self.forward( wav, wav_length, True, False, extract_spk=extract_spk, shuffle=shuffle ) token_seqs, token_length = output_dict["token"], output_dict["token_length"] # Align sequences scale = self.tree_config[0]["downsample_rate"] token_length = (torch.ceil(token_length / scale) * scale).int() new_token_seqs, new_token_lens = [], [] for i, token_seq in enumerate(token_seqs): # discrete-continuous tokens residual = None if isinstance(token_seq, (tuple, list)): token_seq, residual = token_seq scale = self.tree_config[i]["downsample_rate"] new_token_len = token_length // scale pad = int(new_token_len.max()) - token_seq.shape[1] token_seq = F.pad( token_seq, (0, pad) if len(token_seq.shape) == 2 else (0, 0, 0, pad), "replicate", ) if residual is not None: token_seq = (token_seq, residual) new_token_seqs.append(token_seq) new_token_lens.append(new_token_len) if len(new_token_seqs) == 1: new_token_seqs, new_token_lens = new_token_seqs[0], new_token_lens[0] elif serialize: new_token_seqs, new_token_lens = self.quantizer.serialize( new_token_seqs, new_token_lens ) output_dict.update( { "embed": output_dict["quants"], "token": new_token_seqs, "token_length": new_token_lens, } ) return output_dict @torch.no_grad() def code_to_latent(self, token, mel=None): quant, _ = self.quantizer.decode(token, None) speaker_embedding = self.speaker_encoder(mel) latents = quant + speaker_embedding.unsqueeze(1).repeat(1, quant.shape[1], 1) return { "latents": latents, }