| from __future__ import annotations |
| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from audidata.io.crops import RandomCrop |
| from audidata.transforms import Mono |
| from audidata.collate.default import collate_fn |
| from torch.utils.data import DataLoader |
| import pytorch_lightning as pl |
| from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
| from llm.llama import Llama, LlamaConfig |
| from utils import LinearWarmUp, parse_yaml |
|
|
| class MusicLLM(pl.LightningModule): |
| def __init__(self, configs: dict): |
| super().__init__() |
| self.save_hyperparameters(configs) |
| |
| |
| self.configs = configs |
| self.sample_rate = configs["sample_rate"] |
| self.max_seq_len = configs["max_caption_len"] |
| self.lr = float(configs["train"]["lr"]) |
| self.warm_up_steps = configs["train"]["warm_up_steps"] |
| |
| |
| self.codec = self._get_audio_codec(self.configs) |
| self.codec.eval() |
| for param in self.codec.parameters(): |
| param.requires_grad = False |
| |
| |
| vocab_size = configs["audio_codec"]["vocab_size"] |
| self.llm = self._get_llm(vocab_size) |
| |
| def _get_audio_codec(self, configs: dict) -> nn.Module: |
| name = configs["audio_codec"]["name"] |
| if name == "trancodec_fsq": |
| from audio_codec.trancodec_fsq import Trancodec |
| return Trancodec(configs) |
|
|
| def _get_llm(self, vocab_size: int) -> nn.Module: |
| config = LlamaConfig( |
| block_size=self.configs["llm"]["block_size"], |
| vocab_size=vocab_size, |
| n_layer=self.configs["llm"]["n_layer"], |
| n_head=self.configs["llm"]["n_head"], |
| n_embd=self.configs["llm"]["n_embd"] |
| ) |
| return Llama(config=config) |
|
|
| def forward(self, audio_codes): |
| return self.llm(audio_codes) |
|
|
| def training_step(self, batch, batch_idx): |
| audio, _ = self._process_batch(batch) |
| with torch.no_grad(): |
| self.codec.eval() |
| audio_codes = self.codec.encode(audio).squeeze(dim=2) |
| |
| |
| input_codes = audio_codes[:, :-1] |
| target_codes = audio_codes[:, 1:].to(torch.long) |
| |
| |
| logits = self(input_codes) |
| |
| |
| loss = F.cross_entropy( |
| logits.flatten(0, 1), |
| target_codes.flatten(0, 1), |
| reduction="mean" |
| ) |
| |
| self.log("train_loss", loss, prog_bar=True) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| audio, _ = self._process_batch(batch) |
| with torch.no_grad(): |
| self.codec.eval() |
| audio_codes = self.codec.encode(audio).squeeze(dim=2) |
| |
| input_codes = audio_codes[:, :-1] |
| target_codes = audio_codes[:, 1:].to(torch.long) |
| |
| with torch.no_grad(): |
| logits = self(input_codes) |
| loss = F.cross_entropy( |
| logits.flatten(0, 1), |
| target_codes.flatten(0, 1), |
| reduction="mean" |
| ) |
| |
| self.log("val_loss", loss, prog_bar=True) |
| return loss |
|
|
| def configure_optimizers(self): |
| |
| optimizer = torch.optim.AdamW(self.llm.parameters(), lr=self.lr) |
| |
| if self.warm_up_steps: |
| lr_lambda = LinearWarmUp(self.warm_up_steps) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| return [optimizer], [scheduler] |
| return optimizer |
|
|
| def _process_batch(self, batch): |
| return batch["audio"], batch["dataset_name"] |
|
|
| def get_dataset(configs: dict, split: str) -> torch.utils.data.Dataset: |
|
|
| |
| datasets_split = "{}_datasets".format(split) |
|
|
| for name in configs[datasets_split].keys(): |
| if name == "Mozart": |
| from audidata.datasets import Mozart |
| dataset = Mozart( |
| root=configs[datasets_split][name]["root"], |
| split=configs[datasets_split][name]["split"], |
| sr=configs["sample_rate"], |
| crop=RandomCrop(clip_duration=configs["clip_duration"], end_pad=configs["clip_duration"] - 0.1), |
| transform=Mono(), |
| ) |
| return dataset |
|
|
| elif name == "MAESTRO": |
| from audidata.datasets import MAESTRO |
| from audidata.transforms.midi import PianoRoll |
| dataset = MAESTRO( |
| root=configs[datasets_split][name]["root"], |
| split=configs[datasets_split][name]["split"], |
| sr=configs["sample_rate"], |
| crop=RandomCrop(clip_duration=configs["clip_duration"], end_pad=configs["clip_duration"] - 0.1), |
| transform=Mono(), |
| load_target=True, |
| extend_pedal=True, |
| target_transform=PianoRoll(fps=configs["fps"], pitches_num=128), |
| ) |
| return dataset |
| else: |
| raise ValueError(name) |
|
|
| def main(args): |
| |
| configs = parse_yaml(args.config) |
| |
| |
| train_dataset = get_dataset(configs, "train") |
| val_dataset = get_dataset(configs, "test") |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=configs["train"]["batch_size_per_device"], |
| shuffle=True, |
| num_workers=configs["train"]["num_workers"], |
| collate_fn=collate_fn, |
| pin_memory=True |
| ) |
| |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=configs["train"]["batch_size_per_device"], |
| shuffle=False, |
| num_workers=configs["train"]["num_workers"], |
| collate_fn=collate_fn, |
| pin_memory=True |
| ) |
| |
| |
| model = MusicLLM(configs) |
| |
| |
| checkpoint_callback = ModelCheckpoint( |
| dirpath=Path("./checkpoints", Path(__file__).stem, Path(args.config).stem), |
| filename="step={step}", |
| every_n_train_steps=configs["train"]["save_every_n_steps"], |
| save_top_k=-1 |
| ) |
| |
| |
| trainer = pl.Trainer( |
| max_steps=configs["train"]["training_steps"], |
| accelerator="gpu", |
| devices=configs["train"]["device"], |
| strategy="ddp", |
| precision=32, |
| callbacks=[checkpoint_callback], |
| check_val_every_n_epoch=configs["train"]["test_every_n_epoch"], |
| ) |
| |
| |
| trainer.fit(model, train_loader, val_loader) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") |
| args = parser.parse_args() |
| main(args) |