CellPilot / SAMHI /resources /SimpleClick /scripts /plot_ious_analysis.py
philippendres's picture
Upload folder using huggingface_hub
907462b verified
Raw
History Blame Contribute Delete
8.98 kB
import sys
import pickle
import argparse
from pathlib import Path
from collections import defaultdict
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
sys.path.insert(0, '.')
from isegm.utils.exp import load_config_file
def parse_args():
parser = argparse.ArgumentParser()
group_pkl_path = parser.add_mutually_exclusive_group(required=True)
group_pkl_path.add_argument('--folder', type=str, default=None,
help='Path to folder with .pickle files.')
group_pkl_path.add_argument('--files', nargs='+', default=None,
help='List of paths to .pickle files separated by space.')
group_pkl_path.add_argument('--model-dirs', nargs='+', default=None,
help="List of paths to model directories with 'plots' folder "
"containing .pickle files separated by space.")
group_pkl_path.add_argument('--exp-models', nargs='+', default=None,
help='List of experiments paths suffixes (relative to cfg.EXPS_PATH/evaluation_logs). '
'For each experiment, the checkpoint prefix must be specified '
'by using the ":" delimiter at the end.')
parser.add_argument('--mode', choices=['NoBRS', 'RGB-BRS', 'DistMap-BRS',
'f-BRS-A', 'f-BRS-B', 'f-BRS-C'],
default=None, nargs='*', help='')
parser.add_argument('--datasets', type=str, default='GrabCut,Berkeley,DAVIS,COCO_MVal,SBD',
help='List of datasets for plotting the iou analysis'
'Datasets are separated by a comma. Possible choices: '
'GrabCut, Berkeley, DAVIS, COCO_MVal, SBD')
parser.add_argument('--config-path', type=str, default='./config.yml',
help='The path to the config file.')
parser.add_argument('--n-clicks', type=int, default=-1,
help='Maximum number of clicks to plot.')
parser.add_argument('--plots-path', type=str, default='',
help='The path to the evaluation logs. '
'Default path: cfg.EXPS_PATH/evaluation_logs/iou_analysis.')
args = parser.parse_args()
cfg = load_config_file(args.config_path, return_edict=True)
cfg.EXPS_PATH = Path(cfg.EXPS_PATH)
args.datasets = args.datasets.split(',')
if args.plots_path == '':
args.plots_path = cfg.EXPS_PATH / 'evaluation_logs/iou_analysis'
else:
args.plots_path = Path(args.plots_path)
print(args.plots_path)
args.plots_path.mkdir(parents=True, exist_ok=True)
return args, cfg
model_name_mapper = {'sbd_vitb_epoch_54_NoBRS': 'Ours-ViT-B (SBD)',
'sbd_vitl_epoch_54_NoBRS': 'Ours-ViT-L (SBD)',
'sbd_vith_epoch_54_NoBRS': 'Ours-ViT-H (SBD)',
'cocolvis_vitb_epoch_54_NoBRS': 'Ours-ViT-B (C+L)',
'cocolvis_vitl_epoch_54_NoBRS': 'Ours-ViT-L (C+L)',
'cocolvis_vith_epoch_52_NoBRS': 'Ours-ViT-H (C+L)',
'052_NoBRS': 'Ours-ViT-H (C+L)',
'sbd_h18_itermask_NoBRS': 'RITM-HRNet18 (SBD)',
'coco_lvis_h32_itermask_NoBRS': 'RITM-HRNet32 (C+L)',
'cocolvis_segformer_b3_s2_FocalClick': 'FocalClick-SegF-B3 (C+L)',
'cocolvis_segformer_b0_s2_FocalClick': 'FocalClick-SegF-B0 (C+L)',
'sbd_cdnet_resnet34_CDNet': 'CDNet-ResNet-34 (SBD)',
'cocolvis_cdnet_resnet34_CDNet': 'CDNet-ResNet-34 (C+L)'
}
color_style_mapper = {'Ours-ViT-B (SBD)': ('#0000ff', '-'),
'Ours-ViT-L (SBD)': ('#008000', '-'),
'Ours-ViT-H (SBD)': ('#ff0000', '-'),
'Ours-ViT-B (C+L)': ('#0080ff', '-'),
'Ours-ViT-L (C+L)': ('#8000ff', '-'),
'Ours-ViT-H (C+L)': ('#ff8000', '-'),
'RITM-HRNet18 (SBD)': ('#000000', ':'),
'RITM-HRNet32 (C+L)': ('#444444', ':'),
'FocalClick-SegF-B0 (C+L)': ('#888888', ':'),
'FocalClick-SegF-B3 (C+L)': ('#888888', ':'),
'CDNet-ResNet-34 (SBD)': ('', ':'),
'CDNet-ResNet-34 (C+L)': ('', ':')
}
range_mapper = {'SBD': (65, 96, 3),
'DAVIS': (66, 97, 3),
'Pascal VOC': (66, 100, 3),
'COCO_MVal': (60, 97, 3),
'BraTS': (10, 100, 10),
'OAIZIB': (0,85, 10),
'ssTEM': (5, 100, 10),
'GrabCut': (80, 100, 2),
'Berkeley': (80, 100, 2)
}
def main():
args, cfg = parse_args()
files_list = get_files_list(args, cfg)
# Dict of dicts with mapping dataset_name -> model_name -> results
aggregated_plot_data = defaultdict(dict)
for file in files_list:
with open(file, 'rb') as f:
data = pickle.load(f)
data['all_ious'] = [x[:] if args.n_clicks == -1 else x[:args.n_clicks] for x in data['all_ious']]
aggregated_plot_data[data['dataset_name']][data['model_name']] = np.array(data['all_ious']).mean(0)
for dataset_name, dataset_results in aggregated_plot_data.items():
plt.figure(figsize=(12, 7))
max_clicks = 0
min_val, max_val = 100, -1
for model_name, model_results in dataset_results.items():
if args.n_clicks != -1:
model_results = model_results[:args.n_clicks]
model_results = model_results * 100
min_val = min(min_val, min(model_results))
max_val = max(max_val, max(model_results))
n_clicks = len(model_results)
max_clicks = max(max_clicks, n_clicks)
miou_str = ' '.join([f'mIoU@{click_id}={model_results[click_id-1]:.2%};'
for click_id in [1, 3, 5, 10, 20] if click_id <= len(model_results)])
print(f'{model_name} on {dataset_name}:\n{miou_str}\n')
label = model_name_mapper[model_name] if model_name in model_name_mapper else model_name
color, style = None, None
if label in color_style_mapper:
color, style = color_style_mapper[label]
plt.plot(1 + np.arange(n_clicks), model_results, linewidth=2, label=label, linestyle=style)
if dataset_name == 'PascalVOC':
dataset_name = 'Pascal VOC'
plt.title(f'{dataset_name}', fontsize=22)
plt.grid()
plt.legend(loc=4, fontsize='x-large')
min_val, max_val, step = range_mapper[dataset_name]
plt.yticks(np.arange(min_val, max_val, step=step), fontsize='xx-large')
plt.xticks(1 + np.arange(max_clicks), fontsize='xx-large')
plt.xlabel('Number of Clicks', fontsize='xx-large')
plt.ylabel('mIoU score (%)', fontsize='xx-large')
plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
fig_path = get_target_file_path(args.plots_path, dataset_name)
plt.savefig(str(fig_path))
def get_target_file_path(plots_path, dataset_name):
previous_plots = sorted(plots_path.glob(f'{dataset_name}_*.png'))
if len(previous_plots) == 0:
index = 0
else:
index = int(previous_plots[-1].stem.split('_')[-1]) + 1
return str(plots_path / f'{dataset_name}_{index:03d}.png')
def get_files_list(args, cfg):
if args.folder is not None:
files_list = Path(args.folder).glob('*.pickle')
elif args.files is not None:
files_list = [Path(file) for file in args.files]
elif args.model_dirs is not None:
files_list = []
for folder in args.model_dirs:
folder = Path(folder) / 'plots'
files_list.extend(folder.glob('*.pickle'))
elif args.exp_models is not None:
files_list = []
for rel_exp_path in args.exp_models:
rel_exp_path, checkpoint_prefix = rel_exp_path.split(':')
exp_path_prefix = cfg.EXPS_PATH / 'evaluation_logs' / rel_exp_path
candidates = list(exp_path_prefix.parent.glob(exp_path_prefix.stem + '*'))
assert len(candidates) == 1, "Invalid experiment path."
exp_path = candidates[0]
files_list.extend(sorted((exp_path / 'plots').glob(checkpoint_prefix + '*.pickle')))
if args.mode is not None:
files_list = [file for file in files_list
if any(mode in file.stem for mode in args.mode)]
files_list = [file for file in files_list
if any(dataset in file.stem for dataset in args.datasets)]
return files_list
if __name__ == '__main__':
main()