| | import random |
| | from dataclasses import dataclass |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from module import SinePositionalEmbedding, TokenEmbedding, top_k_sampling |
| | from simple_parsing import Serializable |
| | from torch.cuda.amp import autocast |
| | from torchmetrics.classification import MulticlassAccuracy |
| | from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer |
| | from transformers import EncodecModel |
| | from vocos import Vocos |
| |
|
| | from audiozen.acoustics.audio_feature import stft |
| |
|
| |
|
| | @dataclass |
| | class ModelArgs(Serializable): |
| | num_cb: int = 8 |
| | cb_size: int = 1024 |
| | d_model: int = 512 |
| | n_fft: int = 768 |
| | hop_length: int = 384 |
| | num_tokens: int = 1024 |
| | num_layers: int = 12 |
| | num_heads: int = 8 |
| | norm_first: bool = True |
| | share_embedding: bool = True |
| | prepend_bos: bool = False |
| | add_prenet: bool = False |
| | stage: int = 2 |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__(self, args: ModelArgs = ModelArgs()): |
| | super().__init__() |
| | self.encodec = EncodecModel.from_pretrained("facebook/encodec_24khz") |
| | for param in self.encodec.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | self.ar_audio_prepend_bos = args.prepend_bos |
| | self.ar_embedding_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos)) |
| | self.ar_spec_encoder = nn.Sequential( |
| | nn.Linear(args.n_fft // 2 + 1, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(256, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(256, args.d_model), |
| | ) |
| | self.ar_prenet = nn.Identity() |
| | self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=True) |
| |
|
| | |
| | self.ar_decoder = TransformerEncoder( |
| | TransformerEncoderLayer( |
| | args.d_model, |
| | args.num_heads, |
| | dim_feedforward=args.d_model * 4, |
| | dropout=0.1, |
| | batch_first=True, |
| | norm_first=args.norm_first, |
| | ), |
| | num_layers=args.num_layers, |
| | norm=LayerNorm(args.d_model) if args.norm_first else None, |
| | ) |
| |
|
| | self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False) |
| |
|
| | self.ar_acc_metric = MulticlassAccuracy( |
| | args.num_tokens + 1, |
| | top_k=10, |
| | average="micro", |
| | multidim_average="global", |
| | ignore_index=args.num_tokens, |
| | ) |
| |
|
| | |
| | self.nar_spec_encoder = nn.Sequential( |
| | nn.Linear(args.n_fft // 2 + 1, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(256, 256), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(256, args.d_model), |
| | ) |
| | self.nar_embedding_layers = nn.ModuleList( |
| | [TokenEmbedding(args.d_model, args.num_tokens + 1)] |
| | + [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)] |
| | ) |
| | self.nar_prenet = nn.Identity() |
| | self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False) |
| |
|
| | |
| | self.nar_decoder = TransformerEncoder( |
| | TransformerEncoderLayer( |
| | args.d_model, |
| | args.num_heads, |
| | dim_feedforward=args.d_model * 4, |
| | dropout=0.1, |
| | batch_first=True, |
| | norm_first=args.norm_first, |
| | adaptive_layer_norm=True, |
| | ), |
| | num_layers=args.num_layers, |
| | norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None, |
| | ) |
| |
|
| | self.nar_pred_layers = nn.ModuleList( |
| | [nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)] |
| | ) |
| |
|
| | self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)]) |
| |
|
| | if args.share_embedding: |
| | for j in range(0, args.num_cb - 2): |
| | |
| | |
| | self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight |
| |
|
| | self.nar_acc_metric = MulticlassAccuracy( |
| | args.num_tokens + 1, |
| | top_k=10, |
| | average="micro", |
| | multidim_average="global", |
| | ignore_index=args.num_tokens, |
| | ) |
| |
|
| | self.args = args |
| | self.rng = random.Random(0) |
| |
|
| | def _encodec_encode(self, waveform): |
| | """Encode waveform to codes. |
| | Args: |
| | waveform: shape of [B, T] |
| | Returns: |
| | codes: shape of [B, T, N_q] |
| | """ |
| | with torch.no_grad(): |
| | with autocast(dtype=torch.float32): |
| | waveform = rearrange(waveform, "b t -> b () t") |
| | codes = self.encodec.encode(input_values=waveform, return_dict=True, bandwidth=6) |
| | codes = codes.audio_codes |
| | codes = rearrange(codes, "c b nq t -> (c b) t nq") |
| | codes = codes.to(dtype=torch.long) |
| | return codes |
| |
|
| | def _encodec_decode(self, codes): |
| | """codes with shape [B, T, N_q] => [B, T]""" |
| | with torch.no_grad(): |
| | codes = rearrange(codes, "b t nq -> 1 b nq t") |
| | audio_values = self.encodec.decode(audio_codes=codes, audio_scales=[None], return_dict=True).audio_values |
| | audio_values = rearrange(audio_values, "b 1 t -> b t") |
| | return audio_values |
| |
|
| | def _vocos_decode(self, codes): |
| | """codes with shape [B, T, N_q] => [B, T]""" |
| | with torch.no_grad(): |
| | vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(codes.device) |
| |
|
| | codes = codes.permute(2, 0, 1) |
| | features = vocos.codes_to_features(codes) |
| | audio_values = vocos.decode(features, bandwidth_id=torch.tensor([2], device=codes.device)) |
| | return audio_values |
| |
|
| | def _stack_embeddings(self, codes, cb_idx): |
| | """Stack embeddings from stage 0 to stage `cb_idx`. |
| | Args: |
| | codes: shape of [B, T, N_q] |
| | cb_idx: codebook index |
| | |
| | Returns: |
| | cumsum: shape of [B, T, cb_size] |
| | """ |
| | batch_size, time_steps, _ = codes.shape |
| | embed = torch.zeros(batch_size, time_steps, self.args.cb_size, device=codes.device) |
| | for i in range(0, cb_idx + 1): |
| | emb_i = self.embed_layers[i](codes[..., i]) |
| | embed += emb_i |
| | return embed |
| |
|
| | def _pad_eos(self, codes, eos_id): |
| | """Pad EOS to codes with shape [B, T] => [B, T+1]""" |
| | codes = F.pad(codes, (0, 1), value=eos_id) |
| |
|
| | return codes[..., :-1], codes[..., 1:] |
| |
|
| | def forward(self, mix_wave, sep_wav): |
| | """Auto-regressive model to predict codes of the separated waveform. |
| | |
| | Args: |
| | mix_wave: mixture waveform, shape of [B, T] |
| | sep_wave: sep waveform, shape of [B, 2, T] |
| | """ |
| | |
| | batch_size, num_spks, _ = sep_wav.shape |
| | device = sep_wav.device |
| | assert num_spks == 2, f"Only support to process two speakers, but got {num_spks}" |
| |
|
| | |
| | mix_codes = self._encodec_encode(mix_wave) |
| | spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1]) |
| | sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) |
| |
|
| | |
| | |
| | sep_code, target = self._pad_eos(sep_codes[..., 0], self.args.num_tokens) |
| |
|
| | |
| | mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) |
| | for j in range(1, self.args.num_cb): |
| | mix_embed += self.ar_embedding_layer(mix_codes[..., j]) |
| | sep_embed = self.ar_embedding_layer(sep_code) |
| |
|
| | |
| | mix_mag, *_ = stft( |
| | mix_wave, n_fft=self.args.n_fft, hop_length=self.args.hop_length, win_length=self.args.n_fft |
| | ) |
| | mix_mag = rearrange(mix_mag, "b f t -> b t f") |
| | mix_mag = torch.log(mix_mag + 1e-8) |
| | mix_feat = self.ar_spec_encoder(mix_mag) |
| |
|
| | mix_len = mix_feat.shape[1] + mix_embed.shape[1] |
| | sep_len = sep_embed.shape[1] |
| | mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) |
| | mix_sep_embed = self.ar_prenet(mix_sep_embed) |
| | mix_sep_embed = self.ar_position(mix_sep_embed) |
| |
|
| | mix_attn_mask = F.pad( |
| | torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device), (0, sep_len), value=True |
| | ) |
| |
|
| | sep_attn_mask = F.pad( |
| | torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1), |
| | (mix_len, 0), |
| | value=False, |
| | ) |
| |
|
| | mix_sep_attn_mask = torch.cat([mix_attn_mask, sep_attn_mask], dim=0) |
| |
|
| | mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) |
| | logits = self.ar_pred_layer(mix_sep_dec[:, mix_len:]) |
| | logits = rearrange(logits, "b t h -> b h t") |
| | ar_loss = F.cross_entropy(logits, target, reduction="mean") |
| | ar_accuracy_metric = self.ar_acc_metric(logits.detach(), target) |
| | ar_accuracy_metric = ar_accuracy_metric.detach().cpu().numpy().item() |
| |
|
| | if self.args.stage == 1: |
| | return ar_loss, ar_loss, 0.0, ar_accuracy_metric, 0.0 |
| | |
| | num_nar_layers = self.args.num_cb - 1 |
| | nar_stage = self.rng.choices( |
| | list(range(1, self.args.num_cb)), |
| | weights=[1.0 / num_nar_layers] * num_nar_layers, |
| | k=1, |
| | )[0] |
| |
|
| | |
| | mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) |
| | sep_embed = self.nar_embedding_layers[0](sep_code) |
| |
|
| | for j in range(1, self.args.num_cb): |
| | mix_embed += self.nar_embedding_layers[j](mix_codes[..., j]) |
| | if j < nar_stage: |
| | sep_embed += self.nar_embedding_layers[j](sep_codes[..., j]) |
| |
|
| | mix_feat = self.nar_spec_encoder(mix_mag) |
| |
|
| | mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) |
| | mix_sep_embed = self.nar_prenet(mix_sep_embed) |
| | mix_sep_embed = self.nar_position(mix_sep_embed) |
| |
|
| | target = sep_codes[..., nar_stage] |
| |
|
| | mix_sep_dec, _ = self.nar_decoder( |
| | (mix_sep_embed, self.nar_stage_embeddings[nar_stage - 1].weight), |
| | src_key_padding_mask=None, |
| | |
| | ) |
| |
|
| | mix_sep_dec = mix_sep_dec[:, mix_len:] |
| | logits = self.nar_pred_layers[nar_stage - 1](mix_sep_dec).permute(0, 2, 1) |
| |
|
| | |
| | nar_loss = F.cross_entropy(logits, target, ignore_index=self.args.num_tokens, reduction="mean") |
| |
|
| | nar_acc_metric = self.nar_acc_metric( |
| | F.pad( |
| | logits.detach(), |
| | (0, 0, 0, 1, 0, 0), |
| | value=logits.min().cpu().item(), |
| | ), |
| | target, |
| | ).item() |
| |
|
| | total_loss = ar_loss + nar_loss |
| |
|
| | return total_loss, ar_loss, nar_loss, ar_accuracy_metric, nar_acc_metric |
| |
|
| | def generate(self, mix_wave, sep_wav, top_k=-100, temperature=1.0): |
| | """Generate separated waveforms from mixture waveform. |
| | |
| | Args: |
| | mix_wave: mixture waveform, shape of [B, T] |
| | """ |
| | batch_size, seq_len = mix_wave.shape |
| | device = mix_wave.device |
| | assert batch_size == 1, f"Only support batch size 1, but got {batch_size}." |
| |
|
| | mix_codes = self._encodec_encode(mix_wave) |
| | spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1]) |
| | sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) |
| | sep_codes = sep_codes[:, 0:1, :] |
| |
|
| | |
| | sep_code = sep_codes[..., 0] |
| |
|
| | |
| | |
| | |
| | mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) |
| | for j in range(1, self.args.num_cb): |
| | mix_embed += self.ar_embedding_layer(mix_codes[..., j]) |
| |
|
| | mix_len = mix_embed.shape[1] |
| | mix_attn_mask = torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device) |
| |
|
| | while True: |
| | sep_embed = self.ar_embedding_layer(sep_code) |
| | mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) |
| | mix_sep_embed = self.ar_prenet(mix_sep_embed) |
| | mix_sep_embed = self.ar_position(mix_sep_embed) |
| |
|
| | sep_len = sep_code.shape[1] |
| |
|
| | |
| | mix_attn_mask_pad = F.pad(mix_attn_mask, (0, sep_len), value=True) |
| |
|
| | sep_attn_mask = F.pad( |
| | torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1), |
| | (mix_len, 0), |
| | value=False, |
| | ) |
| |
|
| | mix_sep_attn_mask = torch.cat([mix_attn_mask_pad, sep_attn_mask], dim=0) |
| |
|
| | mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) |
| | logits = self.ar_pred_layer(mix_sep_dec[:, -1]) |
| |
|
| | |
| | samples = top_k_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) |
| |
|
| | if ( |
| | (torch.argmax(logits, dim=-1)[0] == self.args.num_tokens) |
| | or (samples[0][0] == self.args.num_tokens) |
| | or (sep_code.shape[1] >= 2 * mix_len) |
| | ): |
| | print(f"EOS token reached at {sep_code.shape[1]}") |
| | break |
| |
|
| | sep_code = torch.cat([sep_code, samples], dim=-1) |
| |
|
| | |
| | codes = [sep_code] |
| | |
| | |
| | |
| | |
| | sep_embed = self.nar_embedding_layers[0](sep_code) |
| |
|
| | |
| | mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) |
| | for j in range(1, self.args.num_cb): |
| | mix_embed += self.nar_embedding_layers[j](mix_codes[..., j]) |
| |
|
| | for i, (pred_layer, embed_layer) in enumerate(zip(self.nar_pred_layers, self.nar_embedding_layers[1:])): |
| | mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) |
| | mix_sep_embed = self.nar_prenet(mix_sep_embed) |
| | mix_sep_embed = self.nar_position(mix_sep_embed) |
| |
|
| | mix_sep_dec, _ = self.nar_decoder( |
| | (mix_sep_embed, self.nar_stage_embeddings[i].weight), |
| | src_key_padding_mask=None, |
| | |
| | ) |
| |
|
| | mix_sep_dec = mix_sep_dec[:, mix_len:] |
| | logits = pred_layer(mix_sep_dec) |
| | samples = torch.argmax(logits, dim=-1) |
| | codes.append(samples) |
| |
|
| | if i < self.args.num_cb - 2: |
| | sep_embed += embed_layer(samples) |
| |
|
| | assert len(codes) == self.args.num_cb |
| | codes = torch.stack(codes, dim=-1) |
| |
|
| | |
| | sep_wave = self._vocos_decode(codes) |
| |
|
| | return sep_wave |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = Model() |
| | input = torch.rand(2, 22000) |
| | target = torch.rand(2, 2, 22000) |
| | output = model(input, target) |
| | print(output) |
| | input = torch.rand(1, 22000) |
| | sep_wave = torch.rand(1, 2, 22000) |
| | audio_values = model.generate(input, sep_wave) |
| | print(audio_values.shape) |
| |
|
| | |
| |
|