Spaces:
Runtime error
Runtime error
| """ | |
| This script evaluates the contribution of a technique from the ablation study for | |
| improving the masker evaluation metrics. The differences in the metrics are computed | |
| for all images of paired models, that is those which only differ in the inclusion or | |
| not of the given technique. Then, statistical inference is performed through the | |
| percentile bootstrap to obtain robust estimates of the differences in the metrics and | |
| confidence intervals. The script plots the distribution of the bootrstraped estimates. | |
| """ | |
| print("Imports...", end="") | |
| from argparse import ArgumentParser | |
| import yaml | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| import os | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import matplotlib.transforms as transforms | |
| # ----------------------- | |
| # ----- Constants ----- | |
| # ----------------------- | |
| dict_models = { | |
| "md": 11, | |
| "dada_ms, msd, pseudo": 9, | |
| "msd, pseudo": 4, | |
| "dada, msd_spade, pseudo": 7, | |
| "msd": 13, | |
| "dada_m, msd": 17, | |
| "dada, msd_spade": 16, | |
| "msd_spade, pseudo": 5, | |
| "dada_ms, msd": 18, | |
| "dada, msd, pseudo": 6, | |
| "ms": 12, | |
| "dada, msd": 15, | |
| "dada_m, msd, pseudo": 8, | |
| "msd_spade": 14, | |
| "m": 10, | |
| "md, pseudo": 2, | |
| "ms, pseudo": 3, | |
| "m, pseudo": 1, | |
| "ground": "G", | |
| "instagan": "I", | |
| } | |
| dict_metrics = { | |
| "names": { | |
| "tpr": "TPR, Recall, Sensitivity", | |
| "tnr": "TNR, Specificity, Selectivity", | |
| "fpr": "FPR", | |
| "fpt": "False positives relative to image size", | |
| "fnr": "FNR, Miss rate", | |
| "fnt": "False negatives relative to image size", | |
| "mpr": "May positive rate (MPR)", | |
| "mnr": "May negative rate (MNR)", | |
| "accuracy": "Accuracy (ignoring may)", | |
| "error": "Error", | |
| "f05": "F05 score", | |
| "precision": "Precision", | |
| "edge_coherence": "Edge coherence", | |
| "accuracy_must_may": "Accuracy (ignoring cannot)", | |
| }, | |
| "key_metrics": ["f05", "error", "edge_coherence"], | |
| } | |
| dict_techniques = { | |
| "depth": "depth", | |
| "segmentation": "seg", | |
| "seg": "seg", | |
| "dada_s": "dada_seg", | |
| "dada_seg": "dada_seg", | |
| "dada_segmentation": "dada_seg", | |
| "dada_m": "dada_masker", | |
| "dada_masker": "dada_masker", | |
| "spade": "spade", | |
| "pseudo": "pseudo", | |
| "pseudo-labels": "pseudo", | |
| "pseudo_labels": "pseudo", | |
| } | |
| # Markers | |
| dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"} | |
| # Model features | |
| model_feats = [ | |
| "masker", | |
| "seg", | |
| "depth", | |
| "dada_seg", | |
| "dada_masker", | |
| "spade", | |
| "pseudo", | |
| "ground", | |
| "instagan", | |
| ] | |
| # Colors | |
| palette_colorblind = sns.color_palette("colorblind") | |
| color_climategan = palette_colorblind[0] | |
| color_munit = palette_colorblind[1] | |
| color_cyclegan = palette_colorblind[6] | |
| color_instagan = palette_colorblind[8] | |
| color_maskinstagan = palette_colorblind[2] | |
| color_paintedground = palette_colorblind[3] | |
| color_cat1 = palette_colorblind[0] | |
| color_cat2 = palette_colorblind[1] | |
| palette_lightest = [ | |
| sns.light_palette(color_cat1, n_colors=20)[3], | |
| sns.light_palette(color_cat2, n_colors=20)[3], | |
| ] | |
| palette_light = [ | |
| sns.light_palette(color_cat1, n_colors=3)[1], | |
| sns.light_palette(color_cat2, n_colors=3)[1], | |
| ] | |
| palette_medium = [color_cat1, color_cat2] | |
| palette_dark = [ | |
| sns.dark_palette(color_cat1, n_colors=3)[1], | |
| sns.dark_palette(color_cat2, n_colors=3)[1], | |
| ] | |
| palette_cat1 = [ | |
| palette_lightest[0], | |
| palette_light[0], | |
| palette_medium[0], | |
| palette_dark[0], | |
| ] | |
| palette_cat2 = [ | |
| palette_lightest[1], | |
| palette_light[1], | |
| palette_medium[1], | |
| palette_dark[1], | |
| ] | |
| color_cat1_light = palette_light[0] | |
| color_cat2_light = palette_light[1] | |
| def parsed_args(): | |
| """ | |
| Parse and returns command-line args | |
| Returns: | |
| argparse.Namespace: the parsed arguments | |
| """ | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--input_csv", | |
| default="ablations_metrics_20210311.csv", | |
| type=str, | |
| help="CSV containing the results of the ablation study", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| default=None, | |
| type=str, | |
| help="Output directory", | |
| ) | |
| parser.add_argument( | |
| "--models", | |
| default="all", | |
| type=str, | |
| help="Models to display: all, pseudo, no_dada_masker, no_baseline", | |
| ) | |
| parser.add_argument( | |
| "--dpi", | |
| default=200, | |
| type=int, | |
| help="DPI for the output images", | |
| ) | |
| parser.add_argument( | |
| "--n_bs", | |
| default=1e6, | |
| type=int, | |
| help="Number of bootrstrap samples", | |
| ) | |
| parser.add_argument( | |
| "--alpha", | |
| default=0.99, | |
| type=float, | |
| help="Confidence level", | |
| ) | |
| parser.add_argument( | |
| "--bs_seed", | |
| default=17, | |
| type=int, | |
| help="Bootstrap random seed, for reproducibility", | |
| ) | |
| return parser.parse_args() | |
| def plot_median_metrics( | |
| df, do_stripplot=True, dpi=200, bs_seed=37, n_bs=1000, **snskwargs | |
| ): | |
| def plot_metric( | |
| ax, df, metric, do_stripplot=True, dpi=200, bs_seed=37, marker="o", **snskwargs | |
| ): | |
| y_labels = [dict_models[f] for f in df.model_feats.unique()] | |
| # Labels | |
| y_labels_int = np.sort([el for el in y_labels if isinstance(el, int)]).tolist() | |
| y_order_int = [ | |
| k for vs in y_labels_int for k, vu in dict_models.items() if vs == vu | |
| ] | |
| y_labels_int = [str(el) for el in y_labels_int] | |
| y_labels_str = sorted([el for el in y_labels if not isinstance(el, int)]) | |
| y_order_str = [ | |
| k for vs in y_labels_str for k, vu in dict_models.items() if vs == vu | |
| ] | |
| y_labels = y_labels_int + y_labels_str | |
| y_order = y_order_int + y_order_str | |
| # Palette | |
| palette = len(y_labels_int) * [color_climategan] | |
| for y in y_labels_str: | |
| if y == "G": | |
| palette = palette + [color_paintedground] | |
| if y == "I": | |
| palette = palette + [color_maskinstagan] | |
| # Error | |
| sns.pointplot( | |
| ax=ax, | |
| data=df, | |
| x=metric, | |
| y="model_feats", | |
| order=y_order, | |
| markers=marker, | |
| estimator=np.median, | |
| ci=99, | |
| seed=bs_seed, | |
| n_boot=n_bs, | |
| join=False, | |
| scale=0.6, | |
| errwidth=1.5, | |
| capsize=0.1, | |
| palette=palette, | |
| ) | |
| xlim = ax.get_xlim() | |
| if do_stripplot: | |
| sns.stripplot( | |
| ax=ax, | |
| data=df, | |
| x=metric, | |
| y="model_feats", | |
| size=1.5, | |
| palette=palette, | |
| alpha=0.2, | |
| ) | |
| ax.set_xlim(xlim) | |
| # Set X-label | |
| ax.set_xlabel(dict_metrics["names"][metric], rotation=0, fontsize="medium") | |
| # Set Y-label | |
| ax.set_ylabel(None) | |
| ax.set_yticklabels(y_labels, fontsize="medium") | |
| # Change spines | |
| sns.despine(ax=ax, left=True, bottom=True) | |
| # Draw gray area on final model | |
| xlim = ax.get_xlim() | |
| ylim = ax.get_ylim() | |
| trans = transforms.blended_transform_factory(ax.transAxes, ax.transData) | |
| rect = mpatches.Rectangle( | |
| xy=(0.0, 5.5), | |
| width=1, | |
| height=1, | |
| transform=trans, | |
| linewidth=0.0, | |
| edgecolor="none", | |
| facecolor="gray", | |
| alpha=0.05, | |
| ) | |
| ax.add_patch(rect) | |
| # Set up plot | |
| sns.set(style="whitegrid") | |
| plt.rcParams.update({"font.family": "serif"}) | |
| plt.rcParams.update( | |
| { | |
| "font.serif": [ | |
| "Computer Modern Roman", | |
| "Times New Roman", | |
| "Utopia", | |
| "New Century Schoolbook", | |
| "Century Schoolbook L", | |
| "ITC Bookman", | |
| "Bookman", | |
| "Times", | |
| "Palatino", | |
| "Charter", | |
| "serif" "Bitstream Vera Serif", | |
| "DejaVu Serif", | |
| ] | |
| } | |
| ) | |
| fig_h = 0.4 * len(df.model_feats.unique()) | |
| fig, axes = plt.subplots( | |
| nrows=1, ncols=3, sharey=True, dpi=dpi, figsize=(18, fig_h) | |
| ) | |
| # Error | |
| plot_metric( | |
| axes[0], | |
| df, | |
| "error", | |
| do_stripplot=do_stripplot, | |
| dpi=dpi, | |
| bs_seed=bs_seed, | |
| marker=dict_markers["error"], | |
| ) | |
| axes[0].set_ylabel("Models") | |
| # F05 | |
| plot_metric( | |
| axes[1], | |
| df, | |
| "f05", | |
| do_stripplot=do_stripplot, | |
| dpi=dpi, | |
| bs_seed=bs_seed, | |
| marker=dict_markers["f05"], | |
| ) | |
| # Edge coherence | |
| plot_metric( | |
| axes[2], | |
| df, | |
| "edge_coherence", | |
| do_stripplot=do_stripplot, | |
| dpi=dpi, | |
| bs_seed=bs_seed, | |
| marker=dict_markers["edge_coherence"], | |
| ) | |
| xticks = axes[2].get_xticks() | |
| xticklabels = ["{:.3f}".format(x) for x in xticks] | |
| axes[2].set(xticks=xticks, xticklabels=xticklabels) | |
| plt.subplots_adjust(wspace=0.12) | |
| return fig | |
| if __name__ == "__main__": | |
| # ----------------------------- | |
| # ----- Parse arguments ----- | |
| # ----------------------------- | |
| args = parsed_args() | |
| print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
| # Determine output dir | |
| if args.output_dir is None: | |
| output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
| else: | |
| output_dir = Path(args.output_dir) | |
| if not output_dir.exists(): | |
| output_dir.mkdir(parents=True, exist_ok=False) | |
| # Store args | |
| output_yml = output_dir / "ablation_comparison_{}.yml".format(args.models) | |
| with open(output_yml, "w") as f: | |
| yaml.dump(vars(args), f) | |
| # Read CSV | |
| df = pd.read_csv(args.input_csv, index_col="model_img_idx") | |
| # Determine models | |
| if "all" in args.models.lower(): | |
| pass | |
| else: | |
| if "no_baseline" in args.models.lower(): | |
| df = df.loc[(df.ground == False) & (df.instagan == False)] | |
| if "pseudo" in args.models.lower(): | |
| df = df.loc[ | |
| (df.pseudo == True) | (df.ground == True) | (df.instagan == True) | |
| ] | |
| if "no_dada_mask" in args.models.lower(): | |
| df = df.loc[ | |
| (df.dada_masker == False) | (df.ground == True) | (df.instagan == True) | |
| ] | |
| fig = plot_median_metrics(df, do_stripplot=True, dpi=args.dpi, bs_seed=args.bs_seed) | |
| # Save figure | |
| output_fig = output_dir / "ablation_comparison_{}.png".format(args.models) | |
| fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") | |