from __future__ import annotations import argparse from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from audidata.io.crops import RandomCrop from audidata.transforms import Mono from audidata.collate.default import collate_fn from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from llm.llama import Llama, LlamaConfig from utils import LinearWarmUp, parse_yaml class MusicLLM(pl.LightningModule): def __init__(self, configs: dict): super().__init__() self.save_hyperparameters(configs) # Configs self.configs = configs self.sample_rate = configs["sample_rate"] self.max_seq_len = configs["max_caption_len"] self.lr = float(configs["train"]["lr"]) self.warm_up_steps = configs["train"]["warm_up_steps"] # Audio codec (预训练模型,不参与训练) self.codec = self._get_audio_codec(self.configs) self.codec.eval() # 设置为 eval 模式 for param in self.codec.parameters(): param.requires_grad = False # 冻结 codec 参数 # LLM vocab_size = configs["audio_codec"]["vocab_size"] self.llm = self._get_llm(vocab_size) def _get_audio_codec(self, configs: dict) -> nn.Module: name = configs["audio_codec"]["name"] if name == "trancodec_fsq": from audio_codec.trancodec_fsq import Trancodec return Trancodec(configs) def _get_llm(self, vocab_size: int) -> nn.Module: config = LlamaConfig( block_size=self.configs["llm"]["block_size"], vocab_size=vocab_size, n_layer=self.configs["llm"]["n_layer"], n_head=self.configs["llm"]["n_head"], n_embd=self.configs["llm"]["n_embd"] ) return Llama(config=config) def forward(self, audio_codes): return self.llm(audio_codes) def training_step(self, batch, batch_idx): audio, _ = self._process_batch(batch) with torch.no_grad(): # 确保 codec 不计算梯度 self.codec.eval() # 强制 eval 模式 audio_codes = self.codec.encode(audio).squeeze(dim=2) # (b, t, q) # Prepare input and target input_codes = audio_codes[:, :-1] # (b, t-1, q) target_codes = audio_codes[:, 1:].to(torch.long) # (b, t-1, q) # Forward logits = self(input_codes) # (b, t-1, q) # Loss loss = F.cross_entropy( logits.flatten(0, 1), target_codes.flatten(0, 1), reduction="mean" ) self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): audio, _ = self._process_batch(batch) with torch.no_grad(): # 确保 codec 不计算梯度 self.codec.eval() audio_codes = self.codec.encode(audio).squeeze(dim=2) input_codes = audio_codes[:, :-1] target_codes = audio_codes[:, 1:].to(torch.long) with torch.no_grad(): logits = self(input_codes) loss = F.cross_entropy( logits.flatten(0, 1), target_codes.flatten(0, 1), reduction="mean" ) self.log("val_loss", loss, prog_bar=True) return loss def configure_optimizers(self): # 只优化 llm 参数,排除 codec optimizer = torch.optim.AdamW(self.llm.parameters(), lr=self.lr) if self.warm_up_steps: lr_lambda = LinearWarmUp(self.warm_up_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) return [optimizer], [scheduler] return optimizer def _process_batch(self, batch): return batch["audio"], batch["dataset_name"] def get_dataset(configs: dict, split: str) -> torch.utils.data.Dataset: datasets_split = "{}_datasets".format(split) for name in configs[datasets_split].keys(): if name == "Mozart": from audidata.datasets import Mozart dataset = Mozart( root=configs[datasets_split][name]["root"], split=configs[datasets_split][name]["split"], sr=configs["sample_rate"], crop=RandomCrop(clip_duration=configs["clip_duration"], end_pad=configs["clip_duration"] - 0.1), transform=Mono(), ) return dataset elif name == "MAESTRO": from audidata.datasets import MAESTRO from audidata.transforms.midi import PianoRoll dataset = MAESTRO( root=configs[datasets_split][name]["root"], split=configs[datasets_split][name]["split"], sr=configs["sample_rate"], crop=RandomCrop(clip_duration=configs["clip_duration"], end_pad=configs["clip_duration"] - 0.1), transform=Mono(), load_target=True, extend_pedal=True, target_transform=PianoRoll(fps=configs["fps"], pitches_num=128), ) return dataset else: raise ValueError(name) def main(args): # Parse configs configs = parse_yaml(args.config) # Datasets train_dataset = get_dataset(configs, "train") val_dataset = get_dataset(configs, "test") # DataLoaders train_loader = DataLoader( train_dataset, batch_size=configs["train"]["batch_size_per_device"], shuffle=True, num_workers=configs["train"]["num_workers"], collate_fn=collate_fn, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=configs["train"]["batch_size_per_device"], shuffle=False, num_workers=configs["train"]["num_workers"], collate_fn=collate_fn, pin_memory=True ) # Model model = MusicLLM(configs) # Checkpoint callback checkpoint_callback = ModelCheckpoint( dirpath=Path("./checkpoints", Path(__file__).stem, Path(args.config).stem), filename="step={step}", every_n_train_steps=configs["train"]["save_every_n_steps"], save_top_k=-1 ) # Trainer trainer = pl.Trainer( max_steps=configs["train"]["training_steps"], accelerator="gpu", devices=configs["train"]["device"], strategy="ddp", precision=32, callbacks=[checkpoint_callback], check_val_every_n_epoch=configs["train"]["test_every_n_epoch"], ) # Train trainer.fit(model, train_loader, val_loader) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") args = parser.parse_args() main(args)