| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from torch.nn import functional as F |
|
|
| from models.codec.amphion_codec.quantize import ResidualVQ |
| from models.codec.amphion_codec.vocos import VocosBackbone |
|
|
|
|
| |
|
|
|
|
| def init_weights(m): |
| if isinstance(m, nn.Conv1d): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| nn.init.constant_(m.bias, 0) |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
| class CocoContentStyle(nn.Module): |
| def __init__( |
| self, |
| codebook_size=8192, |
| hidden_size=1024, |
| codebook_dim=8, |
| num_quantizers=1, |
| quantizer_type="fvq", |
| use_whisper=True, |
| use_chromagram=True, |
| construct_only_for_quantizer=False, |
| cfg=None, |
| ): |
| super().__init__() |
|
|
| assert cfg is not None |
| self.cfg = cfg |
|
|
| codebook_size = getattr(cfg, "codebook_size", codebook_size) |
| hidden_size = getattr(cfg, "hidden_size", hidden_size) |
| codebook_dim = getattr(cfg, "codebook_dim", codebook_dim) |
| num_quantizers = getattr(cfg, "num_quantizers", num_quantizers) |
| quantizer_type = getattr(cfg, "quantizer_type", quantizer_type) |
|
|
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
| self.hidden_size = hidden_size |
| self.num_quantizers = num_quantizers |
| self.quantizer_type = quantizer_type |
|
|
| if use_whisper: |
| self.whisper_input_layer = nn.Linear(self.cfg.whisper_dim, hidden_size) |
| if use_chromagram: |
| self.chromagram_input_layer = nn.Linear( |
| self.cfg.chromagram_dim, hidden_size |
| ) |
|
|
| downsample_rate = getattr(cfg, "downsample_rate", 1) |
| if downsample_rate > 1: |
| self.do_downsample = True |
| assert np.log2(downsample_rate).is_integer() |
|
|
| down_layers = [] |
| up_layers = [] |
| for _ in range(int(np.log2(downsample_rate))): |
| down_layers.extend( |
| [ |
| nn.Conv1d( |
| hidden_size, |
| hidden_size, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| nn.GELU(), |
| ] |
| ) |
| up_layers.extend( |
| [ |
| nn.ConvTranspose1d( |
| hidden_size, hidden_size, kernel_size=4, stride=2, padding=1 |
| ), |
| nn.GELU(), |
| ] |
| ) |
| self.downsample_layers = nn.Sequential(*down_layers) |
| self.upsample_layers = nn.Sequential(*up_layers) |
|
|
| else: |
| self.do_downsample = False |
|
|
| self.encoder = nn.Sequential( |
| VocosBackbone( |
| input_channels=self.hidden_size, |
| dim=self.cfg.encoder.vocos_dim, |
| intermediate_dim=self.cfg.encoder.vocos_intermediate_dim, |
| num_layers=self.cfg.encoder.vocos_num_layers, |
| adanorm_num_embeddings=None, |
| ), |
| nn.Linear(self.cfg.encoder.vocos_dim, self.hidden_size), |
| ) |
|
|
| self.quantizer = ResidualVQ( |
| input_dim=hidden_size, |
| num_quantizers=num_quantizers, |
| codebook_size=codebook_size, |
| codebook_dim=codebook_dim, |
| quantizer_type=quantizer_type, |
| quantizer_dropout=0.0, |
| commitment=0.15, |
| codebook_loss_weight=1.0, |
| use_l2_normlize=True, |
| ) |
|
|
| if not construct_only_for_quantizer: |
| self.decoder = nn.Sequential( |
| VocosBackbone( |
| input_channels=self.hidden_size, |
| dim=self.cfg.decoder.vocos_dim, |
| intermediate_dim=self.cfg.decoder.vocos_intermediate_dim, |
| num_layers=self.cfg.decoder.vocos_num_layers, |
| adanorm_num_embeddings=None, |
| ), |
| nn.Linear(self.cfg.decoder.vocos_dim, self.hidden_size), |
| ) |
|
|
| if use_whisper: |
| self.whisper_output_layer = nn.Linear( |
| self.hidden_size, self.cfg.whisper_dim |
| ) |
| if use_chromagram: |
| self.chromagram_output_layer = nn.Linear( |
| self.hidden_size, self.cfg.chromagram_dim |
| ) |
|
|
| self.reset_parameters() |
|
|
| def forward( |
| self, |
| whisper_feats, |
| chromagram_feats, |
| return_for_quantizer=False, |
| ): |
| """ |
| Args: |
| whisper_feats: [B, T, 1024] |
| chromagram_feats: [B, T, 24] |
| Returns: |
| whisper_rec: [B, T, 1024] |
| chromagram_rec: [B, T, 24] |
| codebook_loss: float |
| all_indices: [N, B, T] or [B, T] if num_of_quantizers == 1 |
| """ |
| T = whisper_feats.shape[1] |
|
|
| |
| x = self.whisper_input_layer(whisper_feats) + self.chromagram_input_layer( |
| chromagram_feats |
| ) |
| |
|
|
| |
| if self.do_downsample: |
| x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2) |
|
|
| |
|
|
| |
| x = self.encoder(x.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| ( |
| quantized_out, |
| all_indices, |
| all_commit_losses, |
| all_codebook_losses, |
| _, |
| ) = self.quantizer(x) |
|
|
| if return_for_quantizer: |
| if all_indices.shape[0] == 1: |
| return all_indices.squeeze(0), quantized_out.transpose(1, 2) |
| return all_indices, quantized_out.transpose(1, 2) |
|
|
| |
| x_rec = self.decoder(quantized_out) |
|
|
| |
| if self.do_downsample: |
| x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2) |
|
|
| |
|
|
| |
| if x_rec.shape[1] >= T: |
| x_rec = x_rec[:, :T, :] |
| else: |
| padding_frames = T - x_rec.shape[1] |
| last_frame = x_rec[:, -1:, :] |
| padding = last_frame.repeat(1, padding_frames, 1) |
| x_rec = torch.cat([x_rec, padding], dim=1) |
|
|
| |
| whisper_rec = self.whisper_output_layer(x_rec) |
| chromagram_rec = self.chromagram_output_layer(x_rec) |
|
|
| codebook_loss = (all_codebook_losses + all_commit_losses).mean() |
| all_indices = all_indices |
|
|
| return whisper_rec, chromagram_rec, codebook_loss, all_indices |
|
|
| def quantize(self, whisper_feats, chromagram_feats): |
| """ |
| Args: |
| whisper_feats: [B, T, 1024] |
| chromagram_feats: [B, T, 24] |
| Returns: |
| all_indices: [N, B, T], or [B, T] if num_of_quantizers == 1 |
| quantized_out: [B, D, T] |
| """ |
| all_indices, quantized_out = self.forward( |
| whisper_feats, |
| chromagram_feats, |
| return_for_quantizer=True, |
| ) |
| return all_indices, quantized_out |
|
|
| def reset_parameters(self): |
| self.apply(init_weights) |
|
|
|
|
| |
|
|
|
|
| class CocoStyle(CocoContentStyle): |
| def __init__( |
| self, |
| cfg, |
| use_whisper=False, |
| use_chromagram=True, |
| construct_only_for_quantizer=False, |
| ): |
| super().__init__( |
| cfg=cfg, |
| use_whisper=use_whisper, |
| use_chromagram=use_chromagram, |
| construct_only_for_quantizer=construct_only_for_quantizer, |
| ) |
|
|
| def forward( |
| self, |
| chromagram_feats, |
| return_for_quantizer=False, |
| ): |
| """ |
| Args: |
| chromagram_feats: [B, T, 24] |
| Returns: |
| chromagram_rec: [B, T, 24] |
| codebook_loss: float |
| all_indices: [N, B, T] |
| """ |
| T = chromagram_feats.shape[1] |
|
|
| |
| x = self.chromagram_input_layer(chromagram_feats) |
|
|
| |
| if self.do_downsample: |
| x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| x = self.encoder(x.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| ( |
| quantized_out, |
| all_indices, |
| all_commit_losses, |
| all_codebook_losses, |
| _, |
| ) = self.quantizer(x) |
|
|
| if return_for_quantizer: |
| if all_indices.shape[0] == 1: |
| return all_indices.squeeze(0), quantized_out.transpose(1, 2) |
| return all_indices, quantized_out.transpose(1, 2) |
|
|
| |
| x_rec = self.decoder(quantized_out) |
|
|
| |
| if self.do_downsample: |
| x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| if x_rec.shape[1] >= T: |
| x_rec = x_rec[:, :T, :] |
| else: |
| padding_frames = T - x_rec.shape[1] |
| last_frame = x_rec[:, -1:, :] |
| padding = last_frame.repeat(1, padding_frames, 1) |
| x_rec = torch.cat([x_rec, padding], dim=1) |
|
|
| |
| chromagram_rec = self.chromagram_output_layer(x_rec) |
|
|
| codebook_loss = (all_codebook_losses + all_commit_losses).mean() |
| all_indices = all_indices |
|
|
| return chromagram_rec, codebook_loss, all_indices |
|
|
| def quantize(self, chromagram_feats): |
| all_indices, quantized_out = self.forward( |
| chromagram_feats, return_for_quantizer=True |
| ) |
| return all_indices, quantized_out |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|