import sys import pickle import argparse from pathlib import Path from collections import defaultdict import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.ticker import FormatStrFormatter sys.path.insert(0, '.') from isegm.utils.exp import load_config_file def parse_args(): parser = argparse.ArgumentParser() group_pkl_path = parser.add_mutually_exclusive_group(required=True) group_pkl_path.add_argument('--folder', type=str, default=None, help='Path to folder with .pickle files.') group_pkl_path.add_argument('--files', nargs='+', default=None, help='List of paths to .pickle files separated by space.') group_pkl_path.add_argument('--model-dirs', nargs='+', default=None, help="List of paths to model directories with 'plots' folder " "containing .pickle files separated by space.") group_pkl_path.add_argument('--exp-models', nargs='+', default=None, help='List of experiments paths suffixes (relative to cfg.EXPS_PATH/evaluation_logs). ' 'For each experiment, the checkpoint prefix must be specified ' 'by using the ":" delimiter at the end.') parser.add_argument('--mode', choices=['NoBRS', 'RGB-BRS', 'DistMap-BRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C'], default=None, nargs='*', help='') parser.add_argument('--datasets', type=str, default='GrabCut,Berkeley,DAVIS,COCO_MVal,SBD', help='List of datasets for plotting the iou analysis' 'Datasets are separated by a comma. Possible choices: ' 'GrabCut, Berkeley, DAVIS, COCO_MVal, SBD') parser.add_argument('--config-path', type=str, default='./config.yml', help='The path to the config file.') parser.add_argument('--n-clicks', type=int, default=-1, help='Maximum number of clicks to plot.') parser.add_argument('--plots-path', type=str, default='', help='The path to the evaluation logs. ' 'Default path: cfg.EXPS_PATH/evaluation_logs/iou_analysis.') args = parser.parse_args() cfg = load_config_file(args.config_path, return_edict=True) cfg.EXPS_PATH = Path(cfg.EXPS_PATH) args.datasets = args.datasets.split(',') if args.plots_path == '': args.plots_path = cfg.EXPS_PATH / 'evaluation_logs/iou_analysis' else: args.plots_path = Path(args.plots_path) print(args.plots_path) args.plots_path.mkdir(parents=True, exist_ok=True) return args, cfg model_name_mapper = {'sbd_vitb_epoch_54_NoBRS': 'Ours-ViT-B (SBD)', 'sbd_vitl_epoch_54_NoBRS': 'Ours-ViT-L (SBD)', 'sbd_vith_epoch_54_NoBRS': 'Ours-ViT-H (SBD)', 'cocolvis_vitb_epoch_54_NoBRS': 'Ours-ViT-B (C+L)', 'cocolvis_vitl_epoch_54_NoBRS': 'Ours-ViT-L (C+L)', 'cocolvis_vith_epoch_52_NoBRS': 'Ours-ViT-H (C+L)', '052_NoBRS': 'Ours-ViT-H (C+L)', 'sbd_h18_itermask_NoBRS': 'RITM-HRNet18 (SBD)', 'coco_lvis_h32_itermask_NoBRS': 'RITM-HRNet32 (C+L)', 'cocolvis_segformer_b3_s2_FocalClick': 'FocalClick-SegF-B3 (C+L)', 'cocolvis_segformer_b0_s2_FocalClick': 'FocalClick-SegF-B0 (C+L)', 'sbd_cdnet_resnet34_CDNet': 'CDNet-ResNet-34 (SBD)', 'cocolvis_cdnet_resnet34_CDNet': 'CDNet-ResNet-34 (C+L)' } color_style_mapper = {'Ours-ViT-B (SBD)': ('#0000ff', '-'), 'Ours-ViT-L (SBD)': ('#008000', '-'), 'Ours-ViT-H (SBD)': ('#ff0000', '-'), 'Ours-ViT-B (C+L)': ('#0080ff', '-'), 'Ours-ViT-L (C+L)': ('#8000ff', '-'), 'Ours-ViT-H (C+L)': ('#ff8000', '-'), 'RITM-HRNet18 (SBD)': ('#000000', ':'), 'RITM-HRNet32 (C+L)': ('#444444', ':'), 'FocalClick-SegF-B0 (C+L)': ('#888888', ':'), 'FocalClick-SegF-B3 (C+L)': ('#888888', ':'), 'CDNet-ResNet-34 (SBD)': ('', ':'), 'CDNet-ResNet-34 (C+L)': ('', ':') } range_mapper = {'SBD': (65, 96, 3), 'DAVIS': (66, 97, 3), 'Pascal VOC': (66, 100, 3), 'COCO_MVal': (60, 97, 3), 'BraTS': (10, 100, 10), 'OAIZIB': (0,85, 10), 'ssTEM': (5, 100, 10), 'GrabCut': (80, 100, 2), 'Berkeley': (80, 100, 2) } def main(): args, cfg = parse_args() files_list = get_files_list(args, cfg) # Dict of dicts with mapping dataset_name -> model_name -> results aggregated_plot_data = defaultdict(dict) for file in files_list: with open(file, 'rb') as f: data = pickle.load(f) data['all_ious'] = [x[:] if args.n_clicks == -1 else x[:args.n_clicks] for x in data['all_ious']] aggregated_plot_data[data['dataset_name']][data['model_name']] = np.array(data['all_ious']).mean(0) for dataset_name, dataset_results in aggregated_plot_data.items(): plt.figure(figsize=(12, 7)) max_clicks = 0 min_val, max_val = 100, -1 for model_name, model_results in dataset_results.items(): if args.n_clicks != -1: model_results = model_results[:args.n_clicks] model_results = model_results * 100 min_val = min(min_val, min(model_results)) max_val = max(max_val, max(model_results)) n_clicks = len(model_results) max_clicks = max(max_clicks, n_clicks) miou_str = ' '.join([f'mIoU@{click_id}={model_results[click_id-1]:.2%};' for click_id in [1, 3, 5, 10, 20] if click_id <= len(model_results)]) print(f'{model_name} on {dataset_name}:\n{miou_str}\n') label = model_name_mapper[model_name] if model_name in model_name_mapper else model_name color, style = None, None if label in color_style_mapper: color, style = color_style_mapper[label] plt.plot(1 + np.arange(n_clicks), model_results, linewidth=2, label=label, linestyle=style) if dataset_name == 'PascalVOC': dataset_name = 'Pascal VOC' plt.title(f'{dataset_name}', fontsize=22) plt.grid() plt.legend(loc=4, fontsize='x-large') min_val, max_val, step = range_mapper[dataset_name] plt.yticks(np.arange(min_val, max_val, step=step), fontsize='xx-large') plt.xticks(1 + np.arange(max_clicks), fontsize='xx-large') plt.xlabel('Number of Clicks', fontsize='xx-large') plt.ylabel('mIoU score (%)', fontsize='xx-large') plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.0f')) fig_path = get_target_file_path(args.plots_path, dataset_name) plt.savefig(str(fig_path)) def get_target_file_path(plots_path, dataset_name): previous_plots = sorted(plots_path.glob(f'{dataset_name}_*.png')) if len(previous_plots) == 0: index = 0 else: index = int(previous_plots[-1].stem.split('_')[-1]) + 1 return str(plots_path / f'{dataset_name}_{index:03d}.png') def get_files_list(args, cfg): if args.folder is not None: files_list = Path(args.folder).glob('*.pickle') elif args.files is not None: files_list = [Path(file) for file in args.files] elif args.model_dirs is not None: files_list = [] for folder in args.model_dirs: folder = Path(folder) / 'plots' files_list.extend(folder.glob('*.pickle')) elif args.exp_models is not None: files_list = [] for rel_exp_path in args.exp_models: rel_exp_path, checkpoint_prefix = rel_exp_path.split(':') exp_path_prefix = cfg.EXPS_PATH / 'evaluation_logs' / rel_exp_path candidates = list(exp_path_prefix.parent.glob(exp_path_prefix.stem + '*')) assert len(candidates) == 1, "Invalid experiment path." exp_path = candidates[0] files_list.extend(sorted((exp_path / 'plots').glob(checkpoint_prefix + '*.pickle'))) if args.mode is not None: files_list = [file for file in files_list if any(mode in file.stem for mode in args.mode)] files_list = [file for file in files_list if any(dataset in file.stem for dataset in args.datasets)] return files_list if __name__ == '__main__': main()