Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2024 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from concurrent.futures import ALL_COMPLETED | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from models.codec.amphion_codec.quantize import ResidualVQ | |
| from models.codec.amphion_codec.vocos import VocosBackbone | |
| def init_weights(m): | |
| if isinstance(m, nn.Conv1d): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| if isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def compute_codebook_perplexity(indices, codebook_size): | |
| indices = indices.flatten() | |
| prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0) | |
| perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10))) | |
| return perp | |
| class CocoContentStyle(nn.Module): | |
| def __init__( | |
| self, | |
| codebook_size=8192, | |
| hidden_size=1024, | |
| codebook_dim=8, | |
| num_quantizers=1, | |
| quantizer_type="fvq", | |
| use_whisper=True, | |
| use_chromagram=True, | |
| construct_only_for_quantizer=False, | |
| cfg=None, | |
| ): | |
| super().__init__() | |
| assert cfg is not None | |
| self.cfg = cfg | |
| codebook_size = getattr(cfg, "codebook_size", codebook_size) | |
| hidden_size = getattr(cfg, "hidden_size", hidden_size) | |
| codebook_dim = getattr(cfg, "codebook_dim", codebook_dim) | |
| num_quantizers = getattr(cfg, "num_quantizers", num_quantizers) | |
| quantizer_type = getattr(cfg, "quantizer_type", quantizer_type) | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.hidden_size = hidden_size | |
| self.num_quantizers = num_quantizers | |
| self.quantizer_type = quantizer_type | |
| if use_whisper: | |
| self.whisper_input_layer = nn.Linear(self.cfg.whisper_dim, hidden_size) | |
| if use_chromagram: | |
| self.chromagram_input_layer = nn.Linear( | |
| self.cfg.chromagram_dim, hidden_size | |
| ) | |
| downsample_rate = getattr(cfg, "downsample_rate", 1) | |
| if downsample_rate > 1: | |
| self.do_downsample = True | |
| assert np.log2(downsample_rate).is_integer() | |
| down_layers = [] | |
| up_layers = [] | |
| for _ in range(int(np.log2(downsample_rate))): | |
| down_layers.extend( | |
| [ | |
| nn.Conv1d( | |
| hidden_size, | |
| hidden_size, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ), | |
| nn.GELU(), | |
| ] | |
| ) | |
| up_layers.extend( | |
| [ | |
| nn.ConvTranspose1d( | |
| hidden_size, hidden_size, kernel_size=4, stride=2, padding=1 | |
| ), | |
| nn.GELU(), | |
| ] | |
| ) | |
| self.downsample_layers = nn.Sequential(*down_layers) | |
| self.upsample_layers = nn.Sequential(*up_layers) | |
| else: | |
| self.do_downsample = False | |
| self.encoder = nn.Sequential( | |
| VocosBackbone( | |
| input_channels=self.hidden_size, | |
| dim=self.cfg.encoder.vocos_dim, | |
| intermediate_dim=self.cfg.encoder.vocos_intermediate_dim, | |
| num_layers=self.cfg.encoder.vocos_num_layers, | |
| adanorm_num_embeddings=None, | |
| ), | |
| nn.Linear(self.cfg.encoder.vocos_dim, self.hidden_size), | |
| ) | |
| self.quantizer = ResidualVQ( | |
| input_dim=hidden_size, | |
| num_quantizers=num_quantizers, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| quantizer_type=quantizer_type, | |
| quantizer_dropout=0.0, | |
| commitment=0.15, | |
| codebook_loss_weight=1.0, | |
| use_l2_normlize=True, | |
| ) | |
| if not construct_only_for_quantizer: | |
| self.decoder = nn.Sequential( | |
| VocosBackbone( | |
| input_channels=self.hidden_size, | |
| dim=self.cfg.decoder.vocos_dim, | |
| intermediate_dim=self.cfg.decoder.vocos_intermediate_dim, | |
| num_layers=self.cfg.decoder.vocos_num_layers, | |
| adanorm_num_embeddings=None, | |
| ), | |
| nn.Linear(self.cfg.decoder.vocos_dim, self.hidden_size), | |
| ) | |
| if use_whisper: | |
| self.whisper_output_layer = nn.Linear( | |
| self.hidden_size, self.cfg.whisper_dim | |
| ) | |
| if use_chromagram: | |
| self.chromagram_output_layer = nn.Linear( | |
| self.hidden_size, self.cfg.chromagram_dim | |
| ) | |
| self.reset_parameters() | |
| def forward( | |
| self, | |
| whisper_feats, | |
| chromagram_feats, | |
| return_for_quantizer=False, | |
| ): | |
| """ | |
| Args: | |
| whisper_feats: [B, T, 1024] | |
| chromagram_feats: [B, T, 24] | |
| Returns: | |
| whisper_rec: [B, T, 1024] | |
| chromagram_rec: [B, T, 24] | |
| codebook_loss: float | |
| all_indices: [N, B, T] or [B, T] if num_of_quantizers == 1 | |
| """ | |
| T = whisper_feats.shape[1] | |
| # [B, T, D] | |
| x = self.whisper_input_layer(whisper_feats) + self.chromagram_input_layer( | |
| chromagram_feats | |
| ) | |
| # print("Before downsample:", x.shape) | |
| # ====== Downsample ====== | |
| if self.do_downsample: | |
| x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2) | |
| # print("After downsample:", x.shape) | |
| # ====== Encoder ====== | |
| x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T] | |
| # ====== Quantizer ====== | |
| ( | |
| quantized_out, # [B, D, T] | |
| all_indices, # [num_of_quantizers, B, T] | |
| all_commit_losses, # [num_of_quantizers] | |
| all_codebook_losses, # [num_of_quantizers] | |
| _, | |
| ) = self.quantizer(x) | |
| if return_for_quantizer: | |
| if all_indices.shape[0] == 1: | |
| return all_indices.squeeze(0), quantized_out.transpose(1, 2) | |
| return all_indices, quantized_out.transpose(1, 2) | |
| # ====== Decoder ====== | |
| x_rec = self.decoder(quantized_out) # [B, T, D] | |
| # ====== Upsample ====== | |
| if self.do_downsample: | |
| x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2) | |
| # print("After upsample:", x_rec.shape) | |
| # Ensure output dimensions match input | |
| if x_rec.shape[1] >= T: # Check time dimension | |
| x_rec = x_rec[:, :T, :] | |
| else: | |
| padding_frames = T - x_rec.shape[1] | |
| last_frame = x_rec[:, -1:, :] | |
| padding = last_frame.repeat(1, padding_frames, 1) | |
| x_rec = torch.cat([x_rec, padding], dim=1) | |
| # ====== Loss ====== | |
| whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024] | |
| chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24] | |
| codebook_loss = (all_codebook_losses + all_commit_losses).mean() | |
| all_indices = all_indices | |
| return whisper_rec, chromagram_rec, codebook_loss, all_indices | |
| def quantize(self, whisper_feats, chromagram_feats): | |
| """ | |
| Args: | |
| whisper_feats: [B, T, 1024] | |
| chromagram_feats: [B, T, 24] | |
| Returns: | |
| all_indices: [N, B, T], or [B, T] if num_of_quantizers == 1 | |
| quantized_out: [B, D, T] | |
| """ | |
| all_indices, quantized_out = self.forward( | |
| whisper_feats, | |
| chromagram_feats, | |
| return_for_quantizer=True, | |
| ) | |
| return all_indices, quantized_out | |
| def reset_parameters(self): | |
| self.apply(init_weights) | |
| class CocoContent(CocoContentStyle): | |
| def __init__( | |
| self, | |
| cfg, | |
| use_whisper=True, | |
| use_chromagram=False, | |
| construct_only_for_quantizer=False, | |
| ): | |
| super().__init__( | |
| cfg=cfg, | |
| use_whisper=use_whisper, | |
| use_chromagram=use_chromagram, | |
| construct_only_for_quantizer=construct_only_for_quantizer, | |
| ) | |
| def forward( | |
| self, | |
| whisper_feats, | |
| return_for_quantizer=False, | |
| ): | |
| """ | |
| Args: | |
| whisper_feats: [B, T, 1024] | |
| Returns: | |
| whisper_rec: [B, T, 1024] | |
| codebook_loss: float | |
| all_indices: [N, B, T] | |
| """ | |
| T = whisper_feats.shape[1] | |
| # [B, T, D] | |
| x = self.whisper_input_layer(whisper_feats) | |
| # ====== Downsample ====== | |
| if self.do_downsample: | |
| x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2) | |
| # ====== Encoder ====== | |
| x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T] | |
| # ====== Quantizer ====== | |
| ( | |
| quantized_out, # [B, D, T] | |
| all_indices, # [num_of_quantizers, B, T] | |
| all_commit_losses, # [num_of_quantizers] | |
| all_codebook_losses, # [num_of_quantizers] | |
| _, | |
| ) = self.quantizer(x) | |
| if return_for_quantizer: | |
| if all_indices.shape[0] == 1: | |
| return all_indices.squeeze(0), quantized_out.transpose(1, 2) | |
| return all_indices, quantized_out.transpose(1, 2) | |
| # ====== Decoder ====== | |
| x_rec = self.decoder(quantized_out) # [B, T, D] | |
| # ====== Upsample ====== | |
| if self.do_downsample: | |
| x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2) | |
| # Ensure output dimensions match input | |
| if x_rec.shape[1] >= T: # Check time dimension | |
| x_rec = x_rec[:, :T, :] | |
| else: | |
| padding_frames = T - x_rec.shape[1] | |
| last_frame = x_rec[:, -1:, :] | |
| padding = last_frame.repeat(1, padding_frames, 1) | |
| x_rec = torch.cat([x_rec, padding], dim=1) | |
| # ====== Loss ====== | |
| whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024] | |
| codebook_loss = (all_codebook_losses + all_commit_losses).mean() | |
| all_indices = all_indices | |
| return whisper_rec, codebook_loss, all_indices | |
| def quantize(self, whisper_feats): | |
| all_indices, quantized_out = self.forward( | |
| whisper_feats, return_for_quantizer=True | |
| ) | |
| return all_indices, quantized_out | |
| class CocoStyle(CocoContentStyle): | |
| def __init__( | |
| self, | |
| cfg, | |
| use_whisper=False, | |
| use_chromagram=True, | |
| construct_only_for_quantizer=False, | |
| ): | |
| super().__init__( | |
| cfg=cfg, | |
| use_whisper=use_whisper, | |
| use_chromagram=use_chromagram, | |
| construct_only_for_quantizer=construct_only_for_quantizer, | |
| ) | |
| def forward( | |
| self, | |
| chromagram_feats, | |
| return_for_quantizer=False, | |
| ): | |
| """ | |
| Args: | |
| chromagram_feats: [B, T, 24] | |
| Returns: | |
| chromagram_rec: [B, T, 24] | |
| codebook_loss: float | |
| all_indices: [N, B, T] | |
| """ | |
| T = chromagram_feats.shape[1] | |
| # [B, T, D] | |
| x = self.chromagram_input_layer(chromagram_feats) | |
| # ====== Downsample ====== | |
| if self.do_downsample: | |
| x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2) | |
| # ====== Encoder ====== | |
| x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T] | |
| # ====== Quantizer ====== | |
| ( | |
| quantized_out, # [B, D, T] | |
| all_indices, # [num_of_quantizers, B, T] | |
| all_commit_losses, # [num_of_quantizers] | |
| all_codebook_losses, # [num_of_quantizers] | |
| _, | |
| ) = self.quantizer(x) | |
| if return_for_quantizer: | |
| if all_indices.shape[0] == 1: | |
| return all_indices.squeeze(0), quantized_out.transpose(1, 2) | |
| return all_indices, quantized_out.transpose(1, 2) | |
| # ====== Decoder ====== | |
| x_rec = self.decoder(quantized_out) # [B, T, D] | |
| # ====== Upsample ====== | |
| if self.do_downsample: | |
| x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2) | |
| # Ensure output dimensions match input | |
| if x_rec.shape[1] >= T: # Check time dimension | |
| x_rec = x_rec[:, :T, :] | |
| else: | |
| padding_frames = T - x_rec.shape[1] | |
| last_frame = x_rec[:, -1:, :] | |
| padding = last_frame.repeat(1, padding_frames, 1) | |
| x_rec = torch.cat([x_rec, padding], dim=1) | |
| # ====== Loss ====== | |
| chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24] | |
| codebook_loss = (all_codebook_losses + all_commit_losses).mean() | |
| all_indices = all_indices | |
| return chromagram_rec, codebook_loss, all_indices | |
| def quantize(self, chromagram_feats): | |
| all_indices, quantized_out = self.forward( | |
| chromagram_feats, return_for_quantizer=True | |
| ) | |
| return all_indices, quantized_out | |
| # if __name__ == "__main__": | |
| # from utils.util import JsonHParams | |
| # cfg = JsonHParams( | |
| # **{ | |
| # "whisper_dim": 1024, | |
| # "chromagram_dim": 24, | |
| # "global_speaker_encoder": { | |
| # "input_dim": 128, # Eg: n_mels | |
| # "hidden_size": 512, # 768 for emilia298k | |
| # "num_hidden_layers": 4, # 6 for emilia298k | |
| # "num_attention_heads": 8, | |
| # }, | |
| # } | |
| # ) | |
| # model = Coco(cfg=cfg) | |
| # x = torch.randn(2, 150, 1024) | |
| # tone_height = torch.randn(2) | |
| # mels = torch.randn(2, 150, 128) | |
| # mel_mask = torch.ones(2, 150) | |
| # x_rec, codebook_loss, all_indices, auxillary_pred_outputs = model( | |
| # x, tone_height, mels, mel_mask | |
| # ) | |
| # print(x_rec.shape, codebook_loss, all_indices.shape) | |
| # for k, v in auxillary_pred_outputs.items(): | |
| # print(k, v.shape) | |