AnyAccomp / models /codec /coco /rep_coco_model.py
viewfinder-annn's picture
Upload inference related files
85651ad verified
raw
history blame
14.3 kB
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ALL_COMPLETED
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from models.codec.amphion_codec.quantize import ResidualVQ
from models.codec.amphion_codec.vocos import VocosBackbone
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def compute_codebook_perplexity(indices, codebook_size):
indices = indices.flatten()
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
return perp
class CocoContentStyle(nn.Module):
def __init__(
self,
codebook_size=8192,
hidden_size=1024,
codebook_dim=8,
num_quantizers=1,
quantizer_type="fvq",
use_whisper=True,
use_chromagram=True,
construct_only_for_quantizer=False,
cfg=None,
):
super().__init__()
assert cfg is not None
self.cfg = cfg
codebook_size = getattr(cfg, "codebook_size", codebook_size)
hidden_size = getattr(cfg, "hidden_size", hidden_size)
codebook_dim = getattr(cfg, "codebook_dim", codebook_dim)
num_quantizers = getattr(cfg, "num_quantizers", num_quantizers)
quantizer_type = getattr(cfg, "quantizer_type", quantizer_type)
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.hidden_size = hidden_size
self.num_quantizers = num_quantizers
self.quantizer_type = quantizer_type
if use_whisper:
self.whisper_input_layer = nn.Linear(self.cfg.whisper_dim, hidden_size)
if use_chromagram:
self.chromagram_input_layer = nn.Linear(
self.cfg.chromagram_dim, hidden_size
)
downsample_rate = getattr(cfg, "downsample_rate", 1)
if downsample_rate > 1:
self.do_downsample = True
assert np.log2(downsample_rate).is_integer()
down_layers = []
up_layers = []
for _ in range(int(np.log2(downsample_rate))):
down_layers.extend(
[
nn.Conv1d(
hidden_size,
hidden_size,
kernel_size=3,
stride=2,
padding=1,
),
nn.GELU(),
]
)
up_layers.extend(
[
nn.ConvTranspose1d(
hidden_size, hidden_size, kernel_size=4, stride=2, padding=1
),
nn.GELU(),
]
)
self.downsample_layers = nn.Sequential(*down_layers)
self.upsample_layers = nn.Sequential(*up_layers)
else:
self.do_downsample = False
self.encoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=self.cfg.encoder.vocos_dim,
intermediate_dim=self.cfg.encoder.vocos_intermediate_dim,
num_layers=self.cfg.encoder.vocos_num_layers,
adanorm_num_embeddings=None,
),
nn.Linear(self.cfg.encoder.vocos_dim, self.hidden_size),
)
self.quantizer = ResidualVQ(
input_dim=hidden_size,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type=quantizer_type,
quantizer_dropout=0.0,
commitment=0.15,
codebook_loss_weight=1.0,
use_l2_normlize=True,
)
if not construct_only_for_quantizer:
self.decoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=self.cfg.decoder.vocos_dim,
intermediate_dim=self.cfg.decoder.vocos_intermediate_dim,
num_layers=self.cfg.decoder.vocos_num_layers,
adanorm_num_embeddings=None,
),
nn.Linear(self.cfg.decoder.vocos_dim, self.hidden_size),
)
if use_whisper:
self.whisper_output_layer = nn.Linear(
self.hidden_size, self.cfg.whisper_dim
)
if use_chromagram:
self.chromagram_output_layer = nn.Linear(
self.hidden_size, self.cfg.chromagram_dim
)
self.reset_parameters()
def forward(
self,
whisper_feats,
chromagram_feats,
return_for_quantizer=False,
):
"""
Args:
whisper_feats: [B, T, 1024]
chromagram_feats: [B, T, 24]
Returns:
whisper_rec: [B, T, 1024]
chromagram_rec: [B, T, 24]
codebook_loss: float
all_indices: [N, B, T] or [B, T] if num_of_quantizers == 1
"""
T = whisper_feats.shape[1]
# [B, T, D]
x = self.whisper_input_layer(whisper_feats) + self.chromagram_input_layer(
chromagram_feats
)
# print("Before downsample:", x.shape)
# ====== Downsample ======
if self.do_downsample:
x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
# print("After downsample:", x.shape)
# ====== Encoder ======
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
# ====== Quantizer ======
(
quantized_out, # [B, D, T]
all_indices, # [num_of_quantizers, B, T]
all_commit_losses, # [num_of_quantizers]
all_codebook_losses, # [num_of_quantizers]
_,
) = self.quantizer(x)
if return_for_quantizer:
if all_indices.shape[0] == 1:
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
return all_indices, quantized_out.transpose(1, 2)
# ====== Decoder ======
x_rec = self.decoder(quantized_out) # [B, T, D]
# ====== Upsample ======
if self.do_downsample:
x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
# print("After upsample:", x_rec.shape)
# Ensure output dimensions match input
if x_rec.shape[1] >= T: # Check time dimension
x_rec = x_rec[:, :T, :]
else:
padding_frames = T - x_rec.shape[1]
last_frame = x_rec[:, -1:, :]
padding = last_frame.repeat(1, padding_frames, 1)
x_rec = torch.cat([x_rec, padding], dim=1)
# ====== Loss ======
whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024]
chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24]
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
all_indices = all_indices
return whisper_rec, chromagram_rec, codebook_loss, all_indices
def quantize(self, whisper_feats, chromagram_feats):
"""
Args:
whisper_feats: [B, T, 1024]
chromagram_feats: [B, T, 24]
Returns:
all_indices: [N, B, T], or [B, T] if num_of_quantizers == 1
quantized_out: [B, D, T]
"""
all_indices, quantized_out = self.forward(
whisper_feats,
chromagram_feats,
return_for_quantizer=True,
)
return all_indices, quantized_out
def reset_parameters(self):
self.apply(init_weights)
class CocoContent(CocoContentStyle):
def __init__(
self,
cfg,
use_whisper=True,
use_chromagram=False,
construct_only_for_quantizer=False,
):
super().__init__(
cfg=cfg,
use_whisper=use_whisper,
use_chromagram=use_chromagram,
construct_only_for_quantizer=construct_only_for_quantizer,
)
def forward(
self,
whisper_feats,
return_for_quantizer=False,
):
"""
Args:
whisper_feats: [B, T, 1024]
Returns:
whisper_rec: [B, T, 1024]
codebook_loss: float
all_indices: [N, B, T]
"""
T = whisper_feats.shape[1]
# [B, T, D]
x = self.whisper_input_layer(whisper_feats)
# ====== Downsample ======
if self.do_downsample:
x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
# ====== Encoder ======
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
# ====== Quantizer ======
(
quantized_out, # [B, D, T]
all_indices, # [num_of_quantizers, B, T]
all_commit_losses, # [num_of_quantizers]
all_codebook_losses, # [num_of_quantizers]
_,
) = self.quantizer(x)
if return_for_quantizer:
if all_indices.shape[0] == 1:
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
return all_indices, quantized_out.transpose(1, 2)
# ====== Decoder ======
x_rec = self.decoder(quantized_out) # [B, T, D]
# ====== Upsample ======
if self.do_downsample:
x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
# Ensure output dimensions match input
if x_rec.shape[1] >= T: # Check time dimension
x_rec = x_rec[:, :T, :]
else:
padding_frames = T - x_rec.shape[1]
last_frame = x_rec[:, -1:, :]
padding = last_frame.repeat(1, padding_frames, 1)
x_rec = torch.cat([x_rec, padding], dim=1)
# ====== Loss ======
whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024]
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
all_indices = all_indices
return whisper_rec, codebook_loss, all_indices
def quantize(self, whisper_feats):
all_indices, quantized_out = self.forward(
whisper_feats, return_for_quantizer=True
)
return all_indices, quantized_out
class CocoStyle(CocoContentStyle):
def __init__(
self,
cfg,
use_whisper=False,
use_chromagram=True,
construct_only_for_quantizer=False,
):
super().__init__(
cfg=cfg,
use_whisper=use_whisper,
use_chromagram=use_chromagram,
construct_only_for_quantizer=construct_only_for_quantizer,
)
def forward(
self,
chromagram_feats,
return_for_quantizer=False,
):
"""
Args:
chromagram_feats: [B, T, 24]
Returns:
chromagram_rec: [B, T, 24]
codebook_loss: float
all_indices: [N, B, T]
"""
T = chromagram_feats.shape[1]
# [B, T, D]
x = self.chromagram_input_layer(chromagram_feats)
# ====== Downsample ======
if self.do_downsample:
x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
# ====== Encoder ======
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
# ====== Quantizer ======
(
quantized_out, # [B, D, T]
all_indices, # [num_of_quantizers, B, T]
all_commit_losses, # [num_of_quantizers]
all_codebook_losses, # [num_of_quantizers]
_,
) = self.quantizer(x)
if return_for_quantizer:
if all_indices.shape[0] == 1:
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
return all_indices, quantized_out.transpose(1, 2)
# ====== Decoder ======
x_rec = self.decoder(quantized_out) # [B, T, D]
# ====== Upsample ======
if self.do_downsample:
x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
# Ensure output dimensions match input
if x_rec.shape[1] >= T: # Check time dimension
x_rec = x_rec[:, :T, :]
else:
padding_frames = T - x_rec.shape[1]
last_frame = x_rec[:, -1:, :]
padding = last_frame.repeat(1, padding_frames, 1)
x_rec = torch.cat([x_rec, padding], dim=1)
# ====== Loss ======
chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24]
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
all_indices = all_indices
return chromagram_rec, codebook_loss, all_indices
def quantize(self, chromagram_feats):
all_indices, quantized_out = self.forward(
chromagram_feats, return_for_quantizer=True
)
return all_indices, quantized_out
# if __name__ == "__main__":
# from utils.util import JsonHParams
# cfg = JsonHParams(
# **{
# "whisper_dim": 1024,
# "chromagram_dim": 24,
# "global_speaker_encoder": {
# "input_dim": 128, # Eg: n_mels
# "hidden_size": 512, # 768 for emilia298k
# "num_hidden_layers": 4, # 6 for emilia298k
# "num_attention_heads": 8,
# },
# }
# )
# model = Coco(cfg=cfg)
# x = torch.randn(2, 150, 1024)
# tone_height = torch.randn(2)
# mels = torch.randn(2, 150, 128)
# mel_mask = torch.ones(2, 150)
# x_rec, codebook_loss, all_indices, auxillary_pred_outputs = model(
# x, tone_height, mels, mel_mask
# )
# print(x_rec.shape, codebook_loss, all_indices.shape)
# for k, v in auxillary_pred_outputs.items():
# print(k, v.shape)