File size: 12,556 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import multiprocessing
import os
from copy import deepcopy
from multiprocessing import Pool
from typing import Tuple, List, Union, Optional
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \
isfile
from nnunetv2.configuration import default_num_processes
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \
determine_reader_writer_from_file_ending
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
# the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]):
return str(label_or_region)
def key_to_label_or_region(key: str):
try:
return int(key)
except ValueError:
key = key.replace('(', '')
key = key.replace(')', '')
split = key.split(',')
return tuple([int(i) for i in split if len(i) > 0])
def save_summary_json(results: dict, output_file: str):
"""
json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit
ourselves
"""
results_converted = deepcopy(results)
# convert keys in mean metrics
results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()}
# convert metric_per_case
for i in range(len(results_converted["metric_per_case"])):
results_converted["metric_per_case"][i]['metrics'] = \
{label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k]
for k in results["metric_per_case"][i]['metrics'].keys()}
# sort_keys=True will make foreground_mean the first entry and thus easy to spot
save_json(results_converted, output_file, sort_keys=True)
def load_summary_json(filename: str):
results = load_json(filename)
# convert keys in mean metrics
results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()}
# convert metric_per_case
for i in range(len(results["metric_per_case"])):
results["metric_per_case"][i]['metrics'] = \
{key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k]
for k in results["metric_per_case"][i]['metrics'].keys()}
return results
def labels_to_list_of_regions(labels: List[int]):
return [(i,) for i in labels]
def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray:
if np.isscalar(region_or_label):
return segmentation == region_or_label
else:
mask = np.zeros_like(segmentation, dtype=bool)
for r in region_or_label:
mask[segmentation == r] = True
return mask
def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None):
if ignore_mask is None:
use_mask = np.ones_like(mask_ref, dtype=bool)
else:
use_mask = ~ignore_mask
tp = np.sum((mask_ref & mask_pred) & use_mask)
fp = np.sum(((~mask_ref) & mask_pred) & use_mask)
fn = np.sum((mask_ref & (~mask_pred)) & use_mask)
tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask)
return tp, fp, fn, tn
def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter,
labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]],
ignore_label: int = None) -> dict:
# load images
seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)
seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)
# spacing = seg_ref_dict['spacing']
ignore_mask = seg_ref == ignore_label if ignore_label is not None else None
results = {}
results['reference_file'] = reference_file
results['prediction_file'] = prediction_file
results['metrics'] = {}
for r in labels_or_regions:
results['metrics'][r] = {}
mask_ref = region_or_label_to_mask(seg_ref, r)
mask_pred = region_or_label_to_mask(seg_pred, r)
tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)
if tp + fp + fn == 0:
results['metrics'][r]['Dice'] = np.nan
results['metrics'][r]['IoU'] = np.nan
else:
results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn)
results['metrics'][r]['IoU'] = tp / (tp + fp + fn)
results['metrics'][r]['FP'] = fp
results['metrics'][r]['TP'] = tp
results['metrics'][r]['FN'] = fn
results['metrics'][r]['TN'] = tn
results['metrics'][r]['n_pred'] = fp + tp
results['metrics'][r]['n_ref'] = fn + tp
return results
def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str,
image_reader_writer: BaseReaderWriter,
file_ending: str,
regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]],
ignore_label: int = None,
num_processes: int = default_num_processes,
chill: bool = True) -> dict:
"""
output_file must end with .json; can be None
"""
if output_file is not None:
assert output_file.endswith('.json'), 'output_file should end with .json'
files_pred = subfiles(folder_pred, suffix=file_ending, join=False)
files_ref = subfiles(folder_ref, suffix=file_ending, join=False)
if not chill:
present = [isfile(join(folder_pred, i)) for i in files_ref]
assert all(present), "Not all files in folder_pred exist in folder_ref"
files_ref = [join(folder_ref, i) for i in files_pred]
files_pred = [join(folder_pred, i) for i in files_pred]
with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
# for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))):
# compute_metrics(*i)
results = pool.starmap(
compute_metrics,
list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred),
[ignore_label] * len(files_pred)))
)
# mean metric per class
metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys())
means = {}
for r in regions_or_labels:
means[r] = {}
for m in metric_list:
means[r][m] = np.nanmean([i['metrics'][r][m] for i in results])
# foreground mean
foreground_mean = {}
for m in metric_list:
values = []
for k in means.keys():
if k == 0 or k == '0':
continue
values.append(means[k][m])
foreground_mean[m] = np.mean(values)
[recursive_fix_for_json_export(i) for i in results]
recursive_fix_for_json_export(means)
recursive_fix_for_json_export(foreground_mean)
result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean}
if output_file is not None:
save_summary_json(result, output_file)
return result
# print('DONE')
def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str,
output_file: str = None,
num_processes: int = default_num_processes,
chill: bool = False):
dataset_json = load_json(dataset_json_file)
# get file ending
file_ending = dataset_json['file_ending']
# get reader writer class
example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0]
rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)()
# maybe auto set output file
if output_file is None:
output_file = join(folder_pred, 'summary.json')
lm = PlansManager(plans_file).get_label_manager(dataset_json)
compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,
lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label,
num_processes, chill=chill)
def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]],
output_file: str = None,
num_processes: int = default_num_processes,
ignore_label: int = None,
chill: bool = False):
example_file = subfiles(folder_ref, join=True)[0]
file_ending = os.path.splitext(example_file)[-1]
rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True,
verbose=False)()
# maybe auto set output file
if output_file is None:
output_file = join(folder_pred, 'summary.json')
compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,
labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill)
def evaluate_folder_entry_point():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')
parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')
parser.add_argument('-djfile', type=str, required=True,
help='dataset.json file')
parser.add_argument('-pfile', type=str, required=True,
help='plans.json file')
parser.add_argument('-o', type=str, required=False, default=None,
help='Output file. Optional. Default: pred_folder/summary.json')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f'number of processes used. Optional. Default: {default_num_processes}')
parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt')
args = parser.parse_args()
compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill)
def evaluate_simple_entry_point():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')
parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')
parser.add_argument('-l', type=int, nargs='+', required=True,
help='list of labels')
parser.add_argument('-il', type=int, required=False, default=None,
help='ignore label')
parser.add_argument('-o', type=str, required=False, default=None,
help='Output file. Optional. Default: pred_folder/summary.json')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f'number of processes used. Optional. Default: {default_num_processes}')
parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt')
args = parser.parse_args()
compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill)
if __name__ == '__main__':
folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr'
folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation'
output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json'
image_reader_writer = SimpleITKIO()
file_ending = '.nii.gz'
regions = labels_to_list_of_regions([1, 2])
ignore_label = None
num_processes = 12
compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label,
num_processes)
|