| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from multiprocessing.pool import Pool |
| from time import sleep |
|
|
| import matplotlib |
| from nnunet.postprocessing.connected_components import determine_postprocessing |
| from nnunet.training.data_augmentation.default_data_augmentation import get_default_augmentation |
| from nnunet.training.dataloading.dataset_loading import DataLoader3D, unpack_dataset |
| from nnunet.evaluation.evaluator import aggregate_scores |
| from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer |
| from nnunet.network_architecture.neural_network import SegmentationNetwork |
| from nnunet.paths import network_training_output_dir |
| from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax |
| from batchgenerators.utilities.file_and_folder_operations import * |
| import numpy as np |
| from nnunet.utilities.one_hot_encoding import to_one_hot |
| import shutil |
|
|
| matplotlib.use("agg") |
|
|
|
|
| class nnUNetTrainerCascadeFullRes(nnUNetTrainer): |
| def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, |
| unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainer", fp16=False): |
| super(nnUNetTrainerCascadeFullRes, self).__init__(plans_file, fold, output_folder, dataset_directory, |
| batch_dice, stage, unpack_data, deterministic, fp16) |
| self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, |
| deterministic, previous_trainer, fp16) |
|
|
| if self.output_folder is not None: |
| task = self.output_folder.split("/")[-3] |
| plans_identifier = self.output_folder.split("/")[-2].split("__")[-1] |
|
|
| folder_with_segs_prev_stage = join(network_training_output_dir, "3d_lowres", |
| task, previous_trainer + "__" + plans_identifier, "pred_next_stage") |
| if not isdir(folder_with_segs_prev_stage): |
| raise RuntimeError( |
| "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the " |
| "segmentations for the next stage") |
| self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage |
| |
| |
| else: |
| self.folder_with_segs_from_prev_stage = None |
|
|
| def do_split(self): |
| super(nnUNetTrainerCascadeFullRes, self).do_split() |
| for k in self.dataset: |
| self.dataset[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, |
| k + "_segFromPrevStage.npz") |
| assert isfile(self.dataset[k]['seg_from_prev_stage_file']), \ |
| "seg from prev stage missing: %s" % (self.dataset[k]['seg_from_prev_stage_file']) |
| for k in self.dataset_val: |
| self.dataset_val[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, |
| k + "_segFromPrevStage.npz") |
| for k in self.dataset_tr: |
| self.dataset_tr[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, |
| k + "_segFromPrevStage.npz") |
|
|
| def get_basic_generators(self): |
| self.load_dataset() |
| self.do_split() |
| if self.threeD: |
| dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, |
| True, oversample_foreground_percent=self.oversample_foreground_percent) |
| dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True, |
| oversample_foreground_percent=self.oversample_foreground_percent) |
| else: |
| raise NotImplementedError |
| return dl_tr, dl_val |
|
|
| def process_plans(self, plans): |
| super(nnUNetTrainerCascadeFullRes, self).process_plans(plans) |
| self.num_input_channels += (self.num_classes - 1) |
|
|
| def setup_DA_params(self): |
| super().setup_DA_params() |
| self.data_aug_params['move_last_seg_chanel_to_data'] = True |
| self.data_aug_params['cascade_do_cascade_augmentations'] = True |
|
|
| self.data_aug_params['cascade_random_binary_transform_p'] = 0.4 |
| self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1 |
| self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8) |
|
|
| self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2 |
| self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15 |
| self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0 |
|
|
| |
| |
| self.data_aug_params['selected_seg_channels'] = [0, 1] |
| |
| self.data_aug_params['all_segmentation_labels'] = list(range(1, self.num_classes)) |
|
|
| def initialize(self, training=True, force_load_plans=False): |
| """ |
| For prediction of test cases just set training=False, this will prevent loading of training data and |
| training batchgenerator initialization |
| :param training: |
| :return: |
| """ |
| if force_load_plans or (self.plans is None): |
| self.load_plans_file() |
|
|
| self.process_plans(self.plans) |
|
|
| self.setup_DA_params() |
|
|
| self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + |
| "_stage%d" % self.stage) |
| if training: |
| self.setup_DA_params() |
|
|
| if self.folder_with_preprocessed_data is not None: |
| self.dl_tr, self.dl_val = self.get_basic_generators() |
|
|
| if self.unpack_data: |
| print("unpacking dataset") |
| unpack_dataset(self.folder_with_preprocessed_data) |
| print("done") |
| else: |
| print( |
| "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " |
| "will wait all winter for your model to finish!") |
|
|
| self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, |
| self.data_aug_params[ |
| 'patch_size_for_spatialtransform'], |
| self.data_aug_params) |
| self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys()))) |
| self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys()))) |
| else: |
| pass |
| self.initialize_network() |
| assert isinstance(self.network, SegmentationNetwork) |
| self.was_initialized = True |
|
|
| def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, |
| step_size: float = 0.5, |
| save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, |
| validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, |
| segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): |
|
|
| current_mode = self.network.training |
| self.network.eval() |
|
|
| assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" |
| if self.dataset_val is None: |
| self.load_dataset() |
| self.do_split() |
|
|
| if segmentation_export_kwargs is None: |
| if 'segmentation_export_params' in self.plans.keys(): |
| force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] |
| interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] |
| interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] |
| else: |
| force_separate_z = None |
| interpolation_order = 1 |
| interpolation_order_z = 0 |
| else: |
| force_separate_z = segmentation_export_kwargs['force_separate_z'] |
| interpolation_order = segmentation_export_kwargs['interpolation_order'] |
| interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] |
|
|
| output_folder = join(self.output_folder, validation_folder_name) |
| maybe_mkdir_p(output_folder) |
|
|
| if do_mirroring: |
| mirror_axes = self.data_aug_params['mirror_axes'] |
| else: |
| mirror_axes = () |
|
|
| pred_gt_tuples = [] |
|
|
| export_pool = Pool(2) |
| results = [] |
|
|
| transpose_backward = self.plans.get('transpose_backward') |
|
|
| for k in self.dataset_val.keys(): |
| properties = load_pickle(self.dataset[k]['properties_file']) |
| data = np.load(self.dataset[k]['data_file'])['data'] |
|
|
| |
| seg_from_prev_stage = np.load(join(self.folder_with_segs_from_prev_stage, |
| k + "_segFromPrevStage.npz"))['data'][None] |
|
|
| print(data.shape) |
| data[-1][data[-1] == -1] = 0 |
| data_for_net = np.concatenate((data[:-1], to_one_hot(seg_from_prev_stage[0], range(1, self.num_classes)))) |
|
|
| softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data_for_net, |
| do_mirroring=do_mirroring, |
| mirror_axes=mirror_axes, |
| use_sliding_window=use_sliding_window, |
| step_size=step_size, |
| use_gaussian=use_gaussian, |
| all_in_gpu=all_in_gpu, |
| mixed_precision=self.fp16)[1] |
|
|
| if transpose_backward is not None: |
| transpose_backward = self.plans.get('transpose_backward') |
| softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in transpose_backward]) |
|
|
| fname = properties['list_of_data_files'][0].split("/")[-1][:-12] |
|
|
| if save_softmax: |
| softmax_fname = join(output_folder, fname + ".npz") |
| else: |
| softmax_fname = None |
|
|
| """There is a problem with python process communication that prevents us from communicating objects |
| larger than 2 GB between processes (basically when the length of the pickle string that will be sent is |
| communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long |
| enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually |
| patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will |
| then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either |
| filename or np.ndarray and will handle this automatically""" |
| if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): |
| np.save(fname + ".npy", softmax_pred) |
| softmax_pred = fname + ".npy" |
|
|
| results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, |
| ((softmax_pred, join(output_folder, fname + ".nii.gz"), |
| properties, interpolation_order, self.regions_class_order, |
| None, None, |
| softmax_fname, None, force_separate_z, |
| interpolation_order_z), |
| ) |
| ) |
| ) |
|
|
| pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), |
| join(self.gt_niftis_folder, fname + ".nii.gz")]) |
|
|
| _ = [i.get() for i in results] |
|
|
| task = self.dataset_directory.split("/")[-1] |
| job_name = self.experiment_name |
| _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), |
| json_output_file=join(output_folder, "summary.json"), json_name=job_name, |
| json_author="Fabian", json_description="", |
| json_task=task) |
|
|
| if run_postprocessing_on_folds: |
| |
| |
| |
| |
| self.print_to_log_file("determining postprocessing") |
| determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, |
| final_subf_name=validation_folder_name + "_postprocessed", debug=debug) |
| |
| |
|
|
| |
| |
| |
| |
| gt_nifti_folder = join(self.output_folder_base, "gt_niftis") |
| maybe_mkdir_p(gt_nifti_folder) |
| for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): |
| success = False |
| attempts = 0 |
| while not success and attempts < 10: |
| try: |
| shutil.copy(f, gt_nifti_folder) |
| success = True |
| except OSError: |
| attempts += 1 |
| sleep(1) |
|
|
| self.network.train(current_mode) |
| export_pool.close() |
| export_pool.join() |