Upload train.py
Browse files
train.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torchvision.utils import make_grid, save_image
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from ddt_model import LocalSongModel
|
| 8 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 9 |
+
from datasets import load_from_disk
|
| 10 |
+
from accelerate import Accelerator
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from collections import deque
|
| 16 |
+
import torchaudio
|
| 17 |
+
import re
|
| 18 |
+
import sys
|
| 19 |
+
import math
|
| 20 |
+
from tag_embedder import TagEmbedder
|
| 21 |
+
|
| 22 |
+
# Import MusicDCAE
|
| 23 |
+
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
|
| 24 |
+
|
| 25 |
+
# Import Muon optimizer
|
| 26 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 27 |
+
import timm.optim
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 32 |
+
|
| 33 |
+
def save(model, optimizer, scheduler, global_step, accelerator):
|
| 34 |
+
if accelerator.is_main_process:
|
| 35 |
+
checkpoint_dir = "checkpoints"
|
| 36 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 39 |
+
|
| 40 |
+
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth")
|
| 41 |
+
save_dict = {
|
| 42 |
+
'model_state_dict': unwrapped_model.state_dict(),
|
| 43 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 44 |
+
'global_step': global_step
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
accelerator.save(save_dict, checkpoint_path)
|
| 48 |
+
print(f"Checkpoint saved at step {global_step}: {checkpoint_path}")
|
| 49 |
+
|
| 50 |
+
checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")],
|
| 51 |
+
key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True)
|
| 52 |
+
|
| 53 |
+
for old_checkpoint in checkpoints[5:]:
|
| 54 |
+
os.remove(os.path.join(checkpoint_dir, old_checkpoint))
|
| 55 |
+
print(f"Removed old checkpoint: {old_checkpoint}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator):
|
| 59 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 60 |
+
|
| 61 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 62 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()}
|
| 63 |
+
missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True)
|
| 64 |
+
print("MISSING:", missing)
|
| 65 |
+
print("UNEXPECTED:", unexpected)
|
| 66 |
+
|
| 67 |
+
if 'optimizer_state_dict' in checkpoint:
|
| 68 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 69 |
+
print("Optimizer loaded")
|
| 70 |
+
|
| 71 |
+
global_step = checkpoint['global_step']
|
| 72 |
+
print(f"Resumed from step {global_step}")
|
| 73 |
+
return global_step
|
| 74 |
+
|
| 75 |
+
def resume(model, optimizer, scheduler, accelerator):
|
| 76 |
+
checkpoint_dir = "checkpoints"
|
| 77 |
+
if os.path.exists(checkpoint_dir):
|
| 78 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")]
|
| 79 |
+
if checkpoints:
|
| 80 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0]))
|
| 81 |
+
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
|
| 82 |
+
if accelerator.is_main_process:
|
| 83 |
+
print(f"Resuming from checkpoint: {checkpoint_path}")
|
| 84 |
+
|
| 85 |
+
return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator)
|
| 86 |
+
else:
|
| 87 |
+
if accelerator.is_main_process:
|
| 88 |
+
print("No checkpoints found. Starting from scratch.")
|
| 89 |
+
else:
|
| 90 |
+
if accelerator.is_main_process:
|
| 91 |
+
print("Checkpoint directory not found. Starting from scratch.")
|
| 92 |
+
|
| 93 |
+
return 0
|
| 94 |
+
|
| 95 |
+
class AudioVAE:
|
| 96 |
+
def __init__(self, device):
|
| 97 |
+
self.model = MusicDCAE().to(device)
|
| 98 |
+
self.model.eval()
|
| 99 |
+
self.device = device
|
| 100 |
+
|
| 101 |
+
self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1)
|
| 102 |
+
self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1)
|
| 103 |
+
|
| 104 |
+
def encode(self, audio):
|
| 105 |
+
"""Encode audio to latents"""
|
| 106 |
+
# audio should be (B, 2, T) at 48kHz
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device)
|
| 109 |
+
latents, _ = self.model.encode(audio, audio_lengths, sr=48000)
|
| 110 |
+
# Normalize latents: (latents - mean) / std
|
| 111 |
+
latents = (latents - self.latent_mean) / self.latent_std
|
| 112 |
+
return latents
|
| 113 |
+
|
| 114 |
+
def decode(self, latents):
|
| 115 |
+
"""Decode latents to audio"""
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
# Denormalize latents: latents * std + mean
|
| 118 |
+
latents = latents * self.latent_std + self.latent_mean
|
| 119 |
+
sr, audio_list = self.model.decode(latents, sr=48000)
|
| 120 |
+
# Convert list of audio tensors to batch tensor
|
| 121 |
+
audio_batch = torch.stack(audio_list).to(self.device)
|
| 122 |
+
return audio_batch
|
| 123 |
+
|
| 124 |
+
class RF:
|
| 125 |
+
def __init__(self, model, time_sampling="sigmoid"):
|
| 126 |
+
self.model = model
|
| 127 |
+
self.time_sampling = time_sampling
|
| 128 |
+
|
| 129 |
+
def sample_timesteps(self, batch, device):
|
| 130 |
+
"""Sample timesteps based on the configured strategy."""
|
| 131 |
+
if self.time_sampling == "sigmoid":
|
| 132 |
+
return torch.sigmoid(torch.randn((batch,), device=device))
|
| 133 |
+
elif self.time_sampling == "warped":
|
| 134 |
+
pm = 128 * 16 * 16
|
| 135 |
+
alpha = max(1.0, math.sqrt(pm / 4096.0))
|
| 136 |
+
u = torch.rand(batch, device=device)
|
| 137 |
+
return alpha * u / (1.0 + (alpha - 1.0) * u)
|
| 138 |
+
elif self.time_sampling == "uniform":
|
| 139 |
+
return torch.rand(batch, device=device)
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}")
|
| 142 |
+
|
| 143 |
+
def forward(self, x, cond):
|
| 144 |
+
b = x.size(0)
|
| 145 |
+
|
| 146 |
+
t = self.sample_timesteps(b, x.device)
|
| 147 |
+
|
| 148 |
+
texp = t.view([b, *([1] * len(x.shape[1:]))])
|
| 149 |
+
z1 = torch.randn_like(x)
|
| 150 |
+
zt = (1 - texp) * x + texp * z1
|
| 151 |
+
|
| 152 |
+
x_pred = self.model(zt, t, cond)
|
| 153 |
+
|
| 154 |
+
target = (zt - x) / (texp + 0.05)
|
| 155 |
+
v_pred = (zt - x_pred) / (texp + 0.05)
|
| 156 |
+
loss = F.mse_loss(target, v_pred)
|
| 157 |
+
|
| 158 |
+
return loss
|
| 159 |
+
|
| 160 |
+
def get_sampling_timesteps(self, steps, device):
|
| 161 |
+
"""Generate timesteps for sampling."""
|
| 162 |
+
if self.time_sampling == "uniform" or self.time_sampling == "sigmoid":
|
| 163 |
+
return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1]
|
| 164 |
+
elif self.time_sampling == "warped":
|
| 165 |
+
pm = 128 * 16 * 16
|
| 166 |
+
alpha = max(1.0, math.sqrt(pm / 4096.0))
|
| 167 |
+
u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1]
|
| 168 |
+
return alpha * u / (1.0 + (alpha - 1.0) * u)
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}")
|
| 171 |
+
|
| 172 |
+
def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0):
|
| 173 |
+
b = z.size(0)
|
| 174 |
+
device = z.device
|
| 175 |
+
latent_shape = [b, *([1] * len(z.shape[1:]))]
|
| 176 |
+
|
| 177 |
+
timesteps = self.get_sampling_timesteps(sample_steps, device)
|
| 178 |
+
images = [z]
|
| 179 |
+
|
| 180 |
+
for idx in range(sample_steps):
|
| 181 |
+
t_curr = timesteps[idx]
|
| 182 |
+
t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device)
|
| 183 |
+
dt = t_curr - t_next
|
| 184 |
+
t = t_curr.expand(b)
|
| 185 |
+
|
| 186 |
+
vc = self.model(z, t, cond)
|
| 187 |
+
vc = (z - vc) / t_curr
|
| 188 |
+
if null_cond is not None:
|
| 189 |
+
vu = self.model(z, t, null_cond)
|
| 190 |
+
vu = (z - vu) / t_curr
|
| 191 |
+
vc = vu + cfg * (vc - vu)
|
| 192 |
+
|
| 193 |
+
z = z - dt * vc
|
| 194 |
+
images.append(z)
|
| 195 |
+
return images
|
| 196 |
+
|
| 197 |
+
def save_audio_samples(audio_batch, sample_rate, filename):
|
| 198 |
+
"""Save audio samples to file"""
|
| 199 |
+
os.makedirs("audio_samples", exist_ok=True)
|
| 200 |
+
|
| 201 |
+
# Take first sample from batch and convert to CPU
|
| 202 |
+
audio = audio_batch[0].cpu() # Shape: (2, T) for stereo
|
| 203 |
+
|
| 204 |
+
# Save as WAV file
|
| 205 |
+
filepath = os.path.join("audio_samples", filename)
|
| 206 |
+
torchaudio.save(filepath, audio, sample_rate)
|
| 207 |
+
print(f"Saved audio sample: {filepath}")
|
| 208 |
+
|
| 209 |
+
def parse_args():
|
| 210 |
+
parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging')
|
| 211 |
+
|
| 212 |
+
parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents')
|
| 213 |
+
parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents')
|
| 214 |
+
parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents')
|
| 215 |
+
parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent')
|
| 216 |
+
parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model')
|
| 217 |
+
parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model')
|
| 218 |
+
parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model')
|
| 219 |
+
parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder')
|
| 220 |
+
parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)')
|
| 221 |
+
parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name')
|
| 222 |
+
parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader')
|
| 223 |
+
|
| 224 |
+
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
|
| 225 |
+
parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train')
|
| 226 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
|
| 227 |
+
parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps')
|
| 228 |
+
|
| 229 |
+
parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)')
|
| 230 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)')
|
| 231 |
+
parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate')
|
| 232 |
+
parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint')
|
| 233 |
+
parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out')
|
| 234 |
+
parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy')
|
| 235 |
+
|
| 236 |
+
return parser.parse_args()
|
| 237 |
+
|
| 238 |
+
def main():
|
| 239 |
+
args = parse_args()
|
| 240 |
+
|
| 241 |
+
accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no")
|
| 242 |
+
|
| 243 |
+
is_main_process = accelerator.is_main_process
|
| 244 |
+
|
| 245 |
+
writer = None
|
| 246 |
+
if is_main_process:
|
| 247 |
+
run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 248 |
+
writer = SummaryWriter(log_dir=f"runs/{run_datetime}")
|
| 249 |
+
|
| 250 |
+
dataset = load_from_disk(args.dataset_name).with_format(type="torch")
|
| 251 |
+
|
| 252 |
+
# Filter out audio samples shorter than subsection_length (unless padding is enabled)
|
| 253 |
+
if not args.pad_to_length:
|
| 254 |
+
def filter_by_length(example):
|
| 255 |
+
latent_width = example['latents'].shape[-1]
|
| 256 |
+
return latent_width >= args.subsection_length * 2
|
| 257 |
+
|
| 258 |
+
dataset = dataset.filter(filter_by_length)
|
| 259 |
+
|
| 260 |
+
if is_main_process:
|
| 261 |
+
print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}")
|
| 262 |
+
else:
|
| 263 |
+
if is_main_process:
|
| 264 |
+
print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}")
|
| 265 |
+
|
| 266 |
+
# Latent normalization parameters (per-channel)
|
| 267 |
+
latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1)
|
| 268 |
+
latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1)
|
| 269 |
+
|
| 270 |
+
# Initialize tag embedder for converting metadata to tag indices
|
| 271 |
+
num_classes = 2304
|
| 272 |
+
tag_embedder = TagEmbedder(num_classes=num_classes)
|
| 273 |
+
|
| 274 |
+
# Custom collate function to randomly sample subsections from variable-width audio latents
|
| 275 |
+
def collate_fn(batch):
|
| 276 |
+
subsection_length = args.subsection_length
|
| 277 |
+
pad_to_length = False
|
| 278 |
+
|
| 279 |
+
sampled_latents = []
|
| 280 |
+
album_names = []
|
| 281 |
+
song_names = []
|
| 282 |
+
ids = []
|
| 283 |
+
tags = [] # List of tag lists for each sample
|
| 284 |
+
|
| 285 |
+
for item in batch:
|
| 286 |
+
latent = item['latents']
|
| 287 |
+
if len(latent.shape) == 3: # Add batch dimension if missing
|
| 288 |
+
latent = latent.unsqueeze(0)
|
| 289 |
+
|
| 290 |
+
# Get the width of the current latent
|
| 291 |
+
_, _, _, width = latent.shape
|
| 292 |
+
|
| 293 |
+
if width < subsection_length:
|
| 294 |
+
if pad_to_length:
|
| 295 |
+
# Pad the latent to subsection_length with zeros on the right
|
| 296 |
+
pad_amount = subsection_length - width
|
| 297 |
+
sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0)
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
# Randomly sample a starting position
|
| 301 |
+
max_start = width - subsection_length
|
| 302 |
+
start_idx = torch.randint(0, max_start + 1, (1,)).item()
|
| 303 |
+
|
| 304 |
+
# Extract the subsection
|
| 305 |
+
sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length]
|
| 306 |
+
|
| 307 |
+
sampled_latents.append(sampled_latent.squeeze(0)) # Remove batch dim for stacking
|
| 308 |
+
album_name = item['album_name']
|
| 309 |
+
song_name = item['song_name']
|
| 310 |
+
album_names.append(album_name)
|
| 311 |
+
song_names.append(song_name)
|
| 312 |
+
|
| 313 |
+
sample_tags = tag_embedder.get_tags(album_name, song_name)
|
| 314 |
+
tags.append(sample_tags)
|
| 315 |
+
|
| 316 |
+
# Stack latents and normalize
|
| 317 |
+
stacked_latents = torch.stack(sampled_latents)
|
| 318 |
+
normalized_latents = (stacked_latents - latent_mean) / latent_std
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
'latents': normalized_latents,
|
| 322 |
+
'tags': tags
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
dataloader = DataLoader(
|
| 326 |
+
dataset,
|
| 327 |
+
batch_size=args.batch_size,
|
| 328 |
+
shuffle=True,
|
| 329 |
+
drop_last=True,
|
| 330 |
+
persistent_workers=True,
|
| 331 |
+
num_workers=args.num_workers if torch.cuda.is_available() else 0,
|
| 332 |
+
pin_memory=True,
|
| 333 |
+
collate_fn=collate_fn
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
channels = args.channels
|
| 337 |
+
|
| 338 |
+
model = LocalSongModel(
|
| 339 |
+
in_channels=channels,
|
| 340 |
+
num_groups=args.n_heads,
|
| 341 |
+
hidden_size=args.dim,
|
| 342 |
+
decoder_hidden_size=args.decoder_dim,
|
| 343 |
+
num_blocks=args.n_layers,
|
| 344 |
+
patch_size=(16, 1), # Audio patch size (16 in height, 1 in width)
|
| 345 |
+
num_classes=num_classes, # Number of tag classes
|
| 346 |
+
max_tags=8, # Maximum number of tags per sample
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
vae = AudioVAE(accelerator.device)
|
| 350 |
+
|
| 351 |
+
rf = RF(model, time_sampling=args.time_sampling)
|
| 352 |
+
|
| 353 |
+
optimizer = timm.optim.Muon(model.parameters(),lr=args.lr)
|
| 354 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader))
|
| 355 |
+
|
| 356 |
+
global_step = 0
|
| 357 |
+
if args.resume:
|
| 358 |
+
global_step = resume(model, optimizer, scheduler, accelerator)
|
| 359 |
+
|
| 360 |
+
if torch.cuda.is_available():
|
| 361 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 362 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 363 |
+
model.forward_emb = torch.compile(model.forward_emb)
|
| 364 |
+
|
| 365 |
+
model, optimizer, scheduler, dataloader = accelerator.prepare(
|
| 366 |
+
model, optimizer, scheduler, dataloader
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
rf.model = model
|
| 370 |
+
|
| 371 |
+
if is_main_process:
|
| 372 |
+
model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad)
|
| 373 |
+
print(f"Number of parameters: {model_size}, {model_size / 1e6}M")
|
| 374 |
+
|
| 375 |
+
os.makedirs("audio_samples", exist_ok=True)
|
| 376 |
+
num_samples = args.num_samples
|
| 377 |
+
|
| 378 |
+
fixed_batch = None
|
| 379 |
+
fixed_latents = None
|
| 380 |
+
fixed_labels = None
|
| 381 |
+
fixed_noise = None
|
| 382 |
+
|
| 383 |
+
if is_main_process:
|
| 384 |
+
data_iter = iter(dataloader)
|
| 385 |
+
fixed_batch = next(data_iter)
|
| 386 |
+
fixed_latents = fixed_batch["latents"][:num_samples]
|
| 387 |
+
|
| 388 |
+
print("Fixed ids:", fixed_batch["album_names"])
|
| 389 |
+
|
| 390 |
+
# Get fixed tags for sampling
|
| 391 |
+
fixed_tags = []
|
| 392 |
+
|
| 393 |
+
# Create reverse mapping from tag indices to strings
|
| 394 |
+
idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()}
|
| 395 |
+
|
| 396 |
+
# Print string labels for fixed tags
|
| 397 |
+
print("Fixed tag labels:")
|
| 398 |
+
for i, tag_list in enumerate(fixed_tags):
|
| 399 |
+
labels = [idx_to_tag.get(idx, f"<unknown:{idx}>") for idx in tag_list]
|
| 400 |
+
print(f" Sample {i}: {labels}")
|
| 401 |
+
|
| 402 |
+
# Create noise with same shape as fixed latents
|
| 403 |
+
B, C, H, W = fixed_latents.shape
|
| 404 |
+
fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device)
|
| 405 |
+
|
| 406 |
+
fixed_latents = fixed_latents.to(accelerator.device)
|
| 407 |
+
|
| 408 |
+
if is_main_process:
|
| 409 |
+
print("Begin training")
|
| 410 |
+
|
| 411 |
+
mse_loss_window = deque(maxlen=100)
|
| 412 |
+
start_epoch = 0
|
| 413 |
+
for epoch in range(start_epoch, args.epochs):
|
| 414 |
+
|
| 415 |
+
pbar = tqdm(dataloader) if is_main_process else dataloader
|
| 416 |
+
for batch in pbar:
|
| 417 |
+
x = batch["latents"]
|
| 418 |
+
|
| 419 |
+
# Get tags from batch
|
| 420 |
+
tags = batch["tags"]
|
| 421 |
+
|
| 422 |
+
# Apply classifier-free guidance dropout (10% chance to drop all tags)
|
| 423 |
+
dropout_tags = []
|
| 424 |
+
for tag_list in tags:
|
| 425 |
+
if torch.rand(1).item() < 0.1:
|
| 426 |
+
# Replace with empty list (will be padded to [0] in embed_condition)
|
| 427 |
+
dropout_tags.append([])
|
| 428 |
+
else:
|
| 429 |
+
dropout_tags.append(tag_list)
|
| 430 |
+
|
| 431 |
+
# Tags will be embedded inside the model's forward pass
|
| 432 |
+
c = dropout_tags
|
| 433 |
+
|
| 434 |
+
with accelerator.accumulate(model):
|
| 435 |
+
optimizer.zero_grad()
|
| 436 |
+
mse_loss = rf.forward(x, c)
|
| 437 |
+
|
| 438 |
+
loss = mse_loss
|
| 439 |
+
|
| 440 |
+
accelerator.backward(loss)
|
| 441 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
| 442 |
+
optimizer.step()
|
| 443 |
+
scheduler.step()
|
| 444 |
+
|
| 445 |
+
if is_main_process:
|
| 446 |
+
|
| 447 |
+
mse_loss_window.append(mse_loss.item())
|
| 448 |
+
|
| 449 |
+
avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window)
|
| 450 |
+
|
| 451 |
+
if isinstance(pbar, tqdm):
|
| 452 |
+
pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']})
|
| 453 |
+
|
| 454 |
+
if writer is not None:
|
| 455 |
+
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
|
| 456 |
+
writer.add_scalar('MSE_Loss', avg_mse_loss, global_step)
|
| 457 |
+
|
| 458 |
+
global_step += 1
|
| 459 |
+
|
| 460 |
+
if is_main_process and global_step % args.save_every == 0:
|
| 461 |
+
save(model, optimizer, scheduler, global_step, accelerator)
|
| 462 |
+
|
| 463 |
+
if is_main_process and global_step % args.sample_every == 0:
|
| 464 |
+
model.eval()
|
| 465 |
+
|
| 466 |
+
with torch.no_grad():
|
| 467 |
+
# Use fixed tags for conditional sampling
|
| 468 |
+
cond = fixed_tags
|
| 469 |
+
# Unconditional is empty tags for all samples
|
| 470 |
+
null_cond = [[] for _ in range(len(cond))]
|
| 471 |
+
|
| 472 |
+
sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1]
|
| 473 |
+
|
| 474 |
+
# Decode latents to audio
|
| 475 |
+
try:
|
| 476 |
+
sampled_audio = vae.decode(sampled_latents)
|
| 477 |
+
|
| 478 |
+
# Save audio samples
|
| 479 |
+
for i in range(min(8, sampled_audio.shape[0])): # Save first 2 samples
|
| 480 |
+
save_audio_samples(
|
| 481 |
+
sampled_audio[i:i+1],
|
| 482 |
+
48000,
|
| 483 |
+
f"sample_{global_step}_generated_{i}.wav"
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Also save original for comparison
|
| 487 |
+
if global_step == args.sample_every:
|
| 488 |
+
original_audio = vae.decode(fixed_latents)
|
| 489 |
+
for i in range(min(8, original_audio.shape[0])):
|
| 490 |
+
save_audio_samples(
|
| 491 |
+
original_audio[i:i+1],
|
| 492 |
+
48000,
|
| 493 |
+
f"sample_{global_step}_original_{i}.wav"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
except Exception as e:
|
| 497 |
+
print(f"Error during audio generation: {e}")
|
| 498 |
+
|
| 499 |
+
model.train()
|
| 500 |
+
|
| 501 |
+
print("Saving final model")
|
| 502 |
+
save(model, optimizer, scheduler, global_step, accelerator)
|
| 503 |
+
|
| 504 |
+
if writer is not None:
|
| 505 |
+
writer.close()
|
| 506 |
+
|
| 507 |
+
if __name__ == '__main__':
|
| 508 |
+
main()
|