import random from dataclasses import dataclass import dac import torch import torch.nn as nn from einops import rearrange from module import SinePositionalEmbedding, TokenEmbedding from simple_parsing import Serializable from torchmetrics.classification import MulticlassAccuracy from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer @dataclass class ModelArgs(Serializable): num_codebooks: int = 8 codebook_size: int = 1024 d_model: int = 512 num_heads: int = 8 num_tokens: int = 256 codec_ckpt_path: str = "/home/xhao/.cache/descript/dac/weights_16khz_8kbps_0.0.5.pth" prepend_bos: bool = False # whether to prepend bos token to the target sequence n_fft: int = 512 class Model(nn.Module): def __init__(self, args: ModelArgs) -> None: super().__init__() self.codec = dac.DAC.load(args.codec_ckpt_path) for param in self.codec.parameters(): param.requires_grad = False # Initialize AR model self.ar_audio_prepend_bos = args.prepend_bos self.ar_embed_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos)) self.ar_spec_encoder = nn.Sequential( nn.Linear(args.n_fft // 2 + 1, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, args.d_model), ) self.ar_prenet = nn.Identity() self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, alpha=True) # Sequence model self.ar_decoder = TransformerEncoder( TransformerEncoderLayer( args.d_model, args.num_heads, dim_feedforward=args.d_model * 4, dropout=0.1, batch_first=True, norm_first=args.norm_first, ), num_layers=args.num_layers, norm=LayerNorm(args.d_model) if args.norm_first else None, ) self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False) self.ar_acc_metric = MulticlassAccuracy( args.num_tokens + 1, top_k=10, average="micro", multidim_average="global", ignore_index=args.num_tokens, ) # ================================ Non AR ================================ self.nar_spec_encoder = nn.Sequential( nn.Linear(args.n_fft // 2 + 1, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, args.d_model), ) self.nar_embedding_layers = nn.ModuleList( [TokenEmbedding(args.d_model, args.num_tokens + 1)] + [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)] ) self.nar_prenet = nn.Identity() self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False) # Sequence model self.nar_decoder = TransformerEncoder( TransformerEncoderLayer( args.d_model, args.num_heads, dim_feedforward=args.d_model * 4, dropout=0.1, batch_first=True, norm_first=args.norm_first, adaptive_layer_norm=True, ), num_layers=args.num_layers, norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None, ) self.nar_pred_layers = nn.ModuleList( [nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)] ) self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)]) if args.share_embedding: for j in range(0, args.num_cb - 2): # We share the paramters of the acoustic embedding layer and the output prediction layer, # pred_layer_i self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight self.nar_acc_metric = MulticlassAccuracy( args.num_tokens + 1, top_k=10, average="micro", multidim_average="global", ignore_index=args.num_tokens, ) self.args = args self.rng = random.Random(0) @torch.no_grad() def _dac_encode(self, waveform): waveform = rearrange(waveform, "b n -> b n")