| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch.nn.utils import weight_norm |
| |
|
| | from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ( |
| | ResidualVQ, |
| | VectorQuantize, |
| | FactorizedVectorQuantize, |
| | LookupFreeQuantize, |
| | ) |
| |
|
| | from indextts.utils.maskgct.models.codec.amphion_codec.vocos import Vocos |
| |
|
| |
|
| | def WNConv1d(*args, **kwargs): |
| | return weight_norm(nn.Conv1d(*args, **kwargs)) |
| |
|
| |
|
| | def WNConvTranspose1d(*args, **kwargs): |
| | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
| |
|
| |
|
| | |
| | @torch.jit.script |
| | def snake(x, alpha): |
| | shape = x.shape |
| | x = x.reshape(shape[0], shape[1], -1) |
| | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
| | x = x.reshape(shape) |
| | return x |
| |
|
| |
|
| | class Snake1d(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
| |
|
| | def forward(self, x): |
| | return snake(x, self.alpha) |
| |
|
| |
|
| | def init_weights(m): |
| | if isinstance(m, nn.Conv1d): |
| | nn.init.trunc_normal_(m.weight, std=0.02) |
| | nn.init.constant_(m.bias, 0) |
| | if isinstance(m, nn.Linear): |
| | nn.init.trunc_normal_(m.weight, std=0.02) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| |
|
| | class ResidualUnit(nn.Module): |
| | def __init__(self, dim: int = 16, dilation: int = 1): |
| | super().__init__() |
| | pad = ((7 - 1) * dilation) // 2 |
| | self.block = nn.Sequential( |
| | Snake1d(dim), |
| | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
| | Snake1d(dim), |
| | WNConv1d(dim, dim, kernel_size=1), |
| | ) |
| |
|
| | def forward(self, x): |
| | y = self.block(x) |
| | pad = (x.shape[-1] - y.shape[-1]) // 2 |
| | if pad > 0: |
| | x = x[..., pad:-pad] |
| | return x + y |
| |
|
| |
|
| | class EncoderBlock(nn.Module): |
| | def __init__(self, dim: int = 16, stride: int = 1): |
| | super().__init__() |
| | self.block = nn.Sequential( |
| | ResidualUnit(dim // 2, dilation=1), |
| | ResidualUnit(dim // 2, dilation=3), |
| | ResidualUnit(dim // 2, dilation=9), |
| | Snake1d(dim // 2), |
| | WNConv1d( |
| | dim // 2, |
| | dim, |
| | kernel_size=2 * stride, |
| | stride=stride, |
| | padding=math.ceil(stride / 2), |
| | ), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| |
|
| | class CodecEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int = 64, |
| | up_ratios: list = [4, 5, 5, 6], |
| | out_channels: int = 256, |
| | use_tanh: bool = False, |
| | cfg=None, |
| | ): |
| | super().__init__() |
| |
|
| | d_model = cfg.d_model if cfg is not None else d_model |
| | up_ratios = cfg.up_ratios if cfg is not None else up_ratios |
| | out_channels = cfg.out_channels if cfg is not None else out_channels |
| | use_tanh = cfg.use_tanh if cfg is not None else use_tanh |
| |
|
| | |
| | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
| |
|
| | |
| | for stride in up_ratios: |
| | d_model *= 2 |
| | self.block += [EncoderBlock(d_model, stride=stride)] |
| |
|
| | |
| | self.block += [ |
| | Snake1d(d_model), |
| | WNConv1d(d_model, out_channels, kernel_size=3, padding=1), |
| | ] |
| |
|
| | if use_tanh: |
| | self.block += [nn.Tanh()] |
| |
|
| | |
| | self.block = nn.Sequential(*self.block) |
| | self.enc_dim = d_model |
| |
|
| | self.reset_parameters() |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| | def reset_parameters(self): |
| | self.apply(init_weights) |
| |
|
| |
|
| | class DecoderBlock(nn.Module): |
| | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): |
| | super().__init__() |
| | self.block = nn.Sequential( |
| | Snake1d(input_dim), |
| | WNConvTranspose1d( |
| | input_dim, |
| | output_dim, |
| | kernel_size=2 * stride, |
| | stride=stride, |
| | padding=stride // 2 + stride % 2, |
| | output_padding=stride % 2, |
| | ), |
| | ResidualUnit(output_dim, dilation=1), |
| | ResidualUnit(output_dim, dilation=3), |
| | ResidualUnit(output_dim, dilation=9), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| |
|
| | class CodecDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int = 256, |
| | upsample_initial_channel: int = 1536, |
| | up_ratios: list = [5, 5, 4, 2], |
| | num_quantizers: int = 8, |
| | codebook_size: int = 1024, |
| | codebook_dim: int = 256, |
| | quantizer_type: str = "vq", |
| | quantizer_dropout: float = 0.5, |
| | commitment: float = 0.25, |
| | codebook_loss_weight: float = 1.0, |
| | use_l2_normlize: bool = False, |
| | codebook_type: str = "euclidean", |
| | kmeans_init: bool = False, |
| | kmeans_iters: int = 10, |
| | decay: float = 0.8, |
| | eps: float = 1e-5, |
| | threshold_ema_dead_code: int = 2, |
| | weight_init: bool = False, |
| | use_vocos: bool = False, |
| | vocos_dim: int = 384, |
| | vocos_intermediate_dim: int = 1152, |
| | vocos_num_layers: int = 8, |
| | n_fft: int = 800, |
| | hop_size: int = 200, |
| | padding: str = "same", |
| | cfg=None, |
| | ): |
| | super().__init__() |
| |
|
| | in_channels = ( |
| | cfg.in_channels |
| | if cfg is not None and hasattr(cfg, "in_channels") |
| | else in_channels |
| | ) |
| | upsample_initial_channel = ( |
| | cfg.upsample_initial_channel |
| | if cfg is not None and hasattr(cfg, "upsample_initial_channel") |
| | else upsample_initial_channel |
| | ) |
| | up_ratios = ( |
| | cfg.up_ratios |
| | if cfg is not None and hasattr(cfg, "up_ratios") |
| | else up_ratios |
| | ) |
| | num_quantizers = ( |
| | cfg.num_quantizers |
| | if cfg is not None and hasattr(cfg, "num_quantizers") |
| | else num_quantizers |
| | ) |
| | codebook_size = ( |
| | cfg.codebook_size |
| | if cfg is not None and hasattr(cfg, "codebook_size") |
| | else codebook_size |
| | ) |
| | codebook_dim = ( |
| | cfg.codebook_dim |
| | if cfg is not None and hasattr(cfg, "codebook_dim") |
| | else codebook_dim |
| | ) |
| | quantizer_type = ( |
| | cfg.quantizer_type |
| | if cfg is not None and hasattr(cfg, "quantizer_type") |
| | else quantizer_type |
| | ) |
| | quantizer_dropout = ( |
| | cfg.quantizer_dropout |
| | if cfg is not None and hasattr(cfg, "quantizer_dropout") |
| | else quantizer_dropout |
| | ) |
| | commitment = ( |
| | cfg.commitment |
| | if cfg is not None and hasattr(cfg, "commitment") |
| | else commitment |
| | ) |
| | codebook_loss_weight = ( |
| | cfg.codebook_loss_weight |
| | if cfg is not None and hasattr(cfg, "codebook_loss_weight") |
| | else codebook_loss_weight |
| | ) |
| | use_l2_normlize = ( |
| | cfg.use_l2_normlize |
| | if cfg is not None and hasattr(cfg, "use_l2_normlize") |
| | else use_l2_normlize |
| | ) |
| | codebook_type = ( |
| | cfg.codebook_type |
| | if cfg is not None and hasattr(cfg, "codebook_type") |
| | else codebook_type |
| | ) |
| | kmeans_init = ( |
| | cfg.kmeans_init |
| | if cfg is not None and hasattr(cfg, "kmeans_init") |
| | else kmeans_init |
| | ) |
| | kmeans_iters = ( |
| | cfg.kmeans_iters |
| | if cfg is not None and hasattr(cfg, "kmeans_iters") |
| | else kmeans_iters |
| | ) |
| | decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay |
| | eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps |
| | threshold_ema_dead_code = ( |
| | cfg.threshold_ema_dead_code |
| | if cfg is not None and hasattr(cfg, "threshold_ema_dead_code") |
| | else threshold_ema_dead_code |
| | ) |
| | weight_init = ( |
| | cfg.weight_init |
| | if cfg is not None and hasattr(cfg, "weight_init") |
| | else weight_init |
| | ) |
| | use_vocos = ( |
| | cfg.use_vocos |
| | if cfg is not None and hasattr(cfg, "use_vocos") |
| | else use_vocos |
| | ) |
| | vocos_dim = ( |
| | cfg.vocos_dim |
| | if cfg is not None and hasattr(cfg, "vocos_dim") |
| | else vocos_dim |
| | ) |
| | vocos_intermediate_dim = ( |
| | cfg.vocos_intermediate_dim |
| | if cfg is not None and hasattr(cfg, "vocos_intermediate_dim") |
| | else vocos_intermediate_dim |
| | ) |
| | vocos_num_layers = ( |
| | cfg.vocos_num_layers |
| | if cfg is not None and hasattr(cfg, "vocos_num_layers") |
| | else vocos_num_layers |
| | ) |
| | n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft |
| | hop_size = ( |
| | cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size |
| | ) |
| | padding = ( |
| | cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding |
| | ) |
| |
|
| | if quantizer_type == "vq": |
| | self.quantizer = ResidualVQ( |
| | input_dim=in_channels, |
| | num_quantizers=num_quantizers, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim, |
| | quantizer_type=quantizer_type, |
| | quantizer_dropout=quantizer_dropout, |
| | commitment=commitment, |
| | codebook_loss_weight=codebook_loss_weight, |
| | use_l2_normlize=use_l2_normlize, |
| | codebook_type=codebook_type, |
| | kmeans_init=kmeans_init, |
| | kmeans_iters=kmeans_iters, |
| | decay=decay, |
| | eps=eps, |
| | threshold_ema_dead_code=threshold_ema_dead_code, |
| | weight_init=weight_init, |
| | ) |
| | elif quantizer_type == "fvq": |
| | self.quantizer = ResidualVQ( |
| | input_dim=in_channels, |
| | num_quantizers=num_quantizers, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim, |
| | quantizer_type=quantizer_type, |
| | quantizer_dropout=quantizer_dropout, |
| | commitment=commitment, |
| | codebook_loss_weight=codebook_loss_weight, |
| | use_l2_normlize=use_l2_normlize, |
| | ) |
| | elif quantizer_type == "lfq": |
| | self.quantizer = ResidualVQ( |
| | input_dim=in_channels, |
| | num_quantizers=num_quantizers, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim, |
| | quantizer_type=quantizer_type, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown quantizer type {quantizer_type}") |
| |
|
| | if not use_vocos: |
| | |
| | channels = upsample_initial_channel |
| | layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] |
| |
|
| | |
| | for i, stride in enumerate(up_ratios): |
| | input_dim = channels // 2**i |
| | output_dim = channels // 2 ** (i + 1) |
| | layers += [DecoderBlock(input_dim, output_dim, stride)] |
| |
|
| | |
| | layers += [ |
| | Snake1d(output_dim), |
| | WNConv1d(output_dim, 1, kernel_size=7, padding=3), |
| | nn.Tanh(), |
| | ] |
| |
|
| | self.model = nn.Sequential(*layers) |
| |
|
| | if use_vocos: |
| | self.model = Vocos( |
| | input_channels=in_channels, |
| | dim=vocos_dim, |
| | intermediate_dim=vocos_intermediate_dim, |
| | num_layers=vocos_num_layers, |
| | adanorm_num_embeddings=None, |
| | n_fft=n_fft, |
| | hop_size=hop_size, |
| | padding=padding, |
| | ) |
| |
|
| | self.reset_parameters() |
| |
|
| | def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None): |
| | """ |
| | if vq is True, x = encoder output, then return quantized output; |
| | else, x = quantized output, then return decoder output |
| | """ |
| | if vq is True: |
| | if eval_vq: |
| | self.quantizer.eval() |
| | ( |
| | quantized_out, |
| | all_indices, |
| | all_commit_losses, |
| | all_codebook_losses, |
| | all_quantized, |
| | ) = self.quantizer(x, n_quantizers=n_quantizers) |
| | return ( |
| | quantized_out, |
| | all_indices, |
| | all_commit_losses, |
| | all_codebook_losses, |
| | all_quantized, |
| | ) |
| |
|
| | return self.model(x) |
| |
|
| | def quantize(self, x, n_quantizers=None): |
| | self.quantizer.eval() |
| | quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers) |
| | return quantized_out, vq |
| |
|
| | |
| | def vq2emb(self, vq, n_quantizers=None): |
| | return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers) |
| |
|
| | def decode(self, x): |
| | return self.model(x) |
| |
|
| | def latent2dist(self, x, n_quantizers=None): |
| | return self.quantizer.latent2dist(x, n_quantizers=n_quantizers) |
| |
|
| | def reset_parameters(self): |
| | self.apply(init_weights) |
| |
|