|
|
import random |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import dac |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange |
|
|
from module import SinePositionalEmbedding, TokenEmbedding |
|
|
from simple_parsing import Serializable |
|
|
from torchmetrics.classification import MulticlassAccuracy |
|
|
from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArgs(Serializable): |
|
|
num_codebooks: int = 8 |
|
|
codebook_size: int = 1024 |
|
|
d_model: int = 512 |
|
|
num_heads: int = 8 |
|
|
num_tokens: int = 256 |
|
|
codec_ckpt_path: str = "/home/xhao/.cache/descript/dac/weights_16khz_8kbps_0.0.5.pth" |
|
|
prepend_bos: bool = False |
|
|
n_fft: int = 512 |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, args: ModelArgs) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.codec = dac.DAC.load(args.codec_ckpt_path) |
|
|
for param in self.codec.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.ar_audio_prepend_bos = args.prepend_bos |
|
|
self.ar_embed_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos)) |
|
|
self.ar_spec_encoder = nn.Sequential( |
|
|
nn.Linear(args.n_fft // 2 + 1, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, args.d_model), |
|
|
) |
|
|
self.ar_prenet = nn.Identity() |
|
|
self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, alpha=True) |
|
|
|
|
|
|
|
|
self.ar_decoder = TransformerEncoder( |
|
|
TransformerEncoderLayer( |
|
|
args.d_model, |
|
|
args.num_heads, |
|
|
dim_feedforward=args.d_model * 4, |
|
|
dropout=0.1, |
|
|
batch_first=True, |
|
|
norm_first=args.norm_first, |
|
|
), |
|
|
num_layers=args.num_layers, |
|
|
norm=LayerNorm(args.d_model) if args.norm_first else None, |
|
|
) |
|
|
|
|
|
self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False) |
|
|
self.ar_acc_metric = MulticlassAccuracy( |
|
|
args.num_tokens + 1, |
|
|
top_k=10, |
|
|
average="micro", |
|
|
multidim_average="global", |
|
|
ignore_index=args.num_tokens, |
|
|
) |
|
|
|
|
|
|
|
|
self.nar_spec_encoder = nn.Sequential( |
|
|
nn.Linear(args.n_fft // 2 + 1, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, args.d_model), |
|
|
) |
|
|
|
|
|
self.nar_embedding_layers = nn.ModuleList( |
|
|
[TokenEmbedding(args.d_model, args.num_tokens + 1)] |
|
|
+ [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)] |
|
|
) |
|
|
|
|
|
self.nar_prenet = nn.Identity() |
|
|
self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False) |
|
|
|
|
|
|
|
|
self.nar_decoder = TransformerEncoder( |
|
|
TransformerEncoderLayer( |
|
|
args.d_model, |
|
|
args.num_heads, |
|
|
dim_feedforward=args.d_model * 4, |
|
|
dropout=0.1, |
|
|
batch_first=True, |
|
|
norm_first=args.norm_first, |
|
|
adaptive_layer_norm=True, |
|
|
), |
|
|
num_layers=args.num_layers, |
|
|
norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None, |
|
|
) |
|
|
|
|
|
self.nar_pred_layers = nn.ModuleList( |
|
|
[nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)] |
|
|
) |
|
|
|
|
|
self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)]) |
|
|
|
|
|
if args.share_embedding: |
|
|
for j in range(0, args.num_cb - 2): |
|
|
|
|
|
|
|
|
self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight |
|
|
|
|
|
self.nar_acc_metric = MulticlassAccuracy( |
|
|
args.num_tokens + 1, |
|
|
top_k=10, |
|
|
average="micro", |
|
|
multidim_average="global", |
|
|
ignore_index=args.num_tokens, |
|
|
) |
|
|
|
|
|
self.args = args |
|
|
self.rng = random.Random(0) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _dac_encode(self, waveform): |
|
|
waveform = rearrange(waveform, "b n -> b n") |
|
|
|