|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from pandas.api.types import CategoricalDtype |
|
|
|
|
|
def plot_metrics(metrics, condition_to_latex, title, color_palette): |
|
|
|
|
|
metrics['condition_latex'] = metrics['condition'].map(condition_to_latex) |
|
|
|
|
|
|
|
|
cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True) |
|
|
metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type) |
|
|
|
|
|
|
|
|
grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std']) |
|
|
|
|
|
fig, axs = plt.subplots(2, 1, figsize=(7, 5.25)) |
|
|
|
|
|
|
|
|
fig.suptitle(title, fontsize=16) |
|
|
|
|
|
|
|
|
bar_colors = [color_palette[condition] for condition in grouped.index] |
|
|
|
|
|
|
|
|
sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False) |
|
|
axs[0].set_ylabel('Mel Spectrogram Loss \u2190') |
|
|
axs[0].set_xlabel('') |
|
|
axs[0].set_xticklabels(grouped.index, rotation=0, ha='center') |
|
|
|
|
|
|
|
|
axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors) |
|
|
axs[1].set_ylabel('FAD \u2190') |
|
|
axs[1].set_xlabel('') |
|
|
axs[1].set_xticklabels(grouped.index, rotation=0, ha='center') |
|
|
|
|
|
|
|
|
plt.subplots_adjust(hspace=0.1) |
|
|
|
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.96]) |
|
|
|
|
|
|
|
|
plt.subplots_adjust(top=0.92) |