| | |
| | |
| | |
| | |
| |
|
| | from concurrent.futures import ALL_COMPLETED |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from torch.nn import functional as F |
| | from einops import rearrange, repeat |
| |
|
| | from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ResidualVQ |
| | from indextts.utils.maskgct.models.codec.kmeans.vocos import VocosBackbone |
| |
|
| |
|
| | def init_weights(m): |
| | if isinstance(m, nn.Conv1d): |
| | nn.init.trunc_normal_(m.weight, std=0.02) |
| | nn.init.constant_(m.bias, 0) |
| | if isinstance(m, nn.Linear): |
| | nn.init.trunc_normal_(m.weight, std=0.02) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| |
|
| | def compute_codebook_perplexity(indices, codebook_size): |
| | indices = indices.flatten() |
| | prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0) |
| | perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10))) |
| | return perp |
| |
|
| |
|
| | class RepCodec(nn.Module): |
| | def __init__( |
| | self, |
| | codebook_size=8192, |
| | hidden_size=1024, |
| | codebook_dim=8, |
| | vocos_dim=384, |
| | vocos_intermediate_dim=2048, |
| | vocos_num_layers=12, |
| | num_quantizers=1, |
| | downsample_scale=1, |
| | cfg=None, |
| | ): |
| | super().__init__() |
| | codebook_size = ( |
| | cfg.codebook_size |
| | if cfg is not None and hasattr(cfg, "codebook_size") |
| | else codebook_size |
| | ) |
| | codebook_dim = ( |
| | cfg.codebook_dim |
| | if cfg is not None and hasattr(cfg, "codebook_dim") |
| | else codebook_dim |
| | ) |
| | hidden_size = ( |
| | cfg.hidden_size |
| | if cfg is not None and hasattr(cfg, "hidden_size") |
| | else hidden_size |
| | ) |
| | vocos_dim = ( |
| | cfg.vocos_dim |
| | if cfg is not None and hasattr(cfg, "vocos_dim") |
| | else vocos_dim |
| | ) |
| | vocos_intermediate_dim = ( |
| | cfg.vocos_intermediate_dim |
| | if cfg is not None and hasattr(cfg, "vocos_dim") |
| | else vocos_intermediate_dim |
| | ) |
| | vocos_num_layers = ( |
| | cfg.vocos_num_layers |
| | if cfg is not None and hasattr(cfg, "vocos_dim") |
| | else vocos_num_layers |
| | ) |
| | num_quantizers = ( |
| | cfg.num_quantizers |
| | if cfg is not None and hasattr(cfg, "num_quantizers") |
| | else num_quantizers |
| | ) |
| | downsample_scale = ( |
| | cfg.downsample_scale |
| | if cfg is not None and hasattr(cfg, "downsample_scale") |
| | else downsample_scale |
| | ) |
| |
|
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| | self.hidden_size = hidden_size |
| | self.vocos_dim = vocos_dim |
| | self.vocos_intermediate_dim = vocos_intermediate_dim |
| | self.vocos_num_layers = vocos_num_layers |
| | self.num_quantizers = num_quantizers |
| | self.downsample_scale = downsample_scale |
| |
|
| | if self.downsample_scale != None and self.downsample_scale > 1: |
| | self.down = nn.Conv1d( |
| | self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 |
| | ) |
| | self.up = nn.Conv1d( |
| | self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1 |
| | ) |
| |
|
| | self.encoder = nn.Sequential( |
| | VocosBackbone( |
| | input_channels=self.hidden_size, |
| | dim=self.vocos_dim, |
| | intermediate_dim=self.vocos_intermediate_dim, |
| | num_layers=self.vocos_num_layers, |
| | adanorm_num_embeddings=None, |
| | ), |
| | nn.Linear(self.vocos_dim, self.hidden_size), |
| | ) |
| | self.decoder = nn.Sequential( |
| | VocosBackbone( |
| | input_channels=self.hidden_size, |
| | dim=self.vocos_dim, |
| | intermediate_dim=self.vocos_intermediate_dim, |
| | num_layers=self.vocos_num_layers, |
| | adanorm_num_embeddings=None, |
| | ), |
| | nn.Linear(self.vocos_dim, self.hidden_size), |
| | ) |
| |
|
| | self.quantizer = ResidualVQ( |
| | input_dim=hidden_size, |
| | num_quantizers=num_quantizers, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim, |
| | quantizer_type="fvq", |
| | quantizer_dropout=0.0, |
| | commitment=0.15, |
| | codebook_loss_weight=1.0, |
| | use_l2_normlize=True, |
| | ) |
| |
|
| | self.reset_parameters() |
| |
|
| | def forward(self, x): |
| |
|
| | |
| | if self.downsample_scale != None and self.downsample_scale > 1: |
| | x = x.transpose(1, 2) |
| | x = self.down(x) |
| | x = F.gelu(x) |
| | x = x.transpose(1, 2) |
| |
|
| | |
| | x = self.encoder(x.transpose(1, 2)).transpose(1, 2) |
| |
|
| | |
| | ( |
| | quantized_out, |
| | all_indices, |
| | all_commit_losses, |
| | all_codebook_losses, |
| | _, |
| | ) = self.quantizer(x) |
| |
|
| | |
| | x = self.decoder(quantized_out) |
| |
|
| | |
| | if self.downsample_scale != None and self.downsample_scale > 1: |
| | x = x.transpose(1, 2) |
| | x = F.interpolate(x, scale_factor=2, mode="nearest") |
| | x_rec = self.up(x).transpose(1, 2) |
| |
|
| | codebook_loss = (all_codebook_losses + all_commit_losses).mean() |
| | all_indices = all_indices |
| |
|
| | return x_rec, codebook_loss, all_indices |
| |
|
| | def quantize(self, x): |
| |
|
| | if self.downsample_scale != None and self.downsample_scale > 1: |
| | x = x.transpose(1, 2) |
| | x = self.down(x) |
| | x = F.gelu(x) |
| | x = x.transpose(1, 2) |
| |
|
| | x = self.encoder(x.transpose(1, 2)).transpose(1, 2) |
| |
|
| | ( |
| | quantized_out, |
| | all_indices, |
| | all_commit_losses, |
| | all_codebook_losses, |
| | _, |
| | ) = self.quantizer(x) |
| |
|
| | if all_indices.shape[0] == 1: |
| | return all_indices.squeeze(0), quantized_out.transpose(1, 2) |
| | return all_indices, quantized_out.transpose(1, 2) |
| |
|
| | def reset_parameters(self): |
| | self.apply(init_weights) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | repcodec = RepCodec(vocos_dim=1024, downsample_scale=2) |
| | print(repcodec) |
| | print(sum(p.numel() for p in repcodec.parameters()) / 1e6) |
| | x = torch.randn(5, 10, 1024) |
| | x_rec, codebook_loss, all_indices = repcodec(x) |
| | print(x_rec.shape, codebook_loss, all_indices.shape) |
| | vq_id, emb = repcodec.quantize(x) |
| | print(vq_id.shape, emb.shape) |
| |
|