MusicTokenizer / train.py
ZheqiDAI's picture
Initial commit with cleaned files
fc4c601
Raw
History Blame Contribute Delete
7.03 kB
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from audidata.io.crops import RandomCrop
from audidata.transforms import Mono
from audidata.collate.default import collate_fn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from llm.llama import Llama, LlamaConfig
from utils import LinearWarmUp, parse_yaml
class MusicLLM(pl.LightningModule):
def __init__(self, configs: dict):
super().__init__()
self.save_hyperparameters(configs)
# Configs
self.configs = configs
self.sample_rate = configs["sample_rate"]
self.max_seq_len = configs["max_caption_len"]
self.lr = float(configs["train"]["lr"])
self.warm_up_steps = configs["train"]["warm_up_steps"]
# Audio codec (预训练模型,不参与训练)
self.codec = self._get_audio_codec(self.configs)
self.codec.eval() # 设置为 eval 模式
for param in self.codec.parameters():
param.requires_grad = False # 冻结 codec 参数
# LLM
vocab_size = configs["audio_codec"]["vocab_size"]
self.llm = self._get_llm(vocab_size)
def _get_audio_codec(self, configs: dict) -> nn.Module:
name = configs["audio_codec"]["name"]
if name == "trancodec_fsq":
from audio_codec.trancodec_fsq import Trancodec
return Trancodec(configs)
def _get_llm(self, vocab_size: int) -> nn.Module:
config = LlamaConfig(
block_size=self.configs["llm"]["block_size"],
vocab_size=vocab_size,
n_layer=self.configs["llm"]["n_layer"],
n_head=self.configs["llm"]["n_head"],
n_embd=self.configs["llm"]["n_embd"]
)
return Llama(config=config)
def forward(self, audio_codes):
return self.llm(audio_codes)
def training_step(self, batch, batch_idx):
audio, _ = self._process_batch(batch)
with torch.no_grad(): # 确保 codec 不计算梯度
self.codec.eval() # 强制 eval 模式
audio_codes = self.codec.encode(audio).squeeze(dim=2) # (b, t, q)
# Prepare input and target
input_codes = audio_codes[:, :-1] # (b, t-1, q)
target_codes = audio_codes[:, 1:].to(torch.long) # (b, t-1, q)
# Forward
logits = self(input_codes) # (b, t-1, q)
# Loss
loss = F.cross_entropy(
logits.flatten(0, 1),
target_codes.flatten(0, 1),
reduction="mean"
)
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
audio, _ = self._process_batch(batch)
with torch.no_grad(): # 确保 codec 不计算梯度
self.codec.eval()
audio_codes = self.codec.encode(audio).squeeze(dim=2)
input_codes = audio_codes[:, :-1]
target_codes = audio_codes[:, 1:].to(torch.long)
with torch.no_grad():
logits = self(input_codes)
loss = F.cross_entropy(
logits.flatten(0, 1),
target_codes.flatten(0, 1),
reduction="mean"
)
self.log("val_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
# 只优化 llm 参数,排除 codec
optimizer = torch.optim.AdamW(self.llm.parameters(), lr=self.lr)
if self.warm_up_steps:
lr_lambda = LinearWarmUp(self.warm_up_steps)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return [optimizer], [scheduler]
return optimizer
def _process_batch(self, batch):
return batch["audio"], batch["dataset_name"]
def get_dataset(configs: dict, split: str) -> torch.utils.data.Dataset:
datasets_split = "{}_datasets".format(split)
for name in configs[datasets_split].keys():
if name == "Mozart":
from audidata.datasets import Mozart
dataset = Mozart(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=configs["sample_rate"],
crop=RandomCrop(clip_duration=configs["clip_duration"], end_pad=configs["clip_duration"] - 0.1),
transform=Mono(),
)
return dataset
elif name == "MAESTRO":
from audidata.datasets import MAESTRO
from audidata.transforms.midi import PianoRoll
dataset = MAESTRO(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=configs["sample_rate"],
crop=RandomCrop(clip_duration=configs["clip_duration"], end_pad=configs["clip_duration"] - 0.1),
transform=Mono(),
load_target=True,
extend_pedal=True,
target_transform=PianoRoll(fps=configs["fps"], pitches_num=128),
)
return dataset
else:
raise ValueError(name)
def main(args):
# Parse configs
configs = parse_yaml(args.config)
# Datasets
train_dataset = get_dataset(configs, "train")
val_dataset = get_dataset(configs, "test")
# DataLoaders
train_loader = DataLoader(
train_dataset,
batch_size=configs["train"]["batch_size_per_device"],
shuffle=True,
num_workers=configs["train"]["num_workers"],
collate_fn=collate_fn,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=configs["train"]["batch_size_per_device"],
shuffle=False,
num_workers=configs["train"]["num_workers"],
collate_fn=collate_fn,
pin_memory=True
)
# Model
model = MusicLLM(configs)
# Checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath=Path("./checkpoints", Path(__file__).stem, Path(args.config).stem),
filename="step={step}",
every_n_train_steps=configs["train"]["save_every_n_steps"],
save_top_k=-1
)
# Trainer
trainer = pl.Trainer(
max_steps=configs["train"]["training_steps"],
accelerator="gpu",
devices=configs["train"]["device"],
strategy="ddp",
precision=32,
callbacks=[checkpoint_callback],
check_val_every_n_epoch=configs["train"]["test_every_n_epoch"],
)
# Train
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path of config yaml.")
args = parser.parse_args()
main(args)