FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
import argparse
import multiprocessing
import shutil
from multiprocessing import Pool
from typing import Union, Tuple, List, Callable
import numpy as np
from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component
from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \
isdir, save_pickle, load_pickle, save_json
from nnunetv2.configuration import default_num_processes
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \
load_summary_json, label_or_region_to_key
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray,
labels_or_regions: Union[int, Tuple[int, ...],
List[Union[int, Tuple[int, ...]]]],
background_label: int = 0) -> np.ndarray:
mask = np.zeros_like(segmentation, dtype=bool)
if not isinstance(labels_or_regions, list):
labels_or_regions = [labels_or_regions]
for l_or_r in labels_or_regions:
mask |= region_or_label_to_mask(segmentation, l_or_r)
mask_keep = remove_all_but_largest_component(mask)
ret = np.copy(segmentation) # do not modify the input!
ret[mask & ~mask_keep] = background_label
return ret
def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]):
for fn, kwargs in zip(pp_fns, pp_fn_kwargs):
segmentation = fn(segmentation, **kwargs)
return segmentation
def load_postprocess_save(segmentation_file: str,
output_fname: str,
image_reader_writer: BaseReaderWriter,
pp_fns: List[Callable],
pp_fn_kwargs: List[dict]):
seg, props = image_reader_writer.read_seg(segmentation_file)
seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs)
image_reader_writer.write_seg(seg, output_fname, props)
def determine_postprocessing(folder_predictions: str,
folder_ref: str,
plans_file_or_dict: Union[str, dict],
dataset_json_file_or_dict: Union[str, dict],
num_processes: int = default_num_processes,
keep_postprocessed_files: bool = True):
"""
Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be
used with apply_postprocessing_to_folder.
Postprocessed files are saved in folder_predictions/postprocessed. Set
keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created
and deleted regardless).
If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
"""
output_folder = join(folder_predictions, 'postprocessed')
if plans_file_or_dict is None:
expected_plans_file = join(folder_predictions, 'plans.json')
if not isfile(expected_plans_file):
raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans files should have been "
f"created while running nnUNetv2_predict. Sadge.")
plans_file_or_dict = load_json(expected_plans_file)
plans_manager = PlansManager(plans_file_or_dict)
if dataset_json_file_or_dict is None:
expected_dataset_json_file = join(folder_predictions, 'dataset.json')
if not isfile(expected_dataset_json_file):
raise RuntimeError(
f"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been "
f"created while running nnUNetv2_predict. Sadge.")
dataset_json_file_or_dict = load_json(expected_dataset_json_file)
if not isinstance(dataset_json_file_or_dict, dict):
dataset_json = load_json(dataset_json_file_or_dict)
else:
dataset_json = dataset_json_file_or_dict
rw = plans_manager.image_reader_writer_class()
label_manager = plans_manager.get_label_manager(dataset_json)
labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels
predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False)
ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False)
# we should print a warning if not all files from folder_ref are present in folder_predictions
if not all([i in predicted_files for i in ref_files]):
print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing '
f'should always be done on the entire dataset!')
# before we start we should evaluate the imaegs in the source folder
if not isfile(join(folder_predictions, 'summary.json')):
compute_metrics_on_folder(folder_ref,
folder_predictions,
join(folder_predictions, 'summary.json'),
rw,
dataset_json['file_ending'],
labels_or_regions,
label_manager.ignore_label,
num_processes)
# we save the postprocessing functions in here
pp_fns = []
pp_fn_kwargs = []
# pool party!
with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
# now let's see whether removing all but the largest foreground region improves the scores
output_here = join(output_folder, 'temp', 'keep_largest_fg')
maybe_mkdir_p(output_here)
pp_fn = remove_all_but_largest_component_from_segmentation
kwargs = {
'labels_or_regions': label_manager.foreground_labels,
}
pool.starmap(
load_postprocess_save,
zip(
[join(folder_predictions, i) for i in predicted_files],
[join(output_here, i) for i in predicted_files],
[rw] * len(predicted_files),
[[pp_fn]] * len(predicted_files),
[[kwargs]] * len(predicted_files)
)
)
compute_metrics_on_folder(folder_ref,
output_here,
join(output_here, 'summary.json'),
rw,
dataset_json['file_ending'],
labels_or_regions,
label_manager.ignore_label,
num_processes)
# now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far
# that if a single class got worse as a result we won't do this. We can change this in the future but right now I
# prefer to do it this way
baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
pp_results = load_summary_json(join(output_here, 'summary.json'))
do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice']
if do_this:
for class_id in pp_results['mean'].keys():
if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']:
do_this = False
break
if do_this:
print(f'Results were improved by removing all but the largest foreground region. '
f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} '
f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}')
source = output_here
pp_fns.append(pp_fn)
pp_fn_kwargs.append(kwargs)
else:
print(f'Removing all but the largest foreground region did not improve results!')
source = folder_predictions
# in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and
# then evaluate for each class whether this improved results. This is no longer possible because we now support
# region-based predictions and regions can overlap, causing interactions
# in principle the order with which the postprocessing is applied to the regions matter as well and should be
# investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think)
# to stick to the order in which they are declared in dataset.json (if you want to think about it then think about
# region_class_order)
# 2023_02_06: I hate myself for the comment above. Thanks past me
if len(labels_or_regions) > 1:
for label_or_region in labels_or_regions:
pp_fn = remove_all_but_largest_component_from_segmentation
kwargs = {
'labels_or_regions': label_or_region,
}
output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion')
maybe_mkdir_p(output_here)
pool.starmap(
load_postprocess_save,
zip(
[join(source, i) for i in predicted_files],
[join(output_here, i) for i in predicted_files],
[rw] * len(predicted_files),
[[pp_fn]] * len(predicted_files),
[[kwargs]] * len(predicted_files)
)
)
compute_metrics_on_folder(folder_ref,
output_here,
join(output_here, 'summary.json'),
rw,
dataset_json['file_ending'],
labels_or_regions,
label_manager.ignore_label,
num_processes)
baseline_results = load_summary_json(join(source, 'summary.json'))
pp_results = load_summary_json(join(output_here, 'summary.json'))
do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice']
if do_this:
print(f'Results were improved by removing all but the largest component for {label_or_region}. '
f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')):
shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'))
shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), )
source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')
pp_fns.append(pp_fn)
pp_fn_kwargs.append(kwargs)
else:
print(f'Removing all but the largest component for {label_or_region} did not improve results! '
f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
[shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)]
save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl'))
baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
final_results = load_summary_json(join(output_folder, 'summary.json'))
tmp = {
'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']},
'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']},
'postprocessing_fns': [i.__name__ for i in pp_fns],
'postprocessing_kwargs': pp_fn_kwargs,
}
# json is very annoying. Can't handle tuples as dict keys.
tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in
tmp['input_folder']['mean'].keys()}
tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in
tmp['postprocessed']['mean'].keys()}
# did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable"
recursive_fix_for_json_export(tmp)
save_json(tmp, join(folder_predictions, 'postprocessing.json'))
shutil.rmtree(join(output_folder, 'temp'))
if not keep_postprocessed_files:
shutil.rmtree(output_folder)
return pp_fns, pp_fn_kwargs
def apply_postprocessing_to_folder(input_folder: str,
output_folder: str,
pp_fns: List[Callable],
pp_fn_kwargs: List[dict],
plans_file_or_dict: Union[str, dict] = None,
dataset_json_file_or_dict: Union[str, dict] = None,
num_processes=8) -> None:
"""
If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
"""
if plans_file_or_dict is None:
expected_plans_file = join(input_folder, 'plans.json')
if not isfile(expected_plans_file):
raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been "
f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply "
f"postprocessing to was create from an ensemble then just specify one of the "
f"plans files of the ensemble members in plans_file_or_dict")
plans_file_or_dict = load_json(expected_plans_file)
plans_manager = PlansManager(plans_file_or_dict)
if dataset_json_file_or_dict is None:
expected_dataset_json_file = join(input_folder, 'dataset.json')
if not isfile(expected_dataset_json_file):
raise RuntimeError(
f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been "
f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.")
dataset_json_file_or_dict = load_json(expected_dataset_json_file)
if not isinstance(dataset_json_file_or_dict, dict):
dataset_json = load_json(dataset_json_file_or_dict)
else:
dataset_json = dataset_json_file_or_dict
rw = plans_manager.image_reader_writer_class()
maybe_mkdir_p(output_folder)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False)
_ = p.starmap(load_postprocess_save,
zip(
[join(input_folder, i) for i in files],
[join(output_folder, i) for i in files],
[rw] * len(files),
[pp_fns] * len(files),
[pp_fn_kwargs] * len(files)
)
)
def entry_point_determine_postprocessing_folder():
parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.')
parser.add_argument('-i', type=str, required=True, help='Input folder')
parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels')
parser.add_argument('-plans_json', type=str, required=False, default=None,
help="plans file to use. If not specified we will look for the plans.json file in the "
"input folder (input_folder/plans.json)")
parser.add_argument('-dataset_json', type=str, required=False, default=None,
help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
"input folder (input_folder/dataset.json)")
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f"number of processes to use. Default: {default_num_processes}")
parser.add_argument('--remove_postprocessed', action='store_true', required=False,
help='set this is you don\'t want to keep the postprocessed files')
args = parser.parse_args()
determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np,
not args.remove_postprocessed)
def entry_point_apply_postprocessing():
parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.')
parser.add_argument('-i', type=str, required=True, help='Input folder')
parser.add_argument('-o', type=str, required=True, help='Output folder')
parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f"number of processes to use. Default: {default_num_processes}")
parser.add_argument('-plans_json', type=str, required=False, default=None,
help="plans file to use. If not specified we will look for the plans.json file in the "
"input folder (input_folder/plans.json)")
parser.add_argument('-dataset_json', type=str, required=False, default=None,
help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
"input folder (input_folder/dataset.json)")
args = parser.parse_args()
pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file)
apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np)
if __name__ == '__main__':
trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres'
labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr')
plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
folds = (0, 1, 2, 3, 4)
label_manager = plans_manager.get_label_manager(dataset_json)
merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False)
fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans,
dataset_json, 8, keep_postprocessed_files=True)
save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl'))
fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl'))
apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs,
plans_manager.plans, dataset_json,
8)
compute_metrics_on_folder(labelstr,
merged_output_folder + '_pp',
join(merged_output_folder + '_pp', 'summary.json'),
plans_manager.image_reader_writer_class(),
dataset_json['file_ending'],
label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels,
label_manager.ignore_label,
8)