|
|
import multiprocessing |
|
|
import os |
|
|
from copy import deepcopy |
|
|
from multiprocessing import Pool |
|
|
from typing import Tuple, List, Union, Optional |
|
|
|
|
|
import numpy as np |
|
|
from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \ |
|
|
isfile |
|
|
from nnunetv2.configuration import default_num_processes |
|
|
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter |
|
|
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \ |
|
|
determine_reader_writer_from_file_ending |
|
|
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO |
|
|
|
|
|
from nnunetv2.utilities.json_export import recursive_fix_for_json_export |
|
|
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager |
|
|
|
|
|
|
|
|
def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]): |
|
|
return str(label_or_region) |
|
|
|
|
|
|
|
|
def key_to_label_or_region(key: str): |
|
|
try: |
|
|
return int(key) |
|
|
except ValueError: |
|
|
key = key.replace('(', '') |
|
|
key = key.replace(')', '') |
|
|
split = key.split(',') |
|
|
return tuple([int(i) for i in split if len(i) > 0]) |
|
|
|
|
|
|
|
|
def save_summary_json(results: dict, output_file: str): |
|
|
""" |
|
|
json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit |
|
|
ourselves |
|
|
""" |
|
|
results_converted = deepcopy(results) |
|
|
|
|
|
results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()} |
|
|
|
|
|
for i in range(len(results_converted["metric_per_case"])): |
|
|
results_converted["metric_per_case"][i]['metrics'] = \ |
|
|
{label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k] |
|
|
for k in results["metric_per_case"][i]['metrics'].keys()} |
|
|
|
|
|
save_json(results_converted, output_file, sort_keys=True) |
|
|
|
|
|
|
|
|
def load_summary_json(filename: str): |
|
|
results = load_json(filename) |
|
|
|
|
|
results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()} |
|
|
|
|
|
for i in range(len(results["metric_per_case"])): |
|
|
results["metric_per_case"][i]['metrics'] = \ |
|
|
{key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k] |
|
|
for k in results["metric_per_case"][i]['metrics'].keys()} |
|
|
return results |
|
|
|
|
|
|
|
|
def labels_to_list_of_regions(labels: List[int]): |
|
|
return [(i,) for i in labels] |
|
|
|
|
|
|
|
|
def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray: |
|
|
if np.isscalar(region_or_label): |
|
|
return segmentation == region_or_label |
|
|
else: |
|
|
mask = np.zeros_like(segmentation, dtype=bool) |
|
|
for r in region_or_label: |
|
|
mask[segmentation == r] = True |
|
|
return mask |
|
|
|
|
|
|
|
|
def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None): |
|
|
if ignore_mask is None: |
|
|
use_mask = np.ones_like(mask_ref, dtype=bool) |
|
|
else: |
|
|
use_mask = ~ignore_mask |
|
|
tp = np.sum((mask_ref & mask_pred) & use_mask) |
|
|
fp = np.sum(((~mask_ref) & mask_pred) & use_mask) |
|
|
fn = np.sum((mask_ref & (~mask_pred)) & use_mask) |
|
|
tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask) |
|
|
return tp, fp, fn, tn |
|
|
|
|
|
|
|
|
def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter, |
|
|
labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]], |
|
|
ignore_label: int = None) -> dict: |
|
|
|
|
|
seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file) |
|
|
seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file) |
|
|
|
|
|
|
|
|
ignore_mask = seg_ref == ignore_label if ignore_label is not None else None |
|
|
|
|
|
results = {} |
|
|
results['reference_file'] = reference_file |
|
|
results['prediction_file'] = prediction_file |
|
|
results['metrics'] = {} |
|
|
for r in labels_or_regions: |
|
|
results['metrics'][r] = {} |
|
|
mask_ref = region_or_label_to_mask(seg_ref, r) |
|
|
mask_pred = region_or_label_to_mask(seg_pred, r) |
|
|
tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask) |
|
|
if tp + fp + fn == 0: |
|
|
results['metrics'][r]['Dice'] = np.nan |
|
|
results['metrics'][r]['IoU'] = np.nan |
|
|
else: |
|
|
results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn) |
|
|
results['metrics'][r]['IoU'] = tp / (tp + fp + fn) |
|
|
results['metrics'][r]['FP'] = fp |
|
|
results['metrics'][r]['TP'] = tp |
|
|
results['metrics'][r]['FN'] = fn |
|
|
results['metrics'][r]['TN'] = tn |
|
|
results['metrics'][r]['n_pred'] = fp + tp |
|
|
results['metrics'][r]['n_ref'] = fn + tp |
|
|
return results |
|
|
|
|
|
|
|
|
def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str, |
|
|
image_reader_writer: BaseReaderWriter, |
|
|
file_ending: str, |
|
|
regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]], |
|
|
ignore_label: int = None, |
|
|
num_processes: int = default_num_processes, |
|
|
chill: bool = True) -> dict: |
|
|
""" |
|
|
output_file must end with .json; can be None |
|
|
""" |
|
|
if output_file is not None: |
|
|
assert output_file.endswith('.json'), 'output_file should end with .json' |
|
|
files_pred = subfiles(folder_pred, suffix=file_ending, join=False) |
|
|
files_ref = subfiles(folder_ref, suffix=file_ending, join=False) |
|
|
if not chill: |
|
|
present = [isfile(join(folder_pred, i)) for i in files_ref] |
|
|
assert all(present), "Not all files in folder_pred exist in folder_ref" |
|
|
files_ref = [join(folder_ref, i) for i in files_pred] |
|
|
files_pred = [join(folder_pred, i) for i in files_pred] |
|
|
with multiprocessing.get_context("spawn").Pool(num_processes) as pool: |
|
|
|
|
|
|
|
|
results = pool.starmap( |
|
|
compute_metrics, |
|
|
list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), |
|
|
[ignore_label] * len(files_pred))) |
|
|
) |
|
|
|
|
|
|
|
|
metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys()) |
|
|
means = {} |
|
|
for r in regions_or_labels: |
|
|
means[r] = {} |
|
|
for m in metric_list: |
|
|
means[r][m] = np.nanmean([i['metrics'][r][m] for i in results]) |
|
|
|
|
|
|
|
|
foreground_mean = {} |
|
|
for m in metric_list: |
|
|
values = [] |
|
|
for k in means.keys(): |
|
|
if k == 0 or k == '0': |
|
|
continue |
|
|
values.append(means[k][m]) |
|
|
foreground_mean[m] = np.mean(values) |
|
|
|
|
|
[recursive_fix_for_json_export(i) for i in results] |
|
|
recursive_fix_for_json_export(means) |
|
|
recursive_fix_for_json_export(foreground_mean) |
|
|
result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean} |
|
|
if output_file is not None: |
|
|
save_summary_json(result, output_file) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str, |
|
|
output_file: str = None, |
|
|
num_processes: int = default_num_processes, |
|
|
chill: bool = False): |
|
|
dataset_json = load_json(dataset_json_file) |
|
|
|
|
|
file_ending = dataset_json['file_ending'] |
|
|
|
|
|
|
|
|
example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0] |
|
|
rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)() |
|
|
|
|
|
|
|
|
if output_file is None: |
|
|
output_file = join(folder_pred, 'summary.json') |
|
|
|
|
|
lm = PlansManager(plans_file).get_label_manager(dataset_json) |
|
|
compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, |
|
|
lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label, |
|
|
num_processes, chill=chill) |
|
|
|
|
|
|
|
|
def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]], |
|
|
output_file: str = None, |
|
|
num_processes: int = default_num_processes, |
|
|
ignore_label: int = None, |
|
|
chill: bool = False): |
|
|
example_file = subfiles(folder_ref, join=True)[0] |
|
|
file_ending = os.path.splitext(example_file)[-1] |
|
|
rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True, |
|
|
verbose=False)() |
|
|
|
|
|
if output_file is None: |
|
|
output_file = join(folder_pred, 'summary.json') |
|
|
compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, |
|
|
labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill) |
|
|
|
|
|
|
|
|
def evaluate_folder_entry_point(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') |
|
|
parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') |
|
|
parser.add_argument('-djfile', type=str, required=True, |
|
|
help='dataset.json file') |
|
|
parser.add_argument('-pfile', type=str, required=True, |
|
|
help='plans.json file') |
|
|
parser.add_argument('-o', type=str, required=False, default=None, |
|
|
help='Output file. Optional. Default: pred_folder/summary.json') |
|
|
parser.add_argument('-np', type=int, required=False, default=default_num_processes, |
|
|
help=f'number of processes used. Optional. Default: {default_num_processes}') |
|
|
parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') |
|
|
args = parser.parse_args() |
|
|
compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill) |
|
|
|
|
|
|
|
|
def evaluate_simple_entry_point(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') |
|
|
parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') |
|
|
parser.add_argument('-l', type=int, nargs='+', required=True, |
|
|
help='list of labels') |
|
|
parser.add_argument('-il', type=int, required=False, default=None, |
|
|
help='ignore label') |
|
|
parser.add_argument('-o', type=str, required=False, default=None, |
|
|
help='Output file. Optional. Default: pred_folder/summary.json') |
|
|
parser.add_argument('-np', type=int, required=False, default=default_num_processes, |
|
|
help=f'number of processes used. Optional. Default: {default_num_processes}') |
|
|
parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') |
|
|
|
|
|
args = parser.parse_args() |
|
|
compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr' |
|
|
folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation' |
|
|
output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json' |
|
|
image_reader_writer = SimpleITKIO() |
|
|
file_ending = '.nii.gz' |
|
|
regions = labels_to_list_of_regions([1, 2]) |
|
|
ignore_label = None |
|
|
num_processes = 12 |
|
|
compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label, |
|
|
num_processes) |
|
|
|