Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def load_sdrs(workspace, task_name, filename, config, gpus, source_type): | |
| stat_path = os.path.join( | |
| workspace, | |
| "statistics", | |
| task_name, | |
| filename, | |
| "config={},gpus={}".format(config, gpus), | |
| "statistics.pkl", | |
| ) | |
| stat_dict = pickle.load(open(stat_path, 'rb')) | |
| median_sdrs = [e['median_sdr_dict'][source_type] for e in stat_dict['test']] | |
| return median_sdrs | |
| def plot_statistics(args): | |
| # arguments & parameters | |
| workspace = args.workspace | |
| select = args.select | |
| task_name = "musdb18" | |
| filename = "train" | |
| # paths | |
| fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select)) | |
| os.makedirs(os.path.dirname(fig_path), exist_ok=True) | |
| linewidth = 1 | |
| lines = [] | |
| fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
| if select == '1a': | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='vocals-accompaniment,unet', | |
| gpus=1, | |
| source_type="vocals", | |
| ) | |
| (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| ylim = 15 | |
| elif select == '1b': | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='accompaniment-vocals,unet', | |
| gpus=1, | |
| source_type="accompaniment", | |
| ) | |
| (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| ylim = 20 | |
| if select == '1c': | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='vocals-accompaniment,unet', | |
| gpus=1, | |
| source_type="vocals", | |
| ) | |
| (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='vocals-accompaniment,resunet', | |
| gpus=2, | |
| source_type="vocals", | |
| ) | |
| (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='vocals-accompaniment,unet_subbandtime', | |
| gpus=1, | |
| source_type="vocals", | |
| ) | |
| (line,) = ax.plot(sdrs, label='unet_subband,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='vocals-accompaniment,resunet_subbandtime', | |
| gpus=1, | |
| source_type="vocals", | |
| ) | |
| (line,) = ax.plot(sdrs, label='resunet_subband,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| ylim = 15 | |
| elif select == '1d': | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='accompaniment-vocals,unet', | |
| gpus=1, | |
| source_type="accompaniment", | |
| ) | |
| (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='accompaniment-vocals,resunet', | |
| gpus=2, | |
| source_type="accompaniment", | |
| ) | |
| (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth) | |
| lines.append(line) | |
| # sdrs = load_sdrs( | |
| # workspace, | |
| # task_name, | |
| # filename, | |
| # config='accompaniment-vocals,unet_subbandtime', | |
| # gpus=1, | |
| # source_type="accompaniment", | |
| # ) | |
| # (line,) = ax.plot(sdrs, label='UNet_subbtandtime,l1_wav', linewidth=linewidth) | |
| # lines.append(line) | |
| sdrs = load_sdrs( | |
| workspace, | |
| task_name, | |
| filename, | |
| config='accompaniment-vocals,resunet_subbandtime', | |
| gpus=1, | |
| source_type="accompaniment", | |
| ) | |
| (line,) = ax.plot( | |
| sdrs, label='ResUNet_subbtandtime,l1_wav', linewidth=linewidth | |
| ) | |
| lines.append(line) | |
| ylim = 20 | |
| else: | |
| raise Exception('Error!') | |
| eval_every_iterations = 10000 | |
| total_ticks = 50 | |
| ticks_freq = 10 | |
| ax.set_ylim(0, ylim) | |
| ax.set_xlim(0, total_ticks) | |
| ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq)) | |
| ax.xaxis.set_ticklabels( | |
| np.arange( | |
| 0, | |
| total_ticks * eval_every_iterations + 1, | |
| ticks_freq * eval_every_iterations, | |
| ) | |
| ) | |
| ax.yaxis.set_ticks(np.arange(ylim + 1)) | |
| ax.yaxis.set_ticklabels(np.arange(ylim + 1)) | |
| ax.grid(color='b', linestyle='solid', linewidth=0.3) | |
| plt.legend(handles=lines, loc=4) | |
| plt.savefig(fig_path) | |
| print('Save figure to {}'.format(fig_path)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--workspace', type=str, required=True) | |
| parser.add_argument('--select', type=str, required=True) | |
| args = parser.parse_args() | |
| plot_statistics(args) | |