import random from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from module import SinePositionalEmbedding, TokenEmbedding, top_k_sampling from simple_parsing import Serializable from torch.cuda.amp import autocast from torchmetrics.classification import MulticlassAccuracy from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer from transformers import EncodecModel from vocos import Vocos from audiozen.acoustics.audio_feature import stft @dataclass class ModelArgs(Serializable): num_cb: int = 8 # number of codebooks cb_size: int = 1024 # codebook size d_model: int = 512 # hidden size of transformer n_fft: int = 768 # number of fft points hop_length: int = 384 # hop length num_tokens: int = 1024 # codebook size + 1 (eos token) num_layers: int = 12 # number of transformer layers num_heads: int = 8 # number of heads in multi-head attention norm_first: bool = True # whether to apply layer norm before self-attention share_embedding: bool = True # whether to share embedding between encoder and decoder prepend_bos: bool = False # whether to prepend bos token to the target sequence add_prenet: bool = False # whether to add prenet before transformer stage: int = 2 # stage of the model class Model(nn.Module): def __init__(self, args: ModelArgs = ModelArgs()): super().__init__() self.encodec = EncodecModel.from_pretrained("facebook/encodec_24khz") for param in self.encodec.parameters(): param.requires_grad = False # ================================ AR ================================ self.ar_audio_prepend_bos = args.prepend_bos self.ar_embedding_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos)) self.ar_spec_encoder = nn.Sequential( nn.Linear(args.n_fft // 2 + 1, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, args.d_model), ) self.ar_prenet = nn.Identity() self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=True) # Sequence model self.ar_decoder = TransformerEncoder( TransformerEncoderLayer( args.d_model, args.num_heads, dim_feedforward=args.d_model * 4, dropout=0.1, batch_first=True, norm_first=args.norm_first, ), num_layers=args.num_layers, norm=LayerNorm(args.d_model) if args.norm_first else None, ) self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False) self.ar_acc_metric = MulticlassAccuracy( args.num_tokens + 1, top_k=10, average="micro", multidim_average="global", ignore_index=args.num_tokens, ) # ================================ Non AR ================================ self.nar_spec_encoder = nn.Sequential( nn.Linear(args.n_fft // 2 + 1, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, args.d_model), ) self.nar_embedding_layers = nn.ModuleList( [TokenEmbedding(args.d_model, args.num_tokens + 1)] + [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)] ) self.nar_prenet = nn.Identity() self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False) # Sequence model self.nar_decoder = TransformerEncoder( TransformerEncoderLayer( args.d_model, args.num_heads, dim_feedforward=args.d_model * 4, dropout=0.1, batch_first=True, norm_first=args.norm_first, adaptive_layer_norm=True, ), num_layers=args.num_layers, norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None, ) self.nar_pred_layers = nn.ModuleList( [nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)] ) self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)]) if args.share_embedding: for j in range(0, args.num_cb - 2): # We share the paramters of the acoustic embedding layer and the output prediction layer, # pred_layer_i self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight self.nar_acc_metric = MulticlassAccuracy( args.num_tokens + 1, top_k=10, average="micro", multidim_average="global", ignore_index=args.num_tokens, ) self.args = args self.rng = random.Random(0) def _encodec_encode(self, waveform): """Encode waveform to codes. Args: waveform: shape of [B, T] Returns: codes: shape of [B, T, N_q] """ with torch.no_grad(): with autocast(dtype=torch.float32): waveform = rearrange(waveform, "b t -> b () t") codes = self.encodec.encode(input_values=waveform, return_dict=True, bandwidth=6) codes = codes.audio_codes # [num_chunks, B, N_q, T], more specifically, [1, B, 8, T] for 6 kbps codes = rearrange(codes, "c b nq t -> (c b) t nq") codes = codes.to(dtype=torch.long) return codes def _encodec_decode(self, codes): """codes with shape [B, T, N_q] => [B, T]""" with torch.no_grad(): codes = rearrange(codes, "b t nq -> 1 b nq t") audio_values = self.encodec.decode(audio_codes=codes, audio_scales=[None], return_dict=True).audio_values audio_values = rearrange(audio_values, "b 1 t -> b t") return audio_values def _vocos_decode(self, codes): """codes with shape [B, T, N_q] => [B, T]""" with torch.no_grad(): vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(codes.device) codes = codes.permute(2, 0, 1) # [B, T, N_q] => [N_q, B, T] features = vocos.codes_to_features(codes) audio_values = vocos.decode(features, bandwidth_id=torch.tensor([2], device=codes.device)) return audio_values def _stack_embeddings(self, codes, cb_idx): """Stack embeddings from stage 0 to stage `cb_idx`. Args: codes: shape of [B, T, N_q] cb_idx: codebook index Returns: cumsum: shape of [B, T, cb_size] """ batch_size, time_steps, _ = codes.shape embed = torch.zeros(batch_size, time_steps, self.args.cb_size, device=codes.device) for i in range(0, cb_idx + 1): emb_i = self.embed_layers[i](codes[..., i]) embed += emb_i return embed def _pad_eos(self, codes, eos_id): """Pad EOS to codes with shape [B, T] => [B, T+1]""" codes = F.pad(codes, (0, 1), value=eos_id) return codes[..., :-1], codes[..., 1:] def forward(self, mix_wave, sep_wav): """Auto-regressive model to predict codes of the separated waveform. Args: mix_wave: mixture waveform, shape of [B, T] sep_wave: sep waveform, shape of [B, 2, T] """ # Check input shape, we only support to process two speakers for now batch_size, num_spks, _ = sep_wav.shape device = sep_wav.device assert num_spks == 2, f"Only support to process two speakers, but got {num_spks}" # Convert waveform to codes mix_codes = self._encodec_encode(mix_wave) # [B, T, N_q] spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1]) sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) # [B, T, N_q] => [B, 2 * T, N_q] # ========================================= AR ========================================= # Pad EOS to codes sep_code, target = self._pad_eos(sep_codes[..., 0], self.args.num_tokens) # Make mixture embedding mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) # [B, T] => [B, T, H] for j in range(1, self.args.num_cb): mix_embed += self.ar_embedding_layer(mix_codes[..., j]) sep_embed = self.ar_embedding_layer(sep_code) # [B, T] => [B, T, H] # Make mixture spectrogram embedding mix_mag, *_ = stft( mix_wave, n_fft=self.args.n_fft, hop_length=self.args.hop_length, win_length=self.args.n_fft ) # [B, T] => [B, F, T] mix_mag = rearrange(mix_mag, "b f t -> b t f") # [B, F, T] => [B, T, F] mix_mag = torch.log(mix_mag + 1e-8) mix_feat = self.ar_spec_encoder(mix_mag) # [B, T, F] => [B, T, H] mix_len = mix_feat.shape[1] + mix_embed.shape[1] sep_len = sep_embed.shape[1] mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) # [B, 2*T, H] mix_sep_embed = self.ar_prenet(mix_sep_embed) mix_sep_embed = self.ar_position(mix_sep_embed) mix_attn_mask = F.pad( torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device), (0, sep_len), value=True ) sep_attn_mask = F.pad( torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1), (mix_len, 0), value=False, ) mix_sep_attn_mask = torch.cat([mix_attn_mask, sep_attn_mask], dim=0) mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) # [B, T, H] logits = self.ar_pred_layer(mix_sep_dec[:, mix_len:]) # [B, T, H] => [B, T, cb] logits = rearrange(logits, "b t h -> b h t") # [B, T, 1024] => [B, 1024, T] ar_loss = F.cross_entropy(logits, target, reduction="mean") ar_accuracy_metric = self.ar_acc_metric(logits.detach(), target) ar_accuracy_metric = ar_accuracy_metric.detach().cpu().numpy().item() if self.args.stage == 1: return ar_loss, ar_loss, 0.0, ar_accuracy_metric, 0.0 # ========================================= Non AR ========================================= num_nar_layers = self.args.num_cb - 1 nar_stage = self.rng.choices( # Randomly choose a stage from [1, 8] list(range(1, self.args.num_cb)), weights=[1.0 / num_nar_layers] * num_nar_layers, k=1, )[0] # Create prompts mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) # First layer codes sep_embed = self.nar_embedding_layers[0](sep_code) # looks like a pad version of sep_code for j in range(1, self.args.num_cb): mix_embed += self.nar_embedding_layers[j](mix_codes[..., j]) if j < nar_stage: sep_embed += self.nar_embedding_layers[j](sep_codes[..., j]) mix_feat = self.nar_spec_encoder(mix_mag) # [B, T, F] => [B, T, H] mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) # [B, 2*T, H] mix_sep_embed = self.nar_prenet(mix_sep_embed) mix_sep_embed = self.nar_position(mix_sep_embed) target = sep_codes[..., nar_stage] mix_sep_dec, _ = self.nar_decoder( (mix_sep_embed, self.nar_stage_embeddings[nar_stage - 1].weight), src_key_padding_mask=None, # is_causal=False, ) mix_sep_dec = mix_sep_dec[:, mix_len:] logits = self.nar_pred_layers[nar_stage - 1](mix_sep_dec).permute(0, 2, 1) # [B, T, H] => [B, 1024, T] # logits: [B, 1024, T], target: [B, T] nar_loss = F.cross_entropy(logits, target, ignore_index=self.args.num_tokens, reduction="mean") nar_acc_metric = self.nar_acc_metric( F.pad( logits.detach(), (0, 0, 0, 1, 0, 0), value=logits.min().cpu().item(), ), target, ).item() total_loss = ar_loss + nar_loss return total_loss, ar_loss, nar_loss, ar_accuracy_metric, nar_acc_metric def generate(self, mix_wave, sep_wav, top_k=-100, temperature=1.0): """Generate separated waveforms from mixture waveform. Args: mix_wave: mixture waveform, shape of [B, T] """ batch_size, seq_len = mix_wave.shape device = mix_wave.device assert batch_size == 1, f"Only support batch size 1, but got {batch_size}." mix_codes = self._encodec_encode(mix_wave) spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1]) sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) # [B, T, N_q] => [B, 2 * T, N_q] sep_codes = sep_codes[:, 0:1, :] # TODO Fix this back # ========================================= AR ========================================= sep_code = sep_codes[..., 0] # [B, T, N_q] => [B, T] # Make mixture embedding # TODO: We should use the different embedding layer for AR mixture speech as different stage shall have # different embedding space mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) for j in range(1, self.args.num_cb): mix_embed += self.ar_embedding_layer(mix_codes[..., j]) mix_len = mix_embed.shape[1] mix_attn_mask = torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device) while True: sep_embed = self.ar_embedding_layer(sep_code) # [B, T] => [B, T, H] mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) # [B, T + 1, H] in the first iteration mix_sep_embed = self.ar_prenet(mix_sep_embed) mix_sep_embed = self.ar_position(mix_sep_embed) sep_len = sep_code.shape[1] # Create attention mask. It will be padded in all iterations. mix_attn_mask_pad = F.pad(mix_attn_mask, (0, sep_len), value=True) sep_attn_mask = F.pad( torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1), (mix_len, 0), value=False, ) mix_sep_attn_mask = torch.cat([mix_attn_mask_pad, sep_attn_mask], dim=0) mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) # [B, T, H] logits = self.ar_pred_layer(mix_sep_dec[:, -1]) # [B, 1025] # Sample from logits samples = top_k_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) # [B, 1] if ( (torch.argmax(logits, dim=-1)[0] == self.args.num_tokens) or (samples[0][0] == self.args.num_tokens) or (sep_code.shape[1] >= 2 * mix_len) ): print(f"EOS token reached at {sep_code.shape[1]}") break sep_code = torch.cat([sep_code, samples], dim=-1) # ========================================= Non AR ========================================= codes = [sep_code] # looks like a pad version of sep_code # During training, we use the codes from the codebook 1 of original waveform # During inference, we use the codes from the generated codebook 1 # Sep in the NAR stage, there is not mixture. sep_embed = self.nar_embedding_layers[0](sep_code) # Create prompts mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) # First layer codes for j in range(1, self.args.num_cb): mix_embed += self.nar_embedding_layers[j](mix_codes[..., j]) for i, (pred_layer, embed_layer) in enumerate(zip(self.nar_pred_layers, self.nar_embedding_layers[1:])): mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) mix_sep_embed = self.nar_prenet(mix_sep_embed) mix_sep_embed = self.nar_position(mix_sep_embed) mix_sep_dec, _ = self.nar_decoder( (mix_sep_embed, self.nar_stage_embeddings[i].weight), src_key_padding_mask=None, # is_causal=False, ) mix_sep_dec = mix_sep_dec[:, mix_len:] logits = pred_layer(mix_sep_dec) # [B, T, H] => [B, T, 1024] samples = torch.argmax(logits, dim=-1) # [B, T] codes.append(samples) if i < self.args.num_cb - 2: sep_embed += embed_layer(samples) assert len(codes) == self.args.num_cb codes = torch.stack(codes, dim=-1) # Convert codes to waveform sep_wave = self._vocos_decode(codes) return sep_wave if __name__ == "__main__": model = Model() input = torch.rand(2, 22000) target = torch.rand(2, 2, 22000) output = model(input, target) print(output) input = torch.rand(1, 22000) sep_wave = torch.rand(1, 2, 22000) audio_values = model.generate(input, sep_wave) print(audio_values.shape) # print(audio_values.shape)