Spaces:
Sleeping
Sleeping
| import os | |
| from matplotlib import pyplot as plt | |
| import matplotlib as mpl | |
| import muspy | |
| import torch | |
| import constants | |
| def plot_pianoroll(muspy_song, save_dir=None, name='pianoroll'): | |
| lines_linewidth = 4 | |
| axes_linewidth = 4 | |
| font_size = 34 | |
| fformat = 'png' | |
| xticklabel = False | |
| label = 'y' | |
| figsize = (20, 10) | |
| dpi = 200 | |
| with mpl.rc_context({'lines.linewidth': lines_linewidth, | |
| 'axes.linewidth': axes_linewidth, | |
| 'font.size': font_size}): | |
| fig, axs_ = plt.subplots(constants.N_TRACKS, sharex=True, | |
| figsize=figsize) | |
| fig.subplots_adjust(hspace=0) | |
| axs = axs_.tolist() | |
| muspy.show_pianoroll(music=muspy_song, yticklabel='off', xtick='off', | |
| label=label, xticklabel=xticklabel, | |
| grid_axis='off', axs=axs, preset='full') | |
| if save_dir: | |
| plt.savefig(os.path.join(save_dir, name + "." + fformat), | |
| format=fformat, dpi=dpi) | |
| def plot_structure(s_tensor, save_dir=None, name='structure'): | |
| lines_linewidth = 1 | |
| axes_linewidth = 1 | |
| font_size = 14 | |
| fformat = 'svg' | |
| dpi = 200 | |
| n_bars = s_tensor.shape[0] | |
| figsize = (3 * n_bars, 3) | |
| n_timesteps = s_tensor.size(2) | |
| resolution = n_timesteps // 4 | |
| s_tensor = s_tensor.permute(1, 0, 2) | |
| s_tensor = s_tensor.reshape(s_tensor.shape[0], -1) | |
| with mpl.rc_context({'lines.linewidth': lines_linewidth, | |
| 'axes.linewidth': axes_linewidth, | |
| 'font.size': font_size}): | |
| plt.figure(figsize=figsize) | |
| plt.pcolormesh(s_tensor, edgecolors='k', linewidth=1) | |
| ax = plt.gca() | |
| plt.xticks(range(0, s_tensor.shape[1], resolution), | |
| range(1, 4*n_bars + 1)) | |
| plt.yticks(range(0, s_tensor.shape[0]), constants.TRACKS) | |
| ax.invert_yaxis() | |
| if save_dir: | |
| plt.savefig(os.path.join(save_dir, name + "." + fformat), | |
| format=fformat, dpi=dpi) | |
| def plot_stats(stat_names, stats_tr, stats_val=None, eval_every=None, | |
| labels=None, rx=None, ry=None): | |
| for i, stat in enumerate(stat_names): | |
| label = stat if not labels else labels[i] | |
| plt.plot(range(1, len(stats_tr[stat])+1), stats_tr[stat], | |
| label=label+' (TR)') | |
| if stats_val: | |
| plt.plot(range(eval_every, len(stats_tr[stat])+1, eval_every), | |
| stats_val[stat], '.', label=label+' (VL)') | |
| plt.grid() | |
| plt.ylim(ry) if ry else plt.ylim(0) | |
| plt.xlim(rx) if rx else plt.xlim(0) | |
| plt.legend() | |
| # Dictionary that maps loss statistic name to plot label | |
| loss_labels = { | |
| 'tot': 'Total Loss', | |
| 'structure': 'Structure', | |
| 'pitch': 'Pitches', | |
| 'dur': 'Duration', | |
| 'reconstruction': 'Reconstruction Term', | |
| 'kld': 'KLD', | |
| 'beta*kld': 'beta * KLD' | |
| } | |
| def plot_losses(model_dir, losses, plot_val=False): | |
| checkpoint_path = os.path.join(model_dir, 'checkpoint') | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| labels = [loss_labels[loss] for loss in losses] | |
| tr_losses = checkpoint['tr_losses'] | |
| val_losses = checkpoint['val_losses'] if plot_val == True else None | |
| eval_every = checkpoint['eval_every'] if plot_val == True else None | |
| plot_stats(losses, tr_losses, stats_val=val_losses, | |
| eval_every=eval_every, labels=labels, rx=(0)) | |
| # Dictionary that maps accuracy statistic name to plot label | |
| accuracy_labels = { | |
| 's_acc': 'Struct. Accuracy', | |
| 's_precision': 'Struct. Precision', | |
| 's_recall': 'Struct. Recall', | |
| 's_f1': 'Struct. F1', | |
| 'pitch': 'Pitch Accuracy', | |
| 'pitch_drums': 'Pitch Accuracy (Drums)', | |
| 'pitch_non_drums': 'Pitch Accuracy (Non Drums)', | |
| 'dur': 'Duration Accuracy', | |
| 'note': 'Note Accuracy' | |
| } | |
| def plot_accuracies(model_dir, accuracies, plot_val=False): | |
| checkpoint_path = os.path.join(model_dir, 'checkpoint') | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| labels = [accuracy_labels[accuracy] for accuracy in accuracies] | |
| tr_accuracies = checkpoint['tr_accuracies'] | |
| val_accuracies = checkpoint['val_accuracies'] if plot_val == True else None | |
| eval_every = checkpoint['eval_every'] if plot_val == True else None | |
| plot_stats(accuracies, tr_accuracies, stats_val=val_accuracies, | |
| eval_every=eval_every, labels=labels, ry=(0, 1)) |