|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision.utils import make_grid, save_image |
|
|
from tqdm import tqdm |
|
|
from ddt_model import LocalSongModel |
|
|
from transformers import get_cosine_schedule_with_warmup |
|
|
from datasets import load_from_disk |
|
|
from accelerate import Accelerator |
|
|
import os |
|
|
import argparse |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from datetime import datetime |
|
|
from collections import deque |
|
|
import torchaudio |
|
|
import re |
|
|
import sys |
|
|
import math |
|
|
from tag_embedder import TagEmbedder |
|
|
|
|
|
|
|
|
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
import timm.optim |
|
|
|
|
|
import os |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def save(model, optimizer, scheduler, global_step, accelerator): |
|
|
if accelerator.is_main_process: |
|
|
checkpoint_dir = "checkpoints" |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth") |
|
|
save_dict = { |
|
|
'model_state_dict': unwrapped_model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'global_step': global_step |
|
|
} |
|
|
|
|
|
accelerator.save(save_dict, checkpoint_path) |
|
|
print(f"Checkpoint saved at step {global_step}: {checkpoint_path}") |
|
|
|
|
|
checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")], |
|
|
key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True) |
|
|
|
|
|
for old_checkpoint in checkpoints[5:]: |
|
|
os.remove(os.path.join(checkpoint_dir, old_checkpoint)) |
|
|
print(f"Removed old checkpoint: {old_checkpoint}") |
|
|
|
|
|
|
|
|
def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator): |
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) |
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()} |
|
|
missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True) |
|
|
print("MISSING:", missing) |
|
|
print("UNEXPECTED:", unexpected) |
|
|
|
|
|
if 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
print("Optimizer loaded") |
|
|
|
|
|
global_step = checkpoint['global_step'] |
|
|
print(f"Resumed from step {global_step}") |
|
|
return global_step |
|
|
|
|
|
def resume(model, optimizer, scheduler, accelerator): |
|
|
checkpoint_dir = "checkpoints" |
|
|
if os.path.exists(checkpoint_dir): |
|
|
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")] |
|
|
if checkpoints: |
|
|
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0])) |
|
|
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint) |
|
|
if accelerator.is_main_process: |
|
|
print(f"Resuming from checkpoint: {checkpoint_path}") |
|
|
|
|
|
return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator) |
|
|
else: |
|
|
if accelerator.is_main_process: |
|
|
print("No checkpoints found. Starting from scratch.") |
|
|
else: |
|
|
if accelerator.is_main_process: |
|
|
print("Checkpoint directory not found. Starting from scratch.") |
|
|
|
|
|
return 0 |
|
|
|
|
|
class AudioVAE: |
|
|
def __init__(self, device): |
|
|
self.model = MusicDCAE().to(device) |
|
|
self.model.eval() |
|
|
self.device = device |
|
|
|
|
|
self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1) |
|
|
self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1) |
|
|
|
|
|
def encode(self, audio): |
|
|
"""Encode audio to latents""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) |
|
|
latents, _ = self.model.encode(audio, audio_lengths, sr=48000) |
|
|
|
|
|
latents = (latents - self.latent_mean) / self.latent_std |
|
|
return latents |
|
|
|
|
|
def decode(self, latents): |
|
|
"""Decode latents to audio""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
latents = latents * self.latent_std + self.latent_mean |
|
|
sr, audio_list = self.model.decode(latents, sr=48000) |
|
|
|
|
|
audio_batch = torch.stack(audio_list).to(self.device) |
|
|
return audio_batch |
|
|
|
|
|
class RF: |
|
|
def __init__(self, model, time_sampling="sigmoid"): |
|
|
self.model = model |
|
|
self.time_sampling = time_sampling |
|
|
|
|
|
def sample_timesteps(self, batch, device): |
|
|
"""Sample timesteps based on the configured strategy.""" |
|
|
if self.time_sampling == "sigmoid": |
|
|
return torch.sigmoid(torch.randn((batch,), device=device)) |
|
|
elif self.time_sampling == "warped": |
|
|
pm = 128 * 16 * 16 |
|
|
alpha = max(1.0, math.sqrt(pm / 4096.0)) |
|
|
u = torch.rand(batch, device=device) |
|
|
return alpha * u / (1.0 + (alpha - 1.0) * u) |
|
|
elif self.time_sampling == "uniform": |
|
|
return torch.rand(batch, device=device) |
|
|
else: |
|
|
raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") |
|
|
|
|
|
def forward(self, x, cond): |
|
|
b = x.size(0) |
|
|
|
|
|
t = self.sample_timesteps(b, x.device) |
|
|
|
|
|
texp = t.view([b, *([1] * len(x.shape[1:]))]) |
|
|
z1 = torch.randn_like(x) |
|
|
zt = (1 - texp) * x + texp * z1 |
|
|
|
|
|
x_pred = self.model(zt, t, cond) |
|
|
|
|
|
target = (zt - x) / (texp + 0.05) |
|
|
v_pred = (zt - x_pred) / (texp + 0.05) |
|
|
loss = F.mse_loss(target, v_pred) |
|
|
|
|
|
return loss |
|
|
|
|
|
def get_sampling_timesteps(self, steps, device): |
|
|
"""Generate timesteps for sampling.""" |
|
|
if self.time_sampling == "uniform" or self.time_sampling == "sigmoid": |
|
|
return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] |
|
|
elif self.time_sampling == "warped": |
|
|
pm = 128 * 16 * 16 |
|
|
alpha = max(1.0, math.sqrt(pm / 4096.0)) |
|
|
u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] |
|
|
return alpha * u / (1.0 + (alpha - 1.0) * u) |
|
|
else: |
|
|
raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") |
|
|
|
|
|
def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0): |
|
|
b = z.size(0) |
|
|
device = z.device |
|
|
latent_shape = [b, *([1] * len(z.shape[1:]))] |
|
|
|
|
|
timesteps = self.get_sampling_timesteps(sample_steps, device) |
|
|
images = [z] |
|
|
|
|
|
for idx in range(sample_steps): |
|
|
t_curr = timesteps[idx] |
|
|
t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device) |
|
|
dt = t_curr - t_next |
|
|
t = t_curr.expand(b) |
|
|
|
|
|
vc = self.model(z, t, cond) |
|
|
vc = (z - vc) / t_curr |
|
|
if null_cond is not None: |
|
|
vu = self.model(z, t, null_cond) |
|
|
vu = (z - vu) / t_curr |
|
|
vc = vu + cfg * (vc - vu) |
|
|
|
|
|
z = z - dt * vc |
|
|
images.append(z) |
|
|
return images |
|
|
|
|
|
def save_audio_samples(audio_batch, sample_rate, filename): |
|
|
"""Save audio samples to file""" |
|
|
os.makedirs("audio_samples", exist_ok=True) |
|
|
|
|
|
|
|
|
audio = audio_batch[0].cpu() |
|
|
|
|
|
|
|
|
filepath = os.path.join("audio_samples", filename) |
|
|
torchaudio.save(filepath, audio, sample_rate) |
|
|
print(f"Saved audio sample: {filepath}") |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging') |
|
|
|
|
|
parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents') |
|
|
parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents') |
|
|
parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents') |
|
|
parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent') |
|
|
parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model') |
|
|
parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model') |
|
|
parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model') |
|
|
parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder') |
|
|
parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)') |
|
|
parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name') |
|
|
parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader') |
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') |
|
|
parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') |
|
|
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') |
|
|
parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps') |
|
|
|
|
|
parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)') |
|
|
parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)') |
|
|
parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate') |
|
|
parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint') |
|
|
parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out') |
|
|
parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy') |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no") |
|
|
|
|
|
is_main_process = accelerator.is_main_process |
|
|
|
|
|
writer = None |
|
|
if is_main_process: |
|
|
run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
writer = SummaryWriter(log_dir=f"runs/{run_datetime}") |
|
|
|
|
|
dataset = load_from_disk(args.dataset_name).with_format(type="torch") |
|
|
|
|
|
|
|
|
if not args.pad_to_length: |
|
|
def filter_by_length(example): |
|
|
latent_width = example['latents'].shape[-1] |
|
|
return latent_width >= args.subsection_length * 2 |
|
|
|
|
|
dataset = dataset.filter(filter_by_length) |
|
|
|
|
|
if is_main_process: |
|
|
print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}") |
|
|
else: |
|
|
if is_main_process: |
|
|
print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}") |
|
|
|
|
|
|
|
|
latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1) |
|
|
latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1) |
|
|
|
|
|
|
|
|
num_classes = 2304 |
|
|
tag_embedder = TagEmbedder(num_classes=num_classes) |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
subsection_length = args.subsection_length |
|
|
pad_to_length = False |
|
|
|
|
|
sampled_latents = [] |
|
|
album_names = [] |
|
|
song_names = [] |
|
|
ids = [] |
|
|
tags = [] |
|
|
|
|
|
for item in batch: |
|
|
latent = item['latents'] |
|
|
if len(latent.shape) == 3: |
|
|
latent = latent.unsqueeze(0) |
|
|
|
|
|
|
|
|
_, _, _, width = latent.shape |
|
|
|
|
|
if width < subsection_length: |
|
|
if pad_to_length: |
|
|
|
|
|
pad_amount = subsection_length - width |
|
|
sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) |
|
|
|
|
|
else: |
|
|
|
|
|
max_start = width - subsection_length |
|
|
start_idx = torch.randint(0, max_start + 1, (1,)).item() |
|
|
|
|
|
|
|
|
sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length] |
|
|
|
|
|
sampled_latents.append(sampled_latent.squeeze(0)) |
|
|
album_name = item['album_name'] |
|
|
song_name = item['song_name'] |
|
|
album_names.append(album_name) |
|
|
song_names.append(song_name) |
|
|
|
|
|
sample_tags = tag_embedder.get_tags(album_name, song_name) |
|
|
tags.append(sample_tags) |
|
|
|
|
|
|
|
|
stacked_latents = torch.stack(sampled_latents) |
|
|
normalized_latents = (stacked_latents - latent_mean) / latent_std |
|
|
|
|
|
return { |
|
|
'latents': normalized_latents, |
|
|
'tags': tags |
|
|
} |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
drop_last=True, |
|
|
persistent_workers=True, |
|
|
num_workers=args.num_workers if torch.cuda.is_available() else 0, |
|
|
pin_memory=True, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
channels = args.channels |
|
|
|
|
|
model = LocalSongModel( |
|
|
in_channels=channels, |
|
|
num_groups=args.n_heads, |
|
|
hidden_size=args.dim, |
|
|
decoder_hidden_size=args.decoder_dim, |
|
|
num_blocks=args.n_layers, |
|
|
patch_size=(16, 1), |
|
|
num_classes=num_classes, |
|
|
max_tags=8, |
|
|
) |
|
|
|
|
|
vae = AudioVAE(accelerator.device) |
|
|
|
|
|
rf = RF(model, time_sampling=args.time_sampling) |
|
|
|
|
|
optimizer = timm.optim.Muon(model.parameters(),lr=args.lr) |
|
|
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader)) |
|
|
|
|
|
global_step = 0 |
|
|
if args.resume: |
|
|
global_step = resume(model, optimizer, scheduler, accelerator) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
model.forward_emb = torch.compile(model.forward_emb) |
|
|
|
|
|
model, optimizer, scheduler, dataloader = accelerator.prepare( |
|
|
model, optimizer, scheduler, dataloader |
|
|
) |
|
|
|
|
|
rf.model = model |
|
|
|
|
|
if is_main_process: |
|
|
model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad) |
|
|
print(f"Number of parameters: {model_size}, {model_size / 1e6}M") |
|
|
|
|
|
os.makedirs("audio_samples", exist_ok=True) |
|
|
num_samples = args.num_samples |
|
|
|
|
|
fixed_batch = None |
|
|
fixed_latents = None |
|
|
fixed_labels = None |
|
|
fixed_noise = None |
|
|
|
|
|
if is_main_process: |
|
|
data_iter = iter(dataloader) |
|
|
fixed_batch = next(data_iter) |
|
|
fixed_latents = fixed_batch["latents"][:num_samples] |
|
|
|
|
|
print("Fixed ids:", fixed_batch["album_names"]) |
|
|
|
|
|
|
|
|
fixed_tags = [] |
|
|
|
|
|
|
|
|
idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()} |
|
|
|
|
|
|
|
|
print("Fixed tag labels:") |
|
|
for i, tag_list in enumerate(fixed_tags): |
|
|
labels = [idx_to_tag.get(idx, f"<unknown:{idx}>") for idx in tag_list] |
|
|
print(f" Sample {i}: {labels}") |
|
|
|
|
|
|
|
|
B, C, H, W = fixed_latents.shape |
|
|
fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device) |
|
|
|
|
|
fixed_latents = fixed_latents.to(accelerator.device) |
|
|
|
|
|
if is_main_process: |
|
|
print("Begin training") |
|
|
|
|
|
mse_loss_window = deque(maxlen=100) |
|
|
start_epoch = 0 |
|
|
for epoch in range(start_epoch, args.epochs): |
|
|
|
|
|
pbar = tqdm(dataloader) if is_main_process else dataloader |
|
|
for batch in pbar: |
|
|
x = batch["latents"] |
|
|
|
|
|
|
|
|
tags = batch["tags"] |
|
|
|
|
|
|
|
|
dropout_tags = [] |
|
|
for tag_list in tags: |
|
|
if torch.rand(1).item() < 0.1: |
|
|
|
|
|
dropout_tags.append([]) |
|
|
else: |
|
|
dropout_tags.append(tag_list) |
|
|
|
|
|
|
|
|
c = dropout_tags |
|
|
|
|
|
with accelerator.accumulate(model): |
|
|
optimizer.zero_grad() |
|
|
mse_loss = rf.forward(x, c) |
|
|
|
|
|
loss = mse_loss |
|
|
|
|
|
accelerator.backward(loss) |
|
|
accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
if is_main_process: |
|
|
|
|
|
mse_loss_window.append(mse_loss.item()) |
|
|
|
|
|
avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window) |
|
|
|
|
|
if isinstance(pbar, tqdm): |
|
|
pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']}) |
|
|
|
|
|
if writer is not None: |
|
|
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step) |
|
|
writer.add_scalar('MSE_Loss', avg_mse_loss, global_step) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
if is_main_process and global_step % args.save_every == 0: |
|
|
save(model, optimizer, scheduler, global_step, accelerator) |
|
|
|
|
|
if is_main_process and global_step % args.sample_every == 0: |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
cond = fixed_tags |
|
|
|
|
|
null_cond = [[] for _ in range(len(cond))] |
|
|
|
|
|
sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1] |
|
|
|
|
|
|
|
|
try: |
|
|
sampled_audio = vae.decode(sampled_latents) |
|
|
|
|
|
|
|
|
for i in range(min(8, sampled_audio.shape[0])): |
|
|
save_audio_samples( |
|
|
sampled_audio[i:i+1], |
|
|
48000, |
|
|
f"sample_{global_step}_generated_{i}.wav" |
|
|
) |
|
|
|
|
|
|
|
|
if global_step == args.sample_every: |
|
|
original_audio = vae.decode(fixed_latents) |
|
|
for i in range(min(8, original_audio.shape[0])): |
|
|
save_audio_samples( |
|
|
original_audio[i:i+1], |
|
|
48000, |
|
|
f"sample_{global_step}_original_{i}.wav" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during audio generation: {e}") |
|
|
|
|
|
model.train() |
|
|
|
|
|
print("Saving final model") |
|
|
save(model, optimizer, scheduler, global_step, accelerator) |
|
|
|
|
|
if writer is not None: |
|
|
writer.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|