arse__ar_ss / model.py
haoxiangsnr's picture
Add files using upload-large-folder tool
1002053 verified
import random
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from module import SinePositionalEmbedding, TokenEmbedding, top_k_sampling
from simple_parsing import Serializable
from torch.cuda.amp import autocast
from torchmetrics.classification import MulticlassAccuracy
from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer
from transformers import EncodecModel
from vocos import Vocos
from audiozen.acoustics.audio_feature import stft
@dataclass
class ModelArgs(Serializable):
num_cb: int = 8 # number of codebooks
cb_size: int = 1024 # codebook size
d_model: int = 512 # hidden size of transformer
n_fft: int = 768 # number of fft points
hop_length: int = 384 # hop length
num_tokens: int = 1024 # codebook size + 1 (eos token)
num_layers: int = 12 # number of transformer layers
num_heads: int = 8 # number of heads in multi-head attention
norm_first: bool = True # whether to apply layer norm before self-attention
share_embedding: bool = True # whether to share embedding between encoder and decoder
prepend_bos: bool = False # whether to prepend bos token to the target sequence
add_prenet: bool = False # whether to add prenet before transformer
stage: int = 2 # stage of the model
class Model(nn.Module):
def __init__(self, args: ModelArgs = ModelArgs()):
super().__init__()
self.encodec = EncodecModel.from_pretrained("facebook/encodec_24khz")
for param in self.encodec.parameters():
param.requires_grad = False
# ================================ AR ================================
self.ar_audio_prepend_bos = args.prepend_bos
self.ar_embedding_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos))
self.ar_spec_encoder = nn.Sequential(
nn.Linear(args.n_fft // 2 + 1, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, args.d_model),
)
self.ar_prenet = nn.Identity()
self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=True)
# Sequence model
self.ar_decoder = TransformerEncoder(
TransformerEncoderLayer(
args.d_model,
args.num_heads,
dim_feedforward=args.d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=args.norm_first,
),
num_layers=args.num_layers,
norm=LayerNorm(args.d_model) if args.norm_first else None,
)
self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False)
self.ar_acc_metric = MulticlassAccuracy(
args.num_tokens + 1,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=args.num_tokens,
)
# ================================ Non AR ================================
self.nar_spec_encoder = nn.Sequential(
nn.Linear(args.n_fft // 2 + 1, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, args.d_model),
)
self.nar_embedding_layers = nn.ModuleList(
[TokenEmbedding(args.d_model, args.num_tokens + 1)]
+ [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)]
)
self.nar_prenet = nn.Identity()
self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False)
# Sequence model
self.nar_decoder = TransformerEncoder(
TransformerEncoderLayer(
args.d_model,
args.num_heads,
dim_feedforward=args.d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=args.norm_first,
adaptive_layer_norm=True,
),
num_layers=args.num_layers,
norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None,
)
self.nar_pred_layers = nn.ModuleList(
[nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)]
)
self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)])
if args.share_embedding:
for j in range(0, args.num_cb - 2):
# We share the paramters of the acoustic embedding layer and the output prediction layer,
# pred_layer_i
self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight
self.nar_acc_metric = MulticlassAccuracy(
args.num_tokens + 1,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=args.num_tokens,
)
self.args = args
self.rng = random.Random(0)
def _encodec_encode(self, waveform):
"""Encode waveform to codes.
Args:
waveform: shape of [B, T]
Returns:
codes: shape of [B, T, N_q]
"""
with torch.no_grad():
with autocast(dtype=torch.float32):
waveform = rearrange(waveform, "b t -> b () t")
codes = self.encodec.encode(input_values=waveform, return_dict=True, bandwidth=6)
codes = codes.audio_codes # [num_chunks, B, N_q, T], more specifically, [1, B, 8, T] for 6 kbps
codes = rearrange(codes, "c b nq t -> (c b) t nq")
codes = codes.to(dtype=torch.long)
return codes
def _encodec_decode(self, codes):
"""codes with shape [B, T, N_q] => [B, T]"""
with torch.no_grad():
codes = rearrange(codes, "b t nq -> 1 b nq t")
audio_values = self.encodec.decode(audio_codes=codes, audio_scales=[None], return_dict=True).audio_values
audio_values = rearrange(audio_values, "b 1 t -> b t")
return audio_values
def _vocos_decode(self, codes):
"""codes with shape [B, T, N_q] => [B, T]"""
with torch.no_grad():
vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(codes.device)
codes = codes.permute(2, 0, 1) # [B, T, N_q] => [N_q, B, T]
features = vocos.codes_to_features(codes)
audio_values = vocos.decode(features, bandwidth_id=torch.tensor([2], device=codes.device))
return audio_values
def _stack_embeddings(self, codes, cb_idx):
"""Stack embeddings from stage 0 to stage `cb_idx`.
Args:
codes: shape of [B, T, N_q]
cb_idx: codebook index
Returns:
cumsum: shape of [B, T, cb_size]
"""
batch_size, time_steps, _ = codes.shape
embed = torch.zeros(batch_size, time_steps, self.args.cb_size, device=codes.device)
for i in range(0, cb_idx + 1):
emb_i = self.embed_layers[i](codes[..., i])
embed += emb_i
return embed
def _pad_eos(self, codes, eos_id):
"""Pad EOS to codes with shape [B, T] => [B, T+1]"""
codes = F.pad(codes, (0, 1), value=eos_id)
return codes[..., :-1], codes[..., 1:]
def forward(self, mix_wave, sep_wav):
"""Auto-regressive model to predict codes of the separated waveform.
Args:
mix_wave: mixture waveform, shape of [B, T]
sep_wave: sep waveform, shape of [B, 2, T]
"""
# Check input shape, we only support to process two speakers for now
batch_size, num_spks, _ = sep_wav.shape
device = sep_wav.device
assert num_spks == 2, f"Only support to process two speakers, but got {num_spks}"
# Convert waveform to codes
mix_codes = self._encodec_encode(mix_wave) # [B, T, N_q]
spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1])
sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) # [B, T, N_q] => [B, 2 * T, N_q]
# ========================================= AR =========================================
# Pad EOS to codes
sep_code, target = self._pad_eos(sep_codes[..., 0], self.args.num_tokens)
# Make mixture embedding
mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) # [B, T] => [B, T, H]
for j in range(1, self.args.num_cb):
mix_embed += self.ar_embedding_layer(mix_codes[..., j])
sep_embed = self.ar_embedding_layer(sep_code) # [B, T] => [B, T, H]
# Make mixture spectrogram embedding
mix_mag, *_ = stft(
mix_wave, n_fft=self.args.n_fft, hop_length=self.args.hop_length, win_length=self.args.n_fft
) # [B, T] => [B, F, T]
mix_mag = rearrange(mix_mag, "b f t -> b t f") # [B, F, T] => [B, T, F]
mix_mag = torch.log(mix_mag + 1e-8)
mix_feat = self.ar_spec_encoder(mix_mag) # [B, T, F] => [B, T, H]
mix_len = mix_feat.shape[1] + mix_embed.shape[1]
sep_len = sep_embed.shape[1]
mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) # [B, 2*T, H]
mix_sep_embed = self.ar_prenet(mix_sep_embed)
mix_sep_embed = self.ar_position(mix_sep_embed)
mix_attn_mask = F.pad(
torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device), (0, sep_len), value=True
)
sep_attn_mask = F.pad(
torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1),
(mix_len, 0),
value=False,
)
mix_sep_attn_mask = torch.cat([mix_attn_mask, sep_attn_mask], dim=0)
mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) # [B, T, H]
logits = self.ar_pred_layer(mix_sep_dec[:, mix_len:]) # [B, T, H] => [B, T, cb]
logits = rearrange(logits, "b t h -> b h t") # [B, T, 1024] => [B, 1024, T]
ar_loss = F.cross_entropy(logits, target, reduction="mean")
ar_accuracy_metric = self.ar_acc_metric(logits.detach(), target)
ar_accuracy_metric = ar_accuracy_metric.detach().cpu().numpy().item()
if self.args.stage == 1:
return ar_loss, ar_loss, 0.0, ar_accuracy_metric, 0.0
# ========================================= Non AR =========================================
num_nar_layers = self.args.num_cb - 1
nar_stage = self.rng.choices( # Randomly choose a stage from [1, 8]
list(range(1, self.args.num_cb)),
weights=[1.0 / num_nar_layers] * num_nar_layers,
k=1,
)[0]
# Create prompts
mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) # First layer codes
sep_embed = self.nar_embedding_layers[0](sep_code) # looks like a pad version of sep_code
for j in range(1, self.args.num_cb):
mix_embed += self.nar_embedding_layers[j](mix_codes[..., j])
if j < nar_stage:
sep_embed += self.nar_embedding_layers[j](sep_codes[..., j])
mix_feat = self.nar_spec_encoder(mix_mag) # [B, T, F] => [B, T, H]
mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) # [B, 2*T, H]
mix_sep_embed = self.nar_prenet(mix_sep_embed)
mix_sep_embed = self.nar_position(mix_sep_embed)
target = sep_codes[..., nar_stage]
mix_sep_dec, _ = self.nar_decoder(
(mix_sep_embed, self.nar_stage_embeddings[nar_stage - 1].weight),
src_key_padding_mask=None,
# is_causal=False,
)
mix_sep_dec = mix_sep_dec[:, mix_len:]
logits = self.nar_pred_layers[nar_stage - 1](mix_sep_dec).permute(0, 2, 1) # [B, T, H] => [B, 1024, T]
# logits: [B, 1024, T], target: [B, T]
nar_loss = F.cross_entropy(logits, target, ignore_index=self.args.num_tokens, reduction="mean")
nar_acc_metric = self.nar_acc_metric(
F.pad(
logits.detach(),
(0, 0, 0, 1, 0, 0),
value=logits.min().cpu().item(),
),
target,
).item()
total_loss = ar_loss + nar_loss
return total_loss, ar_loss, nar_loss, ar_accuracy_metric, nar_acc_metric
def generate(self, mix_wave, sep_wav, top_k=-100, temperature=1.0):
"""Generate separated waveforms from mixture waveform.
Args:
mix_wave: mixture waveform, shape of [B, T]
"""
batch_size, seq_len = mix_wave.shape
device = mix_wave.device
assert batch_size == 1, f"Only support batch size 1, but got {batch_size}."
mix_codes = self._encodec_encode(mix_wave)
spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1])
sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) # [B, T, N_q] => [B, 2 * T, N_q]
sep_codes = sep_codes[:, 0:1, :] # TODO Fix this back
# ========================================= AR =========================================
sep_code = sep_codes[..., 0] # [B, T, N_q] => [B, T]
# Make mixture embedding
# TODO: We should use the different embedding layer for AR mixture speech as different stage shall have
# different embedding space
mix_embed = self.ar_embedding_layer(mix_codes[..., 0])
for j in range(1, self.args.num_cb):
mix_embed += self.ar_embedding_layer(mix_codes[..., j])
mix_len = mix_embed.shape[1]
mix_attn_mask = torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device)
while True:
sep_embed = self.ar_embedding_layer(sep_code) # [B, T] => [B, T, H]
mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) # [B, T + 1, H] in the first iteration
mix_sep_embed = self.ar_prenet(mix_sep_embed)
mix_sep_embed = self.ar_position(mix_sep_embed)
sep_len = sep_code.shape[1]
# Create attention mask. It will be padded in all iterations.
mix_attn_mask_pad = F.pad(mix_attn_mask, (0, sep_len), value=True)
sep_attn_mask = F.pad(
torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1),
(mix_len, 0),
value=False,
)
mix_sep_attn_mask = torch.cat([mix_attn_mask_pad, sep_attn_mask], dim=0)
mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) # [B, T, H]
logits = self.ar_pred_layer(mix_sep_dec[:, -1]) # [B, 1025]
# Sample from logits
samples = top_k_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) # [B, 1]
if (
(torch.argmax(logits, dim=-1)[0] == self.args.num_tokens)
or (samples[0][0] == self.args.num_tokens)
or (sep_code.shape[1] >= 2 * mix_len)
):
print(f"EOS token reached at {sep_code.shape[1]}")
break
sep_code = torch.cat([sep_code, samples], dim=-1)
# ========================================= Non AR =========================================
codes = [sep_code]
# looks like a pad version of sep_code
# During training, we use the codes from the codebook 1 of original waveform
# During inference, we use the codes from the generated codebook 1
# Sep in the NAR stage, there is not mixture.
sep_embed = self.nar_embedding_layers[0](sep_code)
# Create prompts
mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) # First layer codes
for j in range(1, self.args.num_cb):
mix_embed += self.nar_embedding_layers[j](mix_codes[..., j])
for i, (pred_layer, embed_layer) in enumerate(zip(self.nar_pred_layers, self.nar_embedding_layers[1:])):
mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1)
mix_sep_embed = self.nar_prenet(mix_sep_embed)
mix_sep_embed = self.nar_position(mix_sep_embed)
mix_sep_dec, _ = self.nar_decoder(
(mix_sep_embed, self.nar_stage_embeddings[i].weight),
src_key_padding_mask=None,
# is_causal=False,
)
mix_sep_dec = mix_sep_dec[:, mix_len:]
logits = pred_layer(mix_sep_dec) # [B, T, H] => [B, T, 1024]
samples = torch.argmax(logits, dim=-1) # [B, T]
codes.append(samples)
if i < self.args.num_cb - 2:
sep_embed += embed_layer(samples)
assert len(codes) == self.args.num_cb
codes = torch.stack(codes, dim=-1)
# Convert codes to waveform
sep_wave = self._vocos_decode(codes)
return sep_wave
if __name__ == "__main__":
model = Model()
input = torch.rand(2, 22000)
target = torch.rand(2, 2, 22000)
output = model(input, target)
print(output)
input = torch.rand(1, 22000)
sep_wave = torch.rand(1, 2, 22000)
audio_values = model.generate(input, sep_wave)
print(audio_values.shape)
# print(audio_values.shape)