arse__ar_ss / model_dac_encoder.py
haoxiangsnr's picture
Add files using upload-large-folder tool
1002053 verified
import random
from dataclasses import dataclass
import dac
import torch
import torch.nn as nn
from einops import rearrange
from module import SinePositionalEmbedding, TokenEmbedding
from simple_parsing import Serializable
from torchmetrics.classification import MulticlassAccuracy
from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer
@dataclass
class ModelArgs(Serializable):
num_codebooks: int = 8
codebook_size: int = 1024
d_model: int = 512
num_heads: int = 8
num_tokens: int = 256
codec_ckpt_path: str = "/home/xhao/.cache/descript/dac/weights_16khz_8kbps_0.0.5.pth"
prepend_bos: bool = False # whether to prepend bos token to the target sequence
n_fft: int = 512
class Model(nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.codec = dac.DAC.load(args.codec_ckpt_path)
for param in self.codec.parameters():
param.requires_grad = False
# Initialize AR model
self.ar_audio_prepend_bos = args.prepend_bos
self.ar_embed_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos))
self.ar_spec_encoder = nn.Sequential(
nn.Linear(args.n_fft // 2 + 1, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, args.d_model),
)
self.ar_prenet = nn.Identity()
self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, alpha=True)
# Sequence model
self.ar_decoder = TransformerEncoder(
TransformerEncoderLayer(
args.d_model,
args.num_heads,
dim_feedforward=args.d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=args.norm_first,
),
num_layers=args.num_layers,
norm=LayerNorm(args.d_model) if args.norm_first else None,
)
self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False)
self.ar_acc_metric = MulticlassAccuracy(
args.num_tokens + 1,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=args.num_tokens,
)
# ================================ Non AR ================================
self.nar_spec_encoder = nn.Sequential(
nn.Linear(args.n_fft // 2 + 1, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, args.d_model),
)
self.nar_embedding_layers = nn.ModuleList(
[TokenEmbedding(args.d_model, args.num_tokens + 1)]
+ [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)]
)
self.nar_prenet = nn.Identity()
self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False)
# Sequence model
self.nar_decoder = TransformerEncoder(
TransformerEncoderLayer(
args.d_model,
args.num_heads,
dim_feedforward=args.d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=args.norm_first,
adaptive_layer_norm=True,
),
num_layers=args.num_layers,
norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None,
)
self.nar_pred_layers = nn.ModuleList(
[nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)]
)
self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)])
if args.share_embedding:
for j in range(0, args.num_cb - 2):
# We share the paramters of the acoustic embedding layer and the output prediction layer,
# pred_layer_i
self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight
self.nar_acc_metric = MulticlassAccuracy(
args.num_tokens + 1,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=args.num_tokens,
)
self.args = args
self.rng = random.Random(0)
@torch.no_grad()
def _dac_encode(self, waveform):
waveform = rearrange(waveform, "b n -> b n")