import argparse import multiprocessing import shutil from multiprocessing import Pool from typing import Union, Tuple, List, Callable import numpy as np from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \ isdir, save_pickle, load_pickle, save_json from nnunetv2.configuration import default_num_processes from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \ load_summary_json, label_or_region_to_key from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.plans_handling.plans_handler import PlansManager def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray, labels_or_regions: Union[int, Tuple[int, ...], List[Union[int, Tuple[int, ...]]]], background_label: int = 0) -> np.ndarray: mask = np.zeros_like(segmentation, dtype=bool) if not isinstance(labels_or_regions, list): labels_or_regions = [labels_or_regions] for l_or_r in labels_or_regions: mask |= region_or_label_to_mask(segmentation, l_or_r) mask_keep = remove_all_but_largest_component(mask) ret = np.copy(segmentation) # do not modify the input! ret[mask & ~mask_keep] = background_label return ret def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]): for fn, kwargs in zip(pp_fns, pp_fn_kwargs): segmentation = fn(segmentation, **kwargs) return segmentation def load_postprocess_save(segmentation_file: str, output_fname: str, image_reader_writer: BaseReaderWriter, pp_fns: List[Callable], pp_fn_kwargs: List[dict]): seg, props = image_reader_writer.read_seg(segmentation_file) seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs) image_reader_writer.write_seg(seg, output_fname, props) def determine_postprocessing(folder_predictions: str, folder_ref: str, plans_file_or_dict: Union[str, dict], dataset_json_file_or_dict: Union[str, dict], num_processes: int = default_num_processes, keep_postprocessed_files: bool = True): """ Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be used with apply_postprocessing_to_folder. Postprocessed files are saved in folder_predictions/postprocessed. Set keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created and deleted regardless). If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder """ output_folder = join(folder_predictions, 'postprocessed') if plans_file_or_dict is None: expected_plans_file = join(folder_predictions, 'plans.json') if not isfile(expected_plans_file): raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans files should have been " f"created while running nnUNetv2_predict. Sadge.") plans_file_or_dict = load_json(expected_plans_file) plans_manager = PlansManager(plans_file_or_dict) if dataset_json_file_or_dict is None: expected_dataset_json_file = join(folder_predictions, 'dataset.json') if not isfile(expected_dataset_json_file): raise RuntimeError( f"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been " f"created while running nnUNetv2_predict. Sadge.") dataset_json_file_or_dict = load_json(expected_dataset_json_file) if not isinstance(dataset_json_file_or_dict, dict): dataset_json = load_json(dataset_json_file_or_dict) else: dataset_json = dataset_json_file_or_dict rw = plans_manager.image_reader_writer_class() label_manager = plans_manager.get_label_manager(dataset_json) labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False) ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False) # we should print a warning if not all files from folder_ref are present in folder_predictions if not all([i in predicted_files for i in ref_files]): print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing ' f'should always be done on the entire dataset!') # before we start we should evaluate the imaegs in the source folder if not isfile(join(folder_predictions, 'summary.json')): compute_metrics_on_folder(folder_ref, folder_predictions, join(folder_predictions, 'summary.json'), rw, dataset_json['file_ending'], labels_or_regions, label_manager.ignore_label, num_processes) # we save the postprocessing functions in here pp_fns = [] pp_fn_kwargs = [] # pool party! with multiprocessing.get_context("spawn").Pool(num_processes) as pool: # now let's see whether removing all but the largest foreground region improves the scores output_here = join(output_folder, 'temp', 'keep_largest_fg') maybe_mkdir_p(output_here) pp_fn = remove_all_but_largest_component_from_segmentation kwargs = { 'labels_or_regions': label_manager.foreground_labels, } pool.starmap( load_postprocess_save, zip( [join(folder_predictions, i) for i in predicted_files], [join(output_here, i) for i in predicted_files], [rw] * len(predicted_files), [[pp_fn]] * len(predicted_files), [[kwargs]] * len(predicted_files) ) ) compute_metrics_on_folder(folder_ref, output_here, join(output_here, 'summary.json'), rw, dataset_json['file_ending'], labels_or_regions, label_manager.ignore_label, num_processes) # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far # that if a single class got worse as a result we won't do this. We can change this in the future but right now I # prefer to do it this way baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) pp_results = load_summary_json(join(output_here, 'summary.json')) do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice'] if do_this: for class_id in pp_results['mean'].keys(): if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']: do_this = False break if do_this: print(f'Results were improved by removing all but the largest foreground region. ' f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} ' f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}') source = output_here pp_fns.append(pp_fn) pp_fn_kwargs.append(kwargs) else: print(f'Removing all but the largest foreground region did not improve results!') source = folder_predictions # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and # then evaluate for each class whether this improved results. This is no longer possible because we now support # region-based predictions and regions can overlap, causing interactions # in principle the order with which the postprocessing is applied to the regions matter as well and should be # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think) # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about # region_class_order) # 2023_02_06: I hate myself for the comment above. Thanks past me if len(labels_or_regions) > 1: for label_or_region in labels_or_regions: pp_fn = remove_all_but_largest_component_from_segmentation kwargs = { 'labels_or_regions': label_or_region, } output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion') maybe_mkdir_p(output_here) pool.starmap( load_postprocess_save, zip( [join(source, i) for i in predicted_files], [join(output_here, i) for i in predicted_files], [rw] * len(predicted_files), [[pp_fn]] * len(predicted_files), [[kwargs]] * len(predicted_files) ) ) compute_metrics_on_folder(folder_ref, output_here, join(output_here, 'summary.json'), rw, dataset_json['file_ending'], labels_or_regions, label_manager.ignore_label, num_processes) baseline_results = load_summary_json(join(source, 'summary.json')) pp_results = load_summary_json(join(output_here, 'summary.json')) do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice'] if do_this: print(f'Results were improved by removing all but the largest component for {label_or_region}. ' f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')): shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')) shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), ) source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest') pp_fns.append(pp_fn) pp_fn_kwargs.append(kwargs) else: print(f'Removing all but the largest component for {label_or_region} did not improve results! ' f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)] save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl')) baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) final_results = load_summary_json(join(output_folder, 'summary.json')) tmp = { 'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']}, 'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']}, 'postprocessing_fns': [i.__name__ for i in pp_fns], 'postprocessing_kwargs': pp_fn_kwargs, } # json is very annoying. Can't handle tuples as dict keys. tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in tmp['input_folder']['mean'].keys()} tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in tmp['postprocessed']['mean'].keys()} # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" recursive_fix_for_json_export(tmp) save_json(tmp, join(folder_predictions, 'postprocessing.json')) shutil.rmtree(join(output_folder, 'temp')) if not keep_postprocessed_files: shutil.rmtree(output_folder) return pp_fns, pp_fn_kwargs def apply_postprocessing_to_folder(input_folder: str, output_folder: str, pp_fns: List[Callable], pp_fn_kwargs: List[dict], plans_file_or_dict: Union[str, dict] = None, dataset_json_file_or_dict: Union[str, dict] = None, num_processes=8) -> None: """ If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder """ if plans_file_or_dict is None: expected_plans_file = join(input_folder, 'plans.json') if not isfile(expected_plans_file): raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been " f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply " f"postprocessing to was create from an ensemble then just specify one of the " f"plans files of the ensemble members in plans_file_or_dict") plans_file_or_dict = load_json(expected_plans_file) plans_manager = PlansManager(plans_file_or_dict) if dataset_json_file_or_dict is None: expected_dataset_json_file = join(input_folder, 'dataset.json') if not isfile(expected_dataset_json_file): raise RuntimeError( f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been " f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.") dataset_json_file_or_dict = load_json(expected_dataset_json_file) if not isinstance(dataset_json_file_or_dict, dict): dataset_json = load_json(dataset_json_file_or_dict) else: dataset_json = dataset_json_file_or_dict rw = plans_manager.image_reader_writer_class() maybe_mkdir_p(output_folder) with multiprocessing.get_context("spawn").Pool(num_processes) as p: files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False) _ = p.starmap(load_postprocess_save, zip( [join(input_folder, i) for i in files], [join(output_folder, i) for i in files], [rw] * len(files), [pp_fns] * len(files), [pp_fn_kwargs] * len(files) ) ) def entry_point_determine_postprocessing_folder(): parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.') parser.add_argument('-i', type=str, required=True, help='Input folder') parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels') parser.add_argument('-plans_json', type=str, required=False, default=None, help="plans file to use. If not specified we will look for the plans.json file in the " "input folder (input_folder/plans.json)") parser.add_argument('-dataset_json', type=str, required=False, default=None, help="dataset.json file to use. If not specified we will look for the dataset.json file in the " "input folder (input_folder/dataset.json)") parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f"number of processes to use. Default: {default_num_processes}") parser.add_argument('--remove_postprocessed', action='store_true', required=False, help='set this is you don\'t want to keep the postprocessed files') args = parser.parse_args() determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np, not args.remove_postprocessed) def entry_point_apply_postprocessing(): parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.') parser.add_argument('-i', type=str, required=True, help='Input folder') parser.add_argument('-o', type=str, required=True, help='Output folder') parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f"number of processes to use. Default: {default_num_processes}") parser.add_argument('-plans_json', type=str, required=False, default=None, help="plans file to use. If not specified we will look for the plans.json file in the " "input folder (input_folder/plans.json)") parser.add_argument('-dataset_json', type=str, required=False, default=None, help="dataset.json file to use. If not specified we will look for the dataset.json file in the " "input folder (input_folder/dataset.json)") args = parser.parse_args() pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file) apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np) if __name__ == '__main__': trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres' labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr') plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) dataset_json = load_json(join(trained_model_folder, 'dataset.json')) folds = (0, 1, 2, 3, 4) label_manager = plans_manager.get_label_manager(dataset_json) merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False) fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans, dataset_json, 8, keep_postprocessed_files=True) save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl')) fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl')) apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs, plans_manager.plans, dataset_json, 8) compute_metrics_on_folder(labelstr, merged_output_folder + '_pp', join(merged_output_folder + '_pp', 'summary.json'), plans_manager.image_reader_writer_class(), dataset_json['file_ending'], label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels, label_manager.ignore_label, 8)