import torch import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image from tqdm import tqdm from ddt_model import LocalSongModel from transformers import get_cosine_schedule_with_warmup from datasets import load_from_disk from accelerate import Accelerator import os import argparse from torch.utils.tensorboard import SummaryWriter from datetime import datetime from collections import deque import torchaudio import re import sys import math from tag_embedder import TagEmbedder # Import MusicDCAE from acestep.music_dcae.music_dcae_pipeline import MusicDCAE # Import Muon optimizer sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import timm.optim import os os.environ["TOKENIZERS_PARALLELISM"] = "false" def save(model, optimizer, scheduler, global_step, accelerator): if accelerator.is_main_process: checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) unwrapped_model = accelerator.unwrap_model(model) checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth") save_dict = { 'model_state_dict': unwrapped_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'global_step': global_step } accelerator.save(save_dict, checkpoint_path) print(f"Checkpoint saved at step {global_step}: {checkpoint_path}") checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")], key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True) for old_checkpoint in checkpoints[5:]: os.remove(os.path.join(checkpoint_dir, old_checkpoint)) print(f"Removed old checkpoint: {old_checkpoint}") def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator): checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) unwrapped_model = accelerator.unwrap_model(model) state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()} missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True) print("MISSING:", missing) print("UNEXPECTED:", unexpected) if 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print("Optimizer loaded") global_step = checkpoint['global_step'] print(f"Resumed from step {global_step}") return global_step def resume(model, optimizer, scheduler, accelerator): checkpoint_dir = "checkpoints" if os.path.exists(checkpoint_dir): checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")] if checkpoints: latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0])) checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint) if accelerator.is_main_process: print(f"Resuming from checkpoint: {checkpoint_path}") return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator) else: if accelerator.is_main_process: print("No checkpoints found. Starting from scratch.") else: if accelerator.is_main_process: print("Checkpoint directory not found. Starting from scratch.") return 0 class AudioVAE: def __init__(self, device): self.model = MusicDCAE().to(device) self.model.eval() self.device = device self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1) self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1) def encode(self, audio): """Encode audio to latents""" # audio should be (B, 2, T) at 48kHz with torch.no_grad(): audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) latents, _ = self.model.encode(audio, audio_lengths, sr=48000) # Normalize latents: (latents - mean) / std latents = (latents - self.latent_mean) / self.latent_std return latents def decode(self, latents): """Decode latents to audio""" with torch.no_grad(): # Denormalize latents: latents * std + mean latents = latents * self.latent_std + self.latent_mean sr, audio_list = self.model.decode(latents, sr=48000) # Convert list of audio tensors to batch tensor audio_batch = torch.stack(audio_list).to(self.device) return audio_batch class RF: def __init__(self, model, time_sampling="sigmoid"): self.model = model self.time_sampling = time_sampling def sample_timesteps(self, batch, device): """Sample timesteps based on the configured strategy.""" if self.time_sampling == "sigmoid": return torch.sigmoid(torch.randn((batch,), device=device)) elif self.time_sampling == "warped": pm = 128 * 16 * 16 alpha = max(1.0, math.sqrt(pm / 4096.0)) u = torch.rand(batch, device=device) return alpha * u / (1.0 + (alpha - 1.0) * u) elif self.time_sampling == "uniform": return torch.rand(batch, device=device) else: raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") def forward(self, x, cond): b = x.size(0) t = self.sample_timesteps(b, x.device) texp = t.view([b, *([1] * len(x.shape[1:]))]) z1 = torch.randn_like(x) zt = (1 - texp) * x + texp * z1 x_pred = self.model(zt, t, cond) target = (zt - x) / (texp + 0.05) v_pred = (zt - x_pred) / (texp + 0.05) loss = F.mse_loss(target, v_pred) return loss def get_sampling_timesteps(self, steps, device): """Generate timesteps for sampling.""" if self.time_sampling == "uniform" or self.time_sampling == "sigmoid": return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] elif self.time_sampling == "warped": pm = 128 * 16 * 16 alpha = max(1.0, math.sqrt(pm / 4096.0)) u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] return alpha * u / (1.0 + (alpha - 1.0) * u) else: raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0): b = z.size(0) device = z.device latent_shape = [b, *([1] * len(z.shape[1:]))] timesteps = self.get_sampling_timesteps(sample_steps, device) images = [z] for idx in range(sample_steps): t_curr = timesteps[idx] t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device) dt = t_curr - t_next t = t_curr.expand(b) vc = self.model(z, t, cond) vc = (z - vc) / t_curr if null_cond is not None: vu = self.model(z, t, null_cond) vu = (z - vu) / t_curr vc = vu + cfg * (vc - vu) z = z - dt * vc images.append(z) return images def save_audio_samples(audio_batch, sample_rate, filename): """Save audio samples to file""" os.makedirs("audio_samples", exist_ok=True) # Take first sample from batch and convert to CPU audio = audio_batch[0].cpu() # Shape: (2, T) for stereo # Save as WAV file filepath = os.path.join("audio_samples", filename) torchaudio.save(filepath, audio, sample_rate) print(f"Saved audio sample: {filepath}") def parse_args(): parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging') parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents') parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents') parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents') parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent') parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model') parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model') parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model') parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder') parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)') parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name') parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader') parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps') parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)') parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)') parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate') parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint') parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out') parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy') return parser.parse_args() def main(): args = parse_args() accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no") is_main_process = accelerator.is_main_process writer = None if is_main_process: run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") writer = SummaryWriter(log_dir=f"runs/{run_datetime}") dataset = load_from_disk(args.dataset_name).with_format(type="torch") # Filter out audio samples shorter than subsection_length (unless padding is enabled) if not args.pad_to_length: def filter_by_length(example): latent_width = example['latents'].shape[-1] return latent_width >= args.subsection_length * 2 dataset = dataset.filter(filter_by_length) if is_main_process: print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}") else: if is_main_process: print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}") # Latent normalization parameters (per-channel) latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1) latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1) # Initialize tag embedder for converting metadata to tag indices num_classes = 2304 tag_embedder = TagEmbedder(num_classes=num_classes) # Custom collate function to randomly sample subsections from variable-width audio latents def collate_fn(batch): subsection_length = args.subsection_length pad_to_length = False sampled_latents = [] album_names = [] song_names = [] ids = [] tags = [] # List of tag lists for each sample for item in batch: latent = item['latents'] if len(latent.shape) == 3: # Add batch dimension if missing latent = latent.unsqueeze(0) # Get the width of the current latent _, _, _, width = latent.shape if width < subsection_length: if pad_to_length: # Pad the latent to subsection_length with zeros on the right pad_amount = subsection_length - width sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) else: # Randomly sample a starting position max_start = width - subsection_length start_idx = torch.randint(0, max_start + 1, (1,)).item() # Extract the subsection sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length] sampled_latents.append(sampled_latent.squeeze(0)) # Remove batch dim for stacking album_name = item['album_name'] song_name = item['song_name'] album_names.append(album_name) song_names.append(song_name) sample_tags = tag_embedder.get_tags(album_name, song_name) tags.append(sample_tags) # Stack latents and normalize stacked_latents = torch.stack(sampled_latents) normalized_latents = (stacked_latents - latent_mean) / latent_std return { 'latents': normalized_latents, 'tags': tags } dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, persistent_workers=True, num_workers=args.num_workers if torch.cuda.is_available() else 0, pin_memory=True, collate_fn=collate_fn ) channels = args.channels model = LocalSongModel( in_channels=channels, num_groups=args.n_heads, hidden_size=args.dim, decoder_hidden_size=args.decoder_dim, num_blocks=args.n_layers, patch_size=(16, 1), # Audio patch size (16 in height, 1 in width) num_classes=num_classes, # Number of tag classes max_tags=8, # Maximum number of tags per sample ) vae = AudioVAE(accelerator.device) rf = RF(model, time_sampling=args.time_sampling) optimizer = timm.optim.Muon(model.parameters(),lr=args.lr) scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader)) global_step = 0 if args.resume: global_step = resume(model, optimizer, scheduler, accelerator) if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True model.forward_emb = torch.compile(model.forward_emb) model, optimizer, scheduler, dataloader = accelerator.prepare( model, optimizer, scheduler, dataloader ) rf.model = model if is_main_process: model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad) print(f"Number of parameters: {model_size}, {model_size / 1e6}M") os.makedirs("audio_samples", exist_ok=True) num_samples = args.num_samples fixed_batch = None fixed_latents = None fixed_labels = None fixed_noise = None if is_main_process: data_iter = iter(dataloader) fixed_batch = next(data_iter) fixed_latents = fixed_batch["latents"][:num_samples] print("Fixed ids:", fixed_batch["album_names"]) # Get fixed tags for sampling fixed_tags = [] # Create reverse mapping from tag indices to strings idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()} # Print string labels for fixed tags print("Fixed tag labels:") for i, tag_list in enumerate(fixed_tags): labels = [idx_to_tag.get(idx, f"") for idx in tag_list] print(f" Sample {i}: {labels}") # Create noise with same shape as fixed latents B, C, H, W = fixed_latents.shape fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device) fixed_latents = fixed_latents.to(accelerator.device) if is_main_process: print("Begin training") mse_loss_window = deque(maxlen=100) start_epoch = 0 for epoch in range(start_epoch, args.epochs): pbar = tqdm(dataloader) if is_main_process else dataloader for batch in pbar: x = batch["latents"] # Get tags from batch tags = batch["tags"] # Apply classifier-free guidance dropout (10% chance to drop all tags) dropout_tags = [] for tag_list in tags: if torch.rand(1).item() < 0.1: # Replace with empty list (will be padded to [0] in embed_condition) dropout_tags.append([]) else: dropout_tags.append(tag_list) # Tags will be embedded inside the model's forward pass c = dropout_tags with accelerator.accumulate(model): optimizer.zero_grad() mse_loss = rf.forward(x, c) loss = mse_loss accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() if is_main_process: mse_loss_window.append(mse_loss.item()) avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window) if isinstance(pbar, tqdm): pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']}) if writer is not None: writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step) writer.add_scalar('MSE_Loss', avg_mse_loss, global_step) global_step += 1 if is_main_process and global_step % args.save_every == 0: save(model, optimizer, scheduler, global_step, accelerator) if is_main_process and global_step % args.sample_every == 0: model.eval() with torch.no_grad(): # Use fixed tags for conditional sampling cond = fixed_tags # Unconditional is empty tags for all samples null_cond = [[] for _ in range(len(cond))] sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1] # Decode latents to audio try: sampled_audio = vae.decode(sampled_latents) # Save audio samples for i in range(min(8, sampled_audio.shape[0])): # Save first 2 samples save_audio_samples( sampled_audio[i:i+1], 48000, f"sample_{global_step}_generated_{i}.wav" ) # Also save original for comparison if global_step == args.sample_every: original_audio = vae.decode(fixed_latents) for i in range(min(8, original_audio.shape[0])): save_audio_samples( original_audio[i:i+1], 48000, f"sample_{global_step}_original_{i}.wav" ) except Exception as e: print(f"Error during audio generation: {e}") model.train() print("Saving final model") save(model, optimizer, scheduler, global_step, accelerator) if writer is not None: writer.close() if __name__ == '__main__': main()