Spaces:
Runtime error
Runtime error
| import time | |
| import os | |
| from statistics import mean | |
| from collections import defaultdict | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from tqdm.auto import tqdm | |
| import pprint | |
| import math | |
| import constants | |
| from constants import PitchToken, DurationToken | |
| from utils import append_dict, print_divider | |
| class StepBetaScheduler(): | |
| def __init__(self, anneal_start, beta_max, step_size, anneal_end): | |
| self.anneal_start = anneal_start | |
| self.beta_max = beta_max | |
| self.step_size = step_size | |
| self.anneal_end = anneal_end | |
| self.update_steps = 0 | |
| self.beta = 0 | |
| n_steps = self.beta_max // self.step_size | |
| self.inc_every = (self.anneal_end-self.anneal_start) // n_steps | |
| def step(self): | |
| self.update_steps += 1 | |
| if (self.update_steps >= self.anneal_start or | |
| self.update_steps < self.anneal_end): | |
| # If we are annealing, update beta according to current step | |
| curr_step = (self.update_steps-self.anneal_start) // self.inc_every | |
| self.beta = self.step_size * (curr_step+1) | |
| return self.beta | |
| class ExpDecayLRScheduler(): | |
| def __init__(self, optimizer, peak_lr, warmup_steps, final_lr_scale, | |
| decay_steps): | |
| self.optimizer = optimizer | |
| self.peak_lr = peak_lr | |
| self.warmup_steps = warmup_steps | |
| self.decay_steps = decay_steps | |
| # Find the decay factor needed to reach the specified | |
| # learning rate scale after decay_steps steps | |
| self.decay_factor = -math.log(final_lr_scale) / self.decay_steps | |
| self.update_steps = 0 | |
| def set_lr(self, optimizer, lr): | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| def step(self): | |
| self.update_steps += 1 | |
| if self.update_steps <= self.warmup_steps: | |
| self.lr = self.peak_lr | |
| else: | |
| # Decay lr exponentially | |
| steps_after_warmup = self.update_steps - self. warmup_steps | |
| self.lr = \ | |
| self.peak_lr * math.exp(-self.decay_factor*steps_after_warmup) | |
| self.set_lr(self.optimizer, self.lr) | |
| return self.lr | |
| class PolyphemusTrainer(): | |
| def __init__(self, model_dir, model, optimizer, init_lr=1e-4, | |
| lr_scheduler=None, beta_scheduler=None, device=None, | |
| print_every=1, save_every=1, eval_every=100, | |
| iters_to_accumulate=1, **kwargs): | |
| self.__dict__.update(kwargs) | |
| self.model_dir = model_dir | |
| self.model = model | |
| self.optimizer = optimizer | |
| self.init_lr = init_lr | |
| self.lr_scheduler = lr_scheduler | |
| self.beta_scheduler = beta_scheduler | |
| self.device = device if device is not None else torch.device("cpu") | |
| self.cuda = True if self.device.type == 'cuda' else False | |
| self.print_every = print_every | |
| self.save_every = save_every | |
| self.eval_every = eval_every | |
| self.iters_to_accumulate = iters_to_accumulate | |
| # Losses (ignoring PAD tokens) | |
| self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none') | |
| self.ce_p = nn.CrossEntropyLoss(ignore_index=PitchToken.PAD.value) | |
| self.ce_d = nn.CrossEntropyLoss(ignore_index=DurationToken.PAD.value) | |
| # Training stats | |
| self.tr_losses = defaultdict(list) | |
| self.tr_accuracies = defaultdict(list) | |
| self.val_losses = defaultdict(list) | |
| self.val_accuracies = defaultdict(list) | |
| self.lrs = [] | |
| self.betas = [] | |
| self.times = [] | |
| def train(self, trainloader, validloader=None, epochs=100, early_exit=None): | |
| self.tot_batches = 0 | |
| self.beta = 0 | |
| self.min_val_loss = np.inf | |
| start = time.time() | |
| self.times.append(start) | |
| self.model.train() | |
| scaler = torch.cuda.amp.GradScaler() if self.cuda else None | |
| self.optimizer.zero_grad() | |
| progress_bar = tqdm(range(len(trainloader))) | |
| for epoch in range(epochs): | |
| self.cur_epoch = epoch | |
| for batch_idx, graph in enumerate(trainloader): | |
| self.cur_batch_idx = batch_idx | |
| # Move batch of graphs to device. Note: a single graph here | |
| # represents a bar in the original sequence. | |
| graph = graph.to(self.device) | |
| s_tensor, c_tensor = graph.s_tensor, graph.c_tensor | |
| with torch.cuda.amp.autocast(enabled=self.cuda): | |
| # Forward pass to obtain mu, log(sigma^2), computed by the | |
| # encoder, and structure and content logits, computed by the | |
| # decoder | |
| (s_logits, c_logits), mu, log_var = self.model(graph) | |
| # Compute losses | |
| tot_loss, losses = self._losses( | |
| s_tensor, s_logits, | |
| c_tensor, c_logits, | |
| mu, log_var | |
| ) | |
| tot_loss = tot_loss / self.iters_to_accumulate | |
| # Backpropagation | |
| if self.cuda: | |
| scaler.scale(tot_loss).backward() | |
| else: | |
| tot_loss.backward() | |
| # Update weights with accumulated gradients | |
| if (self.tot_batches + 1) % self.iters_to_accumulate == 0: | |
| if self.cuda: | |
| scaler.step(self.optimizer) | |
| scaler.update() | |
| else: | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| # Update lr and beta | |
| if self.lr_scheduler is not None: | |
| self.lr_scheduler.step() | |
| if self.beta_scheduler is not None: | |
| self.beta_scheduler.step() | |
| # Compute accuracies | |
| accs = self._accuracies( | |
| s_tensor, s_logits, | |
| c_tensor, c_logits, | |
| graph.is_drum | |
| ) | |
| # Update the stats | |
| append_dict(self.tr_losses, losses) | |
| append_dict(self.tr_accuracies, accs) | |
| last_lr = (self.lr_scheduler.lr | |
| if self.lr_scheduler is not None else self.init_lr) | |
| self.lrs.append(last_lr) | |
| self.betas.append(self.beta) | |
| now = time.time() | |
| self.times.append(now) | |
| # Print stats | |
| if (self.tot_batches + 1) % self.print_every == 0: | |
| print("Training on batch {}/{} of epoch {}/{} complete." | |
| .format(batch_idx+1, | |
| len(trainloader), | |
| epoch+1, | |
| epochs)) | |
| self._print_stats() | |
| print_divider() | |
| # Eval on VL every `eval_every` gradient updates | |
| if (validloader is not None and | |
| (self.tot_batches + 1) % self.eval_every == 0): | |
| # Evaluate on VL | |
| print("\nEvaluating on validation set...\n") | |
| val_losses, val_accuracies = self.evaluate(validloader) | |
| # Update stats | |
| append_dict(self.val_losses, val_losses) | |
| append_dict(self.val_accuracies, val_accuracies) | |
| print("Val losses:") | |
| print(val_losses) | |
| print("Val accuracies:") | |
| print(val_accuracies) | |
| # Save model if VL loss (tot) reached a new minimum | |
| tot_loss = val_losses['tot'] | |
| if tot_loss < self.min_val_loss: | |
| print("\nValidation loss improved.") | |
| print("Saving new best model to disk...\n") | |
| self._save_model('best_model') | |
| self.min_val_loss = tot_loss | |
| self.model.train() | |
| progress_bar.update(1) | |
| # Save model and stats on disk | |
| if (self.save_every > 0 and | |
| (self.tot_batches + 1) % self.save_every == 0): | |
| self._save_model('checkpoint') | |
| # Stop prematurely if early_exit is set and reached | |
| if (early_exit is not None and | |
| (self.tot_batches + 1) > early_exit): | |
| break | |
| self.tot_batches += 1 | |
| end = time.time() | |
| hours, rem = divmod(end-start, 3600) | |
| minutes, seconds = divmod(rem, 60) | |
| print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}" | |
| .format(int(hours), int(minutes), seconds)) | |
| self._save_model('checkpoint') | |
| def evaluate(self, loader): | |
| losses = defaultdict(list) | |
| accs = defaultdict(list) | |
| self.model.eval() | |
| progress_bar = tqdm(range(len(loader))) | |
| with torch.no_grad(): | |
| for _, graph in enumerate(loader): | |
| # Get the inputs and move them to device | |
| graph = graph.to(self.device) | |
| s_tensor, c_tensor = graph.s_tensor, graph.c_tensor | |
| with torch.cuda.amp.autocast(): | |
| # Forward pass, get the reconstructions | |
| (s_logits, c_logits), mu, log_var = self.model(graph) | |
| _, losses_b = self._losses( | |
| s_tensor, s_logits, | |
| c_tensor, c_logits, | |
| mu, log_var | |
| ) | |
| accs_b = self._accuracies( | |
| s_tensor, s_logits, | |
| c_tensor, c_logits, | |
| graph.is_drum | |
| ) | |
| # Save losses and accuracies | |
| append_dict(losses, losses_b) | |
| append_dict(accs, accs_b) | |
| progress_bar.update(1) | |
| # Compute avg losses and accuracies | |
| avg_losses = {} | |
| for k, l in losses.items(): | |
| avg_losses[k] = mean(l) | |
| avg_accs = {} | |
| for k, l in accs.items(): | |
| avg_accs[k] = mean(l) | |
| return avg_losses, avg_accs | |
| def _losses(self, s_tensor, s_logits, c_tensor, c_logits, mu, log_var): | |
| # Do not consider SOS token | |
| c_tensor = c_tensor[..., 1:, :] | |
| c_logits = c_logits.reshape(-1, c_logits.size(-1)) | |
| c_tensor = c_tensor.reshape(-1, c_tensor.size(-1)) | |
| # Reshape logits to match s_tensor dimensions: | |
| # n_graphs (in batch) x n_tracks x n_timesteps | |
| s_logits = s_tensor.reshape(-1, *s_logits.shape[2:]) | |
| # Binary structure tensor loss (binary cross entropy) | |
| s_loss = self.bce_unreduced( | |
| s_logits.view(-1), s_tensor.view(-1).float()) | |
| s_loss = torch.mean(s_loss) | |
| # Content tensor loss (pitches) | |
| # argmax is used to obtain token ids from onehot rep | |
| pitch_logits = c_logits[:, :constants.N_PITCH_TOKENS] | |
| pitch_true = c_tensor[:, :constants.N_PITCH_TOKENS].argmax(dim=1) | |
| pitch_loss = self.ce_p(pitch_logits, pitch_true) | |
| # Content tensor loss (durations) | |
| dur_logits = c_logits[:, constants.N_PITCH_TOKENS:] | |
| dur_true = c_tensor[:, constants.N_PITCH_TOKENS:].argmax(dim=1) | |
| dur_loss = self.ce_d(dur_logits, dur_true) | |
| # Kullback-Leibler divergence loss | |
| # Derivation in Kingma, Diederik P., and Max Welling. "Auto-encoding | |
| # variational bayes." (2013), Appendix B. | |
| # (https://arxiv.org/pdf/1312.6114.pdf) | |
| kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), | |
| dim=1) | |
| kld_loss = torch.mean(kld_loss) | |
| # Reconstruction loss and total loss | |
| rec_loss = pitch_loss + dur_loss + s_loss | |
| tot_loss = rec_loss + self.beta*kld_loss | |
| losses = { | |
| 'tot': tot_loss.item(), | |
| 'pitch': pitch_loss.item(), | |
| 'dur': dur_loss.item(), | |
| 'structure': s_loss.item(), | |
| 'reconstruction': rec_loss.item(), | |
| 'kld': kld_loss.item(), | |
| 'beta*kld': self.beta*kld_loss.item() | |
| } | |
| return tot_loss, losses | |
| def _accuracies(self, s_tensor, s_logits, c_tensor, c_logits, is_drum): | |
| # Do not consider SOS token | |
| c_tensor = c_tensor[..., 1:, :] | |
| # Reshape logits to match s_tensor dimensions: | |
| # n_graphs (in batch) x n_tracks x n_timesteps | |
| s_logits = s_tensor.reshape(-1, *s_logits.shape[2:]) | |
| # Note accuracy considers both pitches and durations | |
| note_acc = self._note_accuracy(c_logits, c_tensor) | |
| pitch_acc = self._pitch_accuracy(c_logits, c_tensor) | |
| # Compute pitch accuracies for drums and non drums separately | |
| pitch_acc_drums = self._pitch_accuracy( | |
| c_logits, c_tensor, drums=True, is_drum=is_drum | |
| ) | |
| pitch_acc_non_drums = self._pitch_accuracy( | |
| c_logits, c_tensor, drums=False, is_drum=is_drum | |
| ) | |
| dur_acc = self._duration_accuracy(c_logits, c_tensor) | |
| s_acc = self._structure_accuracy(s_logits, s_tensor) | |
| s_precision = self._structure_precision(s_logits, s_tensor) | |
| s_recall = self._structure_recall(s_logits, s_tensor) | |
| s_f1 = (2*s_recall*s_precision / (s_recall+s_precision)) | |
| accs = { | |
| 'note': note_acc.item(), | |
| 'pitch': pitch_acc.item(), | |
| 'pitch_drums': pitch_acc_drums.item(), | |
| 'pitch_non_drums': pitch_acc_non_drums.item(), | |
| 'dur': dur_acc.item(), | |
| 's_acc': s_acc.item(), | |
| 's_precision': s_precision.item(), | |
| 's_recall': s_recall.item(), | |
| 's_f1': s_f1.item() | |
| } | |
| return accs | |
| def _pitch_accuracy(self, c_logits, c_tensor, drums=None, is_drum=None): | |
| # When drums is None, just compute the global pitch accuracy without | |
| # distinguishing between drum and non drum pitches | |
| if drums is not None: | |
| if drums: | |
| c_logits = c_logits[is_drum] | |
| c_tensor = c_tensor[is_drum] | |
| else: | |
| c_logits = c_logits[torch.logical_not(is_drum)] | |
| c_tensor = c_tensor[torch.logical_not(is_drum)] | |
| # Apply softmax to obtain pitch reconstructions | |
| pitch_rec = c_logits[..., :constants.N_PITCH_TOKENS] | |
| pitch_rec = F.softmax(pitch_rec, dim=-1) | |
| pitch_rec = torch.argmax(pitch_rec, dim=-1) | |
| pitch_true = c_tensor[..., :constants.N_PITCH_TOKENS] | |
| pitch_true = torch.argmax(pitch_true, dim=-1) | |
| # Do not consider PAD tokens when computing accuracies | |
| not_pad = (pitch_true != PitchToken.PAD.value) | |
| correct = (pitch_rec == pitch_true) | |
| correct = torch.logical_and(correct, not_pad) | |
| return torch.sum(correct) / torch.sum(not_pad) | |
| def _duration_accuracy(self, c_logits, c_tensor): | |
| # Apply softmax to obtain reconstructed durations | |
| dur_rec = c_logits[..., constants.N_PITCH_TOKENS:] | |
| dur_rec = F.softmax(dur_rec, dim=-1) | |
| dur_rec = torch.argmax(dur_rec, dim=-1) | |
| dur_true = c_tensor[..., constants.N_PITCH_TOKENS:] | |
| dur_true = torch.argmax(dur_true, dim=-1) | |
| # Do not consider PAD tokens when computing accuracies | |
| not_pad = (dur_true != DurationToken.PAD.value) | |
| correct = (dur_rec == dur_true) | |
| correct = torch.logical_and(correct, not_pad) | |
| return torch.sum(correct) / torch.sum(not_pad) | |
| def _note_accuracy(self, c_logits, c_tensor): | |
| # Apply softmax to obtain pitch reconstructions | |
| pitch_rec = c_logits[..., :constants.N_PITCH_TOKENS] | |
| pitch_rec = F.softmax(pitch_rec, dim=-1) | |
| pitch_rec = torch.argmax(pitch_rec, dim=-1) | |
| pitch_true = c_tensor[..., :constants.N_PITCH_TOKENS] | |
| pitch_true = torch.argmax(pitch_true, dim=-1) | |
| not_pad_p = (pitch_true != PitchToken.PAD.value) | |
| correct_p = (pitch_rec == pitch_true) | |
| correct_p = torch.logical_and(correct_p, not_pad_p) | |
| dur_rec = c_logits[..., constants.N_PITCH_TOKENS:] | |
| dur_rec = F.softmax(dur_rec, dim=-1) | |
| dur_rec = torch.argmax(dur_rec, dim=-1) | |
| dur_true = c_tensor[..., constants.N_PITCH_TOKENS:] | |
| dur_true = torch.argmax(dur_true, dim=-1) | |
| not_pad_d = (dur_true != DurationToken.PAD.value) | |
| correct_d = (dur_rec == dur_true) | |
| correct_d = torch.logical_and(correct_d, not_pad_d) | |
| note_accuracy = torch.sum( | |
| torch.logical_and(correct_p, correct_d)) / torch.sum(not_pad_p) | |
| return note_accuracy | |
| def _structure_accuracy(self, s_logits, s_tensor): | |
| s_logits = torch.sigmoid(s_logits) | |
| s_logits[s_logits < 0.5] = 0 | |
| s_logits[s_logits >= 0.5] = 1 | |
| return torch.sum(s_logits == s_tensor) / s_tensor.numel() | |
| def _structure_precision(self, s_logits, s_tensor): | |
| s_logits = torch.sigmoid(s_logits) | |
| s_logits[s_logits < 0.5] = 0 | |
| s_logits[s_logits >= 0.5] = 1 | |
| tp = torch.sum(s_tensor[s_logits == 1]) | |
| return tp / torch.sum(s_logits) | |
| def _structure_recall(self, s_logits, s_tensor): | |
| s_logits = torch.sigmoid(s_logits) | |
| s_logits[s_logits < 0.5] = 0 | |
| s_logits[s_logits >= 0.5] = 1 | |
| tp = torch.sum(s_tensor[s_logits == 1]) | |
| return tp / torch.sum(s_tensor) | |
| def _save_model(self, filename): | |
| path = os.path.join(self.model_dir, filename) | |
| print("Saving model to disk...") | |
| torch.save({ | |
| 'epoch': self.cur_epoch, | |
| 'batch': self.cur_batch_idx, | |
| 'tot_batches': self.tot_batches, | |
| 'betas': self.betas, | |
| 'min_val_loss': self.min_val_loss, | |
| 'print_every': self.print_every, | |
| 'save_every': self.save_every, | |
| 'eval_every': self.eval_every, | |
| 'lrs': self.lrs, | |
| 'tr_losses': self.tr_losses, | |
| 'tr_accuracies': self.tr_accuracies, | |
| 'val_losses': self.val_losses, | |
| 'val_accuracies': self.val_accuracies, | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict() | |
| }, path) | |
| print("The model has been successfully saved.") | |
| def _print_stats(self): | |
| hours, rem = divmod(self.times[-1]-self.times[0], 3600) | |
| minutes, seconds = divmod(rem, 60) | |
| print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}" | |
| .format(int(hours), int(minutes), seconds)) | |
| # Take mean of the last non-printed batches for each loss and accuracy | |
| avg_losses = {} | |
| for k, l in self.tr_losses.items(): | |
| v = mean(l[-self.print_every:]) | |
| avg_losses[k] = round(v, 2) | |
| avg_accs = {} | |
| for k, l in self.tr_accuracies.items(): | |
| v = mean(l[-self.print_every:]) | |
| avg_accs[k] = round(v, 2) | |
| print("Losses:") | |
| pprint.pprint(avg_losses, indent=2) | |
| print("Accuracies:") | |
| pprint.pprint(avg_accs, indent=2) | |