import multiprocessing import os from copy import deepcopy from multiprocessing import Pool from typing import Tuple, List, Union, Optional import numpy as np from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \ isfile from nnunetv2.configuration import default_num_processes from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \ determine_reader_writer_from_file_ending from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO # the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.plans_handling.plans_handler import PlansManager def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]): return str(label_or_region) def key_to_label_or_region(key: str): try: return int(key) except ValueError: key = key.replace('(', '') key = key.replace(')', '') split = key.split(',') return tuple([int(i) for i in split if len(i) > 0]) def save_summary_json(results: dict, output_file: str): """ json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit ourselves """ results_converted = deepcopy(results) # convert keys in mean metrics results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()} # convert metric_per_case for i in range(len(results_converted["metric_per_case"])): results_converted["metric_per_case"][i]['metrics'] = \ {label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k] for k in results["metric_per_case"][i]['metrics'].keys()} # sort_keys=True will make foreground_mean the first entry and thus easy to spot save_json(results_converted, output_file, sort_keys=True) def load_summary_json(filename: str): results = load_json(filename) # convert keys in mean metrics results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()} # convert metric_per_case for i in range(len(results["metric_per_case"])): results["metric_per_case"][i]['metrics'] = \ {key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k] for k in results["metric_per_case"][i]['metrics'].keys()} return results def labels_to_list_of_regions(labels: List[int]): return [(i,) for i in labels] def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray: if np.isscalar(region_or_label): return segmentation == region_or_label else: mask = np.zeros_like(segmentation, dtype=bool) for r in region_or_label: mask[segmentation == r] = True return mask def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None): if ignore_mask is None: use_mask = np.ones_like(mask_ref, dtype=bool) else: use_mask = ~ignore_mask tp = np.sum((mask_ref & mask_pred) & use_mask) fp = np.sum(((~mask_ref) & mask_pred) & use_mask) fn = np.sum((mask_ref & (~mask_pred)) & use_mask) tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask) return tp, fp, fn, tn def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter, labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]], ignore_label: int = None) -> dict: # load images seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file) seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file) # spacing = seg_ref_dict['spacing'] ignore_mask = seg_ref == ignore_label if ignore_label is not None else None results = {} results['reference_file'] = reference_file results['prediction_file'] = prediction_file results['metrics'] = {} for r in labels_or_regions: results['metrics'][r] = {} mask_ref = region_or_label_to_mask(seg_ref, r) mask_pred = region_or_label_to_mask(seg_pred, r) tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask) if tp + fp + fn == 0: results['metrics'][r]['Dice'] = np.nan results['metrics'][r]['IoU'] = np.nan else: results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn) results['metrics'][r]['IoU'] = tp / (tp + fp + fn) results['metrics'][r]['FP'] = fp results['metrics'][r]['TP'] = tp results['metrics'][r]['FN'] = fn results['metrics'][r]['TN'] = tn results['metrics'][r]['n_pred'] = fp + tp results['metrics'][r]['n_ref'] = fn + tp return results def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str, image_reader_writer: BaseReaderWriter, file_ending: str, regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]], ignore_label: int = None, num_processes: int = default_num_processes, chill: bool = True) -> dict: """ output_file must end with .json; can be None """ if output_file is not None: assert output_file.endswith('.json'), 'output_file should end with .json' files_pred = subfiles(folder_pred, suffix=file_ending, join=False) files_ref = subfiles(folder_ref, suffix=file_ending, join=False) if not chill: present = [isfile(join(folder_pred, i)) for i in files_ref] assert all(present), "Not all files in folder_pred exist in folder_ref" files_ref = [join(folder_ref, i) for i in files_pred] files_pred = [join(folder_pred, i) for i in files_pred] with multiprocessing.get_context("spawn").Pool(num_processes) as pool: # for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))): # compute_metrics(*i) results = pool.starmap( compute_metrics, list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))) ) # mean metric per class metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys()) means = {} for r in regions_or_labels: means[r] = {} for m in metric_list: means[r][m] = np.nanmean([i['metrics'][r][m] for i in results]) # foreground mean foreground_mean = {} for m in metric_list: values = [] for k in means.keys(): if k == 0 or k == '0': continue values.append(means[k][m]) foreground_mean[m] = np.mean(values) [recursive_fix_for_json_export(i) for i in results] recursive_fix_for_json_export(means) recursive_fix_for_json_export(foreground_mean) result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean} if output_file is not None: save_summary_json(result, output_file) return result # print('DONE') def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str, output_file: str = None, num_processes: int = default_num_processes, chill: bool = False): dataset_json = load_json(dataset_json_file) # get file ending file_ending = dataset_json['file_ending'] # get reader writer class example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0] rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)() # maybe auto set output file if output_file is None: output_file = join(folder_pred, 'summary.json') lm = PlansManager(plans_file).get_label_manager(dataset_json) compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label, num_processes, chill=chill) def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]], output_file: str = None, num_processes: int = default_num_processes, ignore_label: int = None, chill: bool = False): example_file = subfiles(folder_ref, join=True)[0] file_ending = os.path.splitext(example_file)[-1] rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True, verbose=False)() # maybe auto set output file if output_file is None: output_file = join(folder_pred, 'summary.json') compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill) def evaluate_folder_entry_point(): import argparse parser = argparse.ArgumentParser() parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') parser.add_argument('-djfile', type=str, required=True, help='dataset.json file') parser.add_argument('-pfile', type=str, required=True, help='plans.json file') parser.add_argument('-o', type=str, required=False, default=None, help='Output file. Optional. Default: pred_folder/summary.json') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f'number of processes used. Optional. Default: {default_num_processes}') parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') args = parser.parse_args() compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill) def evaluate_simple_entry_point(): import argparse parser = argparse.ArgumentParser() parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') parser.add_argument('-l', type=int, nargs='+', required=True, help='list of labels') parser.add_argument('-il', type=int, required=False, default=None, help='ignore label') parser.add_argument('-o', type=str, required=False, default=None, help='Output file. Optional. Default: pred_folder/summary.json') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f'number of processes used. Optional. Default: {default_num_processes}') parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') args = parser.parse_args() compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill) if __name__ == '__main__': folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr' folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation' output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json' image_reader_writer = SimpleITKIO() file_ending = '.nii.gz' regions = labels_to_list_of_regions([1, 2]) ignore_label = None num_processes = 12 compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label, num_processes)