|
|
import inspect |
|
|
import itertools |
|
|
import multiprocessing |
|
|
import os |
|
|
from copy import deepcopy |
|
|
from time import sleep |
|
|
from typing import Tuple, Union, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from acvl_utils.cropping_and_padding.padding import pad_nd_image |
|
|
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter |
|
|
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ |
|
|
save_json |
|
|
from torch import nn |
|
|
from torch._dynamo import OptimizedModule |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
from tqdm import tqdm |
|
|
|
|
|
import nnunetv2 |
|
|
from nnunetv2.configuration import default_num_processes |
|
|
from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \ |
|
|
preprocessing_iterator_fromnpy |
|
|
from nnunetv2.inference.export_prediction import export_prediction_from_logits, \ |
|
|
convert_predicted_logits_to_segmentation_with_correct_shape |
|
|
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \ |
|
|
compute_steps_for_sliding_window |
|
|
from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy |
|
|
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class |
|
|
from nnunetv2.utilities.helpers import empty_cache, dummy_context |
|
|
from nnunetv2.utilities.json_export import recursive_fix_for_json_export |
|
|
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels |
|
|
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager |
|
|
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder |
|
|
|
|
|
import pickle |
|
|
|
|
|
class nnUNetPredictor(object): |
|
|
def __init__(self, |
|
|
tile_step_size: float = 0.5, |
|
|
use_gaussian: bool = True, |
|
|
use_mirroring: bool = True, |
|
|
perform_everything_on_device: bool = True, |
|
|
device: torch.device = torch.device('cuda'), |
|
|
verbose: bool = False, |
|
|
verbose_preprocessing: bool = False, |
|
|
allow_tqdm: bool = True): |
|
|
self.verbose = verbose |
|
|
self.verbose_preprocessing = verbose_preprocessing |
|
|
self.allow_tqdm = allow_tqdm |
|
|
|
|
|
self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \ |
|
|
self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None |
|
|
|
|
|
self.tile_step_size = tile_step_size |
|
|
|
|
|
print("tile : ", self.tile_step_size ) |
|
|
self.use_gaussian = use_gaussian |
|
|
self.use_mirroring = use_mirroring |
|
|
if device.type == 'cuda': |
|
|
|
|
|
pass |
|
|
if device.type != 'cuda': |
|
|
print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') |
|
|
perform_everything_on_device = False |
|
|
self.device = device |
|
|
self.perform_everything_on_device = perform_everything_on_device |
|
|
|
|
|
def initialize_from_trained_model_folder(self, model_training_output_dir: str, |
|
|
use_folds: Union[Tuple[Union[int, str]], None], |
|
|
checkpoint_name: str = 'checkpoint_final.pth'): |
|
|
""" |
|
|
This is used when making predictions with a trained model |
|
|
""" |
|
|
if use_folds is None: |
|
|
use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) |
|
|
|
|
|
dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) |
|
|
plans = load_json(join(model_training_output_dir, 'plans.json')) |
|
|
plans_manager = PlansManager(plans) |
|
|
|
|
|
if isinstance(use_folds, str): |
|
|
use_folds = [use_folds] |
|
|
|
|
|
parameters = [] |
|
|
for i, f in enumerate(use_folds): |
|
|
f = int(f) if f != 'all' else f |
|
|
checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), |
|
|
map_location=torch.device('cpu'), weights_only=False) |
|
|
if i == 0: |
|
|
trainer_name = checkpoint['trainer_name'] |
|
|
configuration_name = checkpoint['init_args']['configuration'] |
|
|
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ |
|
|
'inference_allowed_mirroring_axes' in checkpoint.keys() else None |
|
|
|
|
|
parameters.append(checkpoint['network_weights']) |
|
|
|
|
|
configuration_manager = plans_manager.get_configuration(configuration_name) |
|
|
|
|
|
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) |
|
|
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), |
|
|
trainer_name, 'nnunetv2.training.nnUNetTrainer') |
|
|
|
|
|
network = trainer_class.build_network_architecture( |
|
|
configuration_manager.network_arch_class_name, |
|
|
configuration_manager.network_arch_init_kwargs, |
|
|
configuration_manager.network_arch_init_kwargs_req_import, |
|
|
num_input_channels, |
|
|
plans_manager.get_label_manager(dataset_json).num_segmentation_heads, |
|
|
enable_deep_supervision=False |
|
|
) |
|
|
|
|
|
self.plans_manager = plans_manager |
|
|
self.configuration_manager = configuration_manager |
|
|
self.list_of_parameters = parameters |
|
|
self.network = network |
|
|
self.dataset_json = dataset_json |
|
|
self.trainer_name = trainer_name |
|
|
self.allowed_mirroring_axes = inference_allowed_mirroring_axes |
|
|
self.label_manager = plans_manager.get_label_manager(dataset_json) |
|
|
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ |
|
|
and not isinstance(self.network, OptimizedModule): |
|
|
print('Using torch.compile') |
|
|
self.network = torch.compile(self.network) |
|
|
|
|
|
def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, |
|
|
configuration_manager: ConfigurationManager, parameters: Optional[List[dict]], |
|
|
dataset_json: dict, trainer_name: str, |
|
|
inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]): |
|
|
""" |
|
|
This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation |
|
|
""" |
|
|
self.plans_manager = plans_manager |
|
|
self.configuration_manager = configuration_manager |
|
|
self.list_of_parameters = parameters |
|
|
self.network = network |
|
|
self.dataset_json = dataset_json |
|
|
self.trainer_name = trainer_name |
|
|
self.allowed_mirroring_axes = inference_allowed_mirroring_axes |
|
|
self.label_manager = plans_manager.get_label_manager(dataset_json) |
|
|
allow_compile = True |
|
|
allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and ( |
|
|
os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) |
|
|
allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) |
|
|
if isinstance(self.network, DistributedDataParallel): |
|
|
allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) |
|
|
if allow_compile: |
|
|
print('Using torch.compile') |
|
|
self.network = torch.compile(self.network) |
|
|
|
|
|
@staticmethod |
|
|
def auto_detect_available_folds(model_training_output_dir, checkpoint_name): |
|
|
print('use_folds is None, attempting to auto detect available folds') |
|
|
fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False) |
|
|
fold_folders = [i for i in fold_folders if i != 'fold_all'] |
|
|
fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))] |
|
|
use_folds = [int(i.split('_')[-1]) for i in fold_folders] |
|
|
print(f'found the following folds: {use_folds}') |
|
|
return use_folds |
|
|
|
|
|
def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], |
|
|
output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]], |
|
|
folder_with_segs_from_prev_stage: str = None, |
|
|
overwrite: bool = True, |
|
|
part_id: int = 0, |
|
|
num_parts: int = 1, |
|
|
save_probabilities: bool = False): |
|
|
if isinstance(list_of_lists_or_source_folder, str): |
|
|
list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder, |
|
|
self.dataset_json['file_ending']) |
|
|
print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder') |
|
|
print(list_of_lists_or_source_folder) |
|
|
list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts] |
|
|
|
|
|
|
|
|
caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in list_of_lists_or_source_folder if len(i) > 0 and len(os.path.basename(i[0])) > len(self.dataset_json['file_ending']) + 5] |
|
|
|
|
|
|
|
|
print(f'There are {len(caseids)} cases that I would like to predict') |
|
|
|
|
|
if isinstance(output_folder_or_list_of_truncated_output_files, str): |
|
|
output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids] |
|
|
else: |
|
|
output_filename_truncated = output_folder_or_list_of_truncated_output_files |
|
|
|
|
|
seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if |
|
|
folder_with_segs_from_prev_stage is not None else None for i in caseids] |
|
|
|
|
|
if not overwrite and output_filename_truncated is not None: |
|
|
tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated] |
|
|
if save_probabilities: |
|
|
tmp2 = [isfile(i + '.npz') for i in output_filename_truncated] |
|
|
tmp = [i and j for i, j in zip(tmp, tmp2)] |
|
|
not_existing_indices = [i for i, j in enumerate(tmp) if not j] |
|
|
|
|
|
output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices] |
|
|
list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices] |
|
|
seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices] |
|
|
print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. ' |
|
|
f'That\'s {len(not_existing_indices)} cases.') |
|
|
return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files |
|
|
|
|
|
def predict_from_files(self, |
|
|
list_of_lists_or_source_folder: Union[str, List[List[str]]], |
|
|
output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], |
|
|
save_probabilities: bool = False, |
|
|
overwrite: bool = True, |
|
|
num_processes_preprocessing: int = default_num_processes, |
|
|
num_processes_segmentation_export: int = default_num_processes, |
|
|
folder_with_segs_from_prev_stage: str = None, |
|
|
num_parts: int = 1, |
|
|
part_id: int = 0, |
|
|
reconstruction_mode:str = "mean"): |
|
|
""" |
|
|
This is nnU-Net's default function for making predictions. It works best for batch predictions |
|
|
(predicting many images at once). |
|
|
""" |
|
|
if isinstance(output_folder_or_list_of_truncated_output_files, str): |
|
|
output_folder = output_folder_or_list_of_truncated_output_files |
|
|
elif isinstance(output_folder_or_list_of_truncated_output_files, list): |
|
|
output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0]) |
|
|
else: |
|
|
output_folder = None |
|
|
|
|
|
|
|
|
|
|
|
if output_folder is not None: |
|
|
my_init_kwargs = {} |
|
|
for k in inspect.signature(self.predict_from_files).parameters.keys(): |
|
|
my_init_kwargs[k] = locals()[k] |
|
|
my_init_kwargs = deepcopy( |
|
|
my_init_kwargs) |
|
|
recursive_fix_for_json_export(my_init_kwargs) |
|
|
maybe_mkdir_p(output_folder) |
|
|
save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json')) |
|
|
|
|
|
|
|
|
save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) |
|
|
save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False) |
|
|
|
|
|
|
|
|
|
|
|
if self.configuration_manager.previous_stage_name is not None: |
|
|
assert folder_with_segs_from_prev_stage is not None, \ |
|
|
f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \ |
|
|
f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \ |
|
|
f' they are located via folder_with_segs_from_prev_stage' |
|
|
|
|
|
|
|
|
list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ |
|
|
self._manage_input_and_output_lists(list_of_lists_or_source_folder, |
|
|
output_folder_or_list_of_truncated_output_files, |
|
|
folder_with_segs_from_prev_stage, overwrite, part_id, num_parts, |
|
|
save_probabilities) |
|
|
if len(list_of_lists_or_source_folder) == 0: |
|
|
return |
|
|
|
|
|
data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, |
|
|
seg_from_prev_stage_files, |
|
|
output_filename_truncated, |
|
|
num_processes_preprocessing) |
|
|
|
|
|
return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export, reconstruction_mode) |
|
|
|
|
|
def _internal_get_data_iterator_from_lists_of_filenames(self, |
|
|
input_list_of_lists: List[List[str]], |
|
|
seg_from_prev_stage_files: Union[List[str], None], |
|
|
output_filenames_truncated: Union[List[str], None], |
|
|
num_processes: int): |
|
|
return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files, |
|
|
output_filenames_truncated, self.plans_manager, self.dataset_json, |
|
|
self.configuration_manager, num_processes, self.device.type == 'cuda', |
|
|
self.verbose_preprocessing) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_data_iterator_from_raw_npy_data(self, |
|
|
image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], |
|
|
segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, |
|
|
np.ndarray, |
|
|
List[ |
|
|
np.ndarray]], |
|
|
properties_or_list_of_properties: Union[dict, List[dict]], |
|
|
truncated_ofname: Union[str, List[str], None], |
|
|
num_processes: int = 3): |
|
|
|
|
|
list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \ |
|
|
image_or_list_of_images |
|
|
|
|
|
if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray): |
|
|
segs_from_prev_stage_or_list_of_segs_from_prev_stage = [ |
|
|
segs_from_prev_stage_or_list_of_segs_from_prev_stage] |
|
|
|
|
|
if isinstance(truncated_ofname, str): |
|
|
truncated_ofname = [truncated_ofname] |
|
|
|
|
|
if isinstance(properties_or_list_of_properties, dict): |
|
|
properties_or_list_of_properties = [properties_or_list_of_properties] |
|
|
|
|
|
num_processes = min(num_processes, len(list_of_images)) |
|
|
pp = preprocessing_iterator_fromnpy( |
|
|
list_of_images, |
|
|
segs_from_prev_stage_or_list_of_segs_from_prev_stage, |
|
|
properties_or_list_of_properties, |
|
|
truncated_ofname, |
|
|
self.plans_manager, |
|
|
self.dataset_json, |
|
|
self.configuration_manager, |
|
|
num_processes, |
|
|
self.device.type == 'cuda', |
|
|
self.verbose_preprocessing |
|
|
) |
|
|
|
|
|
return pp |
|
|
|
|
|
def predict_from_list_of_npy_arrays(self, |
|
|
image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], |
|
|
segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, |
|
|
np.ndarray, |
|
|
List[ |
|
|
np.ndarray]], |
|
|
properties_or_list_of_properties: Union[dict, List[dict]], |
|
|
truncated_ofname: Union[str, List[str], None], |
|
|
num_processes: int = 3, |
|
|
save_probabilities: bool = False, |
|
|
num_processes_segmentation_export: int = default_num_processes): |
|
|
iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images, |
|
|
segs_from_prev_stage_or_list_of_segs_from_prev_stage, |
|
|
properties_or_list_of_properties, |
|
|
truncated_ofname, |
|
|
num_processes) |
|
|
return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export) |
|
|
|
|
|
def predict_from_data_iterator(self, |
|
|
data_iterator, |
|
|
save_probabilities: bool = False, |
|
|
num_processes_segmentation_export: int = default_num_processes, |
|
|
reconstruction_mode:str = "mean"): |
|
|
""" |
|
|
each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! |
|
|
If 'ofile' is None, the result will be returned instead of written to a file |
|
|
""" |
|
|
with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: |
|
|
worker_list = [i for i in export_pool._pool] |
|
|
r = [] |
|
|
for preprocessed in data_iterator: |
|
|
data = preprocessed['data'] |
|
|
if isinstance(data, str): |
|
|
delfile = data |
|
|
data = torch.from_numpy(np.load(data)) |
|
|
os.remove(delfile) |
|
|
|
|
|
ofile = preprocessed['ofile'] |
|
|
if ofile is not None: |
|
|
print(f'\nPredicting {os.path.basename(ofile)}:') |
|
|
else: |
|
|
print(f'\nPredicting image of shape {data.shape}:') |
|
|
|
|
|
print(f'perform_everything_on_device: {self.perform_everything_on_device}') |
|
|
|
|
|
properties = preprocessed['data_properties'] |
|
|
|
|
|
|
|
|
|
|
|
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) |
|
|
while not proceed: |
|
|
sleep(0.1) |
|
|
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) |
|
|
|
|
|
prediction = self.predict_logits_from_preprocessed_data(data, reconstruction_mode = reconstruction_mode).cpu() |
|
|
|
|
|
if ofile is not None: |
|
|
|
|
|
|
|
|
|
|
|
print('sending off prediction to background worker for resampling and export') |
|
|
r.append( |
|
|
export_pool.starmap_async( |
|
|
export_prediction_from_logits, |
|
|
((prediction, properties, self.configuration_manager, self.plans_manager, |
|
|
self.dataset_json, ofile, save_probabilities),) |
|
|
) |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('sending off prediction to background worker for resampling') |
|
|
r.append( |
|
|
export_pool.starmap_async( |
|
|
convert_predicted_logits_to_segmentation_with_correct_shape, ( |
|
|
(prediction, self.plans_manager, |
|
|
self.configuration_manager, self.label_manager, |
|
|
properties, |
|
|
save_probabilities),) |
|
|
) |
|
|
) |
|
|
if ofile is not None: |
|
|
print(f'done with {os.path.basename(ofile)}') |
|
|
else: |
|
|
print(f'\nDone with image of shape {data.shape}:') |
|
|
ret = [i.get()[0] for i in r] |
|
|
|
|
|
if isinstance(data_iterator, MultiThreadedAugmenter): |
|
|
data_iterator._finish() |
|
|
|
|
|
|
|
|
compute_gaussian.cache_clear() |
|
|
|
|
|
empty_cache(self.device) |
|
|
return ret |
|
|
|
|
|
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, |
|
|
segmentation_previous_stage: np.ndarray = None, |
|
|
output_file_truncated: str = None, |
|
|
save_or_return_probabilities: bool = False): |
|
|
""" |
|
|
image_properties must only have a 'spacing' key! |
|
|
""" |
|
|
ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], |
|
|
[output_file_truncated], |
|
|
self.plans_manager, self.dataset_json, self.configuration_manager, |
|
|
num_threads_in_multithreaded=1, verbose=self.verbose) |
|
|
if self.verbose: |
|
|
print('preprocessing') |
|
|
dct = next(ppa) |
|
|
|
|
|
if self.verbose: |
|
|
print('predicting') |
|
|
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu() |
|
|
|
|
|
if self.verbose: |
|
|
print('resampling to original shape') |
|
|
if output_file_truncated is not None: |
|
|
export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager, |
|
|
self.plans_manager, self.dataset_json, output_file_truncated, |
|
|
save_or_return_probabilities) |
|
|
else: |
|
|
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, |
|
|
self.configuration_manager, |
|
|
self.label_manager, |
|
|
dct['data_properties'], |
|
|
return_probabilities= |
|
|
save_or_return_probabilities) |
|
|
if save_or_return_probabilities: |
|
|
return ret[0], ret[1] |
|
|
else: |
|
|
return ret |
|
|
|
|
|
def predict_logits_from_preprocessed_data(self, data: torch.Tensor, reconstruction_mode:str = "mean") -> torch.Tensor: |
|
|
""" |
|
|
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON |
|
|
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! |
|
|
|
|
|
RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. |
|
|
SEE convert_predicted_logits_to_segmentation_with_correct_shape |
|
|
""" |
|
|
n_threads = torch.get_num_threads() |
|
|
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) |
|
|
with torch.no_grad(): |
|
|
prediction = None |
|
|
|
|
|
for params in self.list_of_parameters: |
|
|
|
|
|
|
|
|
if not isinstance(self.network, OptimizedModule): |
|
|
self.network.load_state_dict(params) |
|
|
else: |
|
|
self.network._orig_mod.load_state_dict(params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prediction is None: |
|
|
prediction = self.predict_sliding_window_return_logits(data, reconstruction_mode=reconstruction_mode).to('cpu') |
|
|
|
|
|
|
|
|
else: |
|
|
prediction += self.predict_sliding_window_return_logits(data, reconstruction_mode=reconstruction_mode).to('cpu') |
|
|
|
|
|
|
|
|
if len(self.list_of_parameters) > 1: |
|
|
prediction /= len(self.list_of_parameters) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.verbose: print('Prediction done') |
|
|
prediction = prediction.to('cpu') |
|
|
torch.set_num_threads(n_threads) |
|
|
return prediction |
|
|
|
|
|
def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): |
|
|
slicers = [] |
|
|
if len(self.configuration_manager.patch_size) < len(image_size): |
|
|
assert len(self.configuration_manager.patch_size) == len( |
|
|
image_size) - 1, 'if tile_size has less entries than image_size, ' \ |
|
|
'len(tile_size) ' \ |
|
|
'must be one shorter than len(image_size) ' \ |
|
|
'(only dimension ' \ |
|
|
'discrepancy of 1 allowed).' |
|
|
steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size, |
|
|
self.tile_step_size) |
|
|
if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is' |
|
|
f' {image_size}, tile_size {self.configuration_manager.patch_size}, ' |
|
|
f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') |
|
|
for d in range(image_size[0]): |
|
|
for sx in steps[0]: |
|
|
for sy in steps[1]: |
|
|
slicers.append( |
|
|
tuple([slice(None), d, *[slice(si, si + ti) for si, ti in |
|
|
zip((sx, sy), self.configuration_manager.patch_size)]])) |
|
|
else: |
|
|
steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size, |
|
|
self.tile_step_size) |
|
|
if self.verbose: print( |
|
|
f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, ' |
|
|
f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') |
|
|
for sx in steps[0]: |
|
|
for sy in steps[1]: |
|
|
for sz in steps[2]: |
|
|
slicers.append( |
|
|
tuple([slice(None), *[slice(si, si + ti) for si, ti in |
|
|
zip((sx, sy, sz), self.configuration_manager.patch_size)]])) |
|
|
return slicers |
|
|
|
|
|
def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: |
|
|
mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None |
|
|
prediction = self.network(x) |
|
|
if mirror_axes is not None: |
|
|
|
|
|
|
|
|
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' |
|
|
|
|
|
axes_combinations = [ |
|
|
c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1) |
|
|
] |
|
|
for axes in axes_combinations: |
|
|
prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,)) |
|
|
prediction /= (len(axes_combinations) + 1) |
|
|
return prediction |
|
|
|
|
|
def rec_mean(self, slicers, data): |
|
|
results_device = self.device |
|
|
|
|
|
vol = torch.zeros((data.shape),dtype=torch.half) |
|
|
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half) |
|
|
for sl in tqdm(slicers): |
|
|
workon = data[sl][None] |
|
|
workon = workon.to(self.device, non_blocking=False) |
|
|
prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) |
|
|
patch = prediction.detach().cpu()[0] |
|
|
|
|
|
|
|
|
vol[sl] += patch |
|
|
n_predictions[sl[1:]] += 1 |
|
|
vol /= n_predictions |
|
|
return vol |
|
|
|
|
|
def rec_median(self, slicers, data, max_layers=50): |
|
|
results_device = self.device |
|
|
|
|
|
vol = torch.zeros((max_layers, *data.shape),dtype=torch.float32) |
|
|
iii=0 |
|
|
for sl in tqdm(slicers): |
|
|
workon = data[sl][None] |
|
|
workon = workon.to(self.device, non_blocking=False) |
|
|
prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) |
|
|
patch = prediction.detach().cpu()[0] |
|
|
iii+=1 |
|
|
if iii==99: |
|
|
np.save(f"{iii}.npy", patch) |
|
|
|
|
|
for layer in range(max_layers): |
|
|
if torch.sum(vol[layer][sl])==0: |
|
|
vol[layer][sl] = patch |
|
|
break |
|
|
for layer in range(max_layers): |
|
|
if torch.sum(vol[layer])==0: |
|
|
if layer >= max_layers-1: |
|
|
raise Exception("max_layers in median reconstruction is too low!") |
|
|
print("nb layer used for rec_median : ", layer) |
|
|
break |
|
|
|
|
|
vol = torch.where(vol == 0, torch.tensor(float('nan')), vol) |
|
|
median_vol = torch.nanmedian(vol, dim=0) |
|
|
return median_vol[0].half() |
|
|
|
|
|
def _internal_predict_sliding_window_return_logits(self, |
|
|
data: torch.Tensor, |
|
|
slicers, |
|
|
do_on_device: bool = True, |
|
|
reconstruction_mode:str = "mean", |
|
|
): |
|
|
predicted_logits = n_predictions = prediction = gaussian = workon = None |
|
|
results_device = self.device if do_on_device else torch.device('cpu') |
|
|
|
|
|
try: |
|
|
empty_cache(self.device) |
|
|
|
|
|
|
|
|
if self.verbose: |
|
|
print(f'move image to device {results_device}') |
|
|
data = data.to(results_device) |
|
|
|
|
|
|
|
|
if self.verbose: |
|
|
print(f'preallocating results arrays on device {results_device}') |
|
|
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), |
|
|
dtype=torch.half, |
|
|
device=results_device) |
|
|
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) |
|
|
if self.use_gaussian: |
|
|
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, |
|
|
value_scaling_factor=10, |
|
|
device=results_device) |
|
|
|
|
|
if self.verbose: print('running prediction') |
|
|
if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if reconstruction_mode == "mean": |
|
|
print("Reconstruction: MEAN") |
|
|
predicted_logits = self.rec_mean(slicers, data) |
|
|
elif reconstruction_mode == "median": |
|
|
print("Reconstruction: MEDIAN") |
|
|
predicted_logits = self.rec_median(slicers, data) |
|
|
else: |
|
|
raise ValueError(f"Unknown reconstruction mode: {reconstruction_mode}") |
|
|
|
|
|
|
|
|
if torch.any(torch.isinf(predicted_logits)): |
|
|
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' |
|
|
'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' |
|
|
'predicted_logits to fp32') |
|
|
except Exception as e: |
|
|
del predicted_logits, n_predictions, prediction, gaussian, workon |
|
|
empty_cache(self.device) |
|
|
empty_cache(results_device) |
|
|
raise e |
|
|
|
|
|
return predicted_logits |
|
|
|
|
|
def predict_sliding_window_return_logits(self, input_image: torch.Tensor, reconstruction_mode:str = "mean") \ |
|
|
-> Union[np.ndarray, torch.Tensor]: |
|
|
assert isinstance(input_image, torch.Tensor) |
|
|
self.network = self.network.to(self.device) |
|
|
self.network.eval() |
|
|
|
|
|
empty_cache(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): |
|
|
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' |
|
|
|
|
|
if self.verbose: print(f'Input shape: {input_image.shape}') |
|
|
if self.verbose: print("step_size:", self.tile_step_size) |
|
|
if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) |
|
|
|
|
|
|
|
|
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, |
|
|
'constant', {'value': 0}, True, |
|
|
None) |
|
|
|
|
|
slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) |
|
|
|
|
|
if self.perform_everything_on_device and self.device != 'cpu': |
|
|
|
|
|
|
|
|
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, |
|
|
self.perform_everything_on_device, |
|
|
reconstruction_mode) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, |
|
|
self.perform_everything_on_device, |
|
|
reconstruction_mode) |
|
|
|
|
|
empty_cache(self.device) |
|
|
|
|
|
predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] |
|
|
return predicted_logits |
|
|
|
|
|
|
|
|
def predict_entry_point_modelfolder(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' |
|
|
'you want to manually specify a folder containing a trained nnU-Net ' |
|
|
'model. This is useful when the nnunet environment variables ' |
|
|
'(nnUNet_results) are not set.') |
|
|
parser.add_argument('-i', type=str, required=True, |
|
|
help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' |
|
|
'File endings must be the same as the training dataset!') |
|
|
parser.add_argument('-o', type=str, required=True, |
|
|
help='Output folder. If it does not exist it will be created. Predicted segmentations will ' |
|
|
'have the same name as their source images.') |
|
|
parser.add_argument('-m', type=str, required=True, |
|
|
help='Folder in which the trained model is. Must have subfolders fold_X for the different ' |
|
|
'folds you trained') |
|
|
parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), |
|
|
help='Specify the folds of the trained model that should be used for prediction. ' |
|
|
'Default: (0, 1, 2, 3, 4)') |
|
|
parser.add_argument('-step_size', type=float, required=False, default=0.5, |
|
|
help='Step size for sliding window prediction. The larger it is the faster but less accurate ' |
|
|
'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') |
|
|
parser.add_argument('--disable_tta', action='store_true', required=False, default=False, |
|
|
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' |
|
|
'but less accurate inference. Not recommended.') |
|
|
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " |
|
|
"to be a good listener/reader.") |
|
|
parser.add_argument('--save_probabilities', action='store_true', |
|
|
help='Set this to export predicted class "probabilities". Required if you want to ensemble ' |
|
|
'multiple configurations.') |
|
|
parser.add_argument('--continue_prediction', '--c', action='store_true', |
|
|
help='Continue an aborted previous prediction (will not overwrite existing files)') |
|
|
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', |
|
|
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') |
|
|
parser.add_argument('-npp', type=int, required=False, default=3, |
|
|
help='Number of processes used for preprocessing. More is not always better. Beware of ' |
|
|
'out-of-RAM issues. Default: 3') |
|
|
parser.add_argument('-nps', type=int, required=False, default=3, |
|
|
help='Number of processes used for segmentation export. More is not always better. Beware of ' |
|
|
'out-of-RAM issues. Default: 3') |
|
|
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, |
|
|
help='Folder containing the predictions of the previous stage. Required for cascaded models.') |
|
|
parser.add_argument('-device', type=str, default='cuda', required=False, |
|
|
help="Use this to set the device the inference should run with. Available options are 'cuda' " |
|
|
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " |
|
|
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") |
|
|
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, |
|
|
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' |
|
|
'jobs)') |
|
|
parser.add_argument('--rec', type=str, default='mean', choices=['mean', 'median'], |
|
|
help='Method of reconstruction: mean or median. Default is mean.') |
|
|
|
|
|
|
|
|
print( |
|
|
"\n#######################################################################\nPlease cite the following paper " |
|
|
"when using nnU-Net:\n" |
|
|
"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " |
|
|
"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " |
|
|
"Nature methods, 18(2), 203-211.\n#######################################################################\n") |
|
|
|
|
|
args = parser.parse_args() |
|
|
args.f = [i if i == 'all' else int(i) for i in args.f] |
|
|
|
|
|
if not isdir(args.o): |
|
|
maybe_mkdir_p(args.o) |
|
|
|
|
|
assert args.device in ['cpu', 'cuda', |
|
|
'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' |
|
|
if args.device == 'cpu': |
|
|
|
|
|
import multiprocessing |
|
|
torch.set_num_threads(multiprocessing.cpu_count()) |
|
|
device = torch.device('cpu') |
|
|
elif args.device == 'cuda': |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
torch.set_num_interop_threads(1) |
|
|
device = torch.device('cuda') |
|
|
else: |
|
|
device = torch.device('mps') |
|
|
|
|
|
predictor = nnUNetPredictor(tile_step_size=args.step_size, |
|
|
use_gaussian=True, |
|
|
use_mirroring=not args.disable_tta, |
|
|
perform_everything_on_device=True, |
|
|
device=device, |
|
|
verbose=args.verbose, |
|
|
allow_tqdm=not args.disable_progress_bar, |
|
|
verbose_preprocessing=args.verbose) |
|
|
predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) |
|
|
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, |
|
|
overwrite=not args.continue_prediction, |
|
|
num_processes_preprocessing=args.npp, |
|
|
num_processes_segmentation_export=args.nps, |
|
|
folder_with_segs_from_prev_stage=args.prev_stage_predictions, |
|
|
num_parts=1, part_id=0, |
|
|
reconstruction_mode=args.rec) |
|
|
|
|
|
def predict_entry_point(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' |
|
|
'you want to manually specify a folder containing a trained nnU-Net ' |
|
|
'model. This is useful when the nnunet environment variables ' |
|
|
'(nnUNet_results) are not set.') |
|
|
parser.add_argument('-i', type=str, required=True, |
|
|
help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' |
|
|
'File endings must be the same as the training dataset!') |
|
|
parser.add_argument('-o', type=str, required=True, |
|
|
help='Output folder. If it does not exist it will be created. Predicted segmentations will ' |
|
|
'have the same name as their source images.') |
|
|
parser.add_argument('-d', type=str, required=True, |
|
|
help='Dataset with which you would like to predict. You can specify either dataset name or id') |
|
|
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', |
|
|
help='Plans identifier. Specify the plans in which the desired configuration is located. ' |
|
|
'Default: nnUNetPlans') |
|
|
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', |
|
|
help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer') |
|
|
parser.add_argument('-c', type=str, required=True, |
|
|
help='nnU-Net configuration that should be used for prediction. Config must be located ' |
|
|
'in the plans specified with -p') |
|
|
parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), |
|
|
help='Specify the folds of the trained model that should be used for prediction. ' |
|
|
'Default: (0, 1, 2, 3, 4)') |
|
|
parser.add_argument('-step_size', type=float, required=False, default=0.5, |
|
|
help='Step size for sliding window prediction. The larger it is the faster but less accurate ' |
|
|
'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') |
|
|
parser.add_argument('--disable_tta', action='store_true', required=False, default=False, |
|
|
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' |
|
|
'but less accurate inference. Not recommended.') |
|
|
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " |
|
|
"to be a good listener/reader.") |
|
|
parser.add_argument('--save_probabilities', action='store_true', |
|
|
help='Set this to export predicted class "probabilities". Required if you want to ensemble ' |
|
|
'multiple configurations.') |
|
|
parser.add_argument('--continue_prediction', action='store_true', |
|
|
help='Continue an aborted previous prediction (will not overwrite existing files)') |
|
|
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', |
|
|
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') |
|
|
parser.add_argument('-npp', type=int, required=False, default=3, |
|
|
help='Number of processes used for preprocessing. More is not always better. Beware of ' |
|
|
'out-of-RAM issues. Default: 3') |
|
|
parser.add_argument('-nps', type=int, required=False, default=3, |
|
|
help='Number of processes used for segmentation export. More is not always better. Beware of ' |
|
|
'out-of-RAM issues. Default: 3') |
|
|
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, |
|
|
help='Folder containing the predictions of the previous stage. Required for cascaded models.') |
|
|
parser.add_argument('-num_parts', type=int, required=False, default=1, |
|
|
help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one ' |
|
|
'call predicts everything)') |
|
|
parser.add_argument('-part_id', type=int, required=False, default=0, |
|
|
help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with ' |
|
|
'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts ' |
|
|
'5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible ' |
|
|
'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)') |
|
|
parser.add_argument('-device', type=str, default='cuda', required=False, |
|
|
help="Use this to set the device the inference should run with. Available options are 'cuda' " |
|
|
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " |
|
|
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") |
|
|
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, |
|
|
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' |
|
|
'jobs)') |
|
|
parser.add_argument('--rec', type=str, default='mean', choices=['mean', 'median'], |
|
|
help='Method of reconstruction: mean or median. Default is mean.') |
|
|
|
|
|
|
|
|
print( |
|
|
"\n#######################################################################\nPlease cite the following paper " |
|
|
"when using nnU-Net:\n" |
|
|
"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " |
|
|
"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " |
|
|
"Nature methods, 18(2), 203-211.\n#######################################################################\n") |
|
|
|
|
|
args = parser.parse_args() |
|
|
args.f = [i if i == 'all' else int(i) for i in args.f] |
|
|
|
|
|
model_folder = get_output_folder(args.d, args.tr, args.p, args.c) |
|
|
|
|
|
if not isdir(args.o): |
|
|
maybe_mkdir_p(args.o) |
|
|
|
|
|
|
|
|
assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.' |
|
|
|
|
|
assert args.device in ['cpu', 'cuda', |
|
|
'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' |
|
|
if args.device == 'cpu': |
|
|
|
|
|
import multiprocessing |
|
|
torch.set_num_threads(multiprocessing.cpu_count()) |
|
|
device = torch.device('cpu') |
|
|
elif args.device == 'cuda': |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
torch.set_num_interop_threads(1) |
|
|
device = torch.device('cuda') |
|
|
else: |
|
|
device = torch.device('mps') |
|
|
|
|
|
predictor = nnUNetPredictor(tile_step_size=args.step_size, |
|
|
use_gaussian=True, |
|
|
use_mirroring=not args.disable_tta, |
|
|
perform_everything_on_device=True, |
|
|
device=device, |
|
|
verbose=args.verbose, |
|
|
verbose_preprocessing=args.verbose, |
|
|
allow_tqdm=not args.disable_progress_bar) |
|
|
predictor.initialize_from_trained_model_folder( |
|
|
model_folder, |
|
|
args.f, |
|
|
checkpoint_name=args.chk |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, |
|
|
overwrite=not args.continue_prediction, |
|
|
num_processes_preprocessing=args.npp, |
|
|
num_processes_segmentation_export=args.nps, |
|
|
folder_with_segs_from_prev_stage=args.prev_stage_predictions, |
|
|
num_parts=args.num_parts, |
|
|
reconstruction_mode=args.rec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
from nnunetv2.paths import nnUNet_results, nnUNet_raw |
|
|
|
|
|
dataset_name = "Dataset540_synthrad2025_task2_CBCT_AB_pre_v2r_stitched_masked_both" |
|
|
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
|
|
FOLD=(0,1,2,3,4) |
|
|
IMG_NAME = '2ABA033_0000.mha' |
|
|
OUTPUT_FILE = '/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/export_models/testing_dataset540_fold0/2ABA033_before_norm.mha' |
|
|
|
|
|
predictor = nnUNetPredictor( |
|
|
tile_step_size=0.5, |
|
|
use_gaussian=True, |
|
|
use_mirroring=True, |
|
|
perform_everything_on_device=True, |
|
|
device=torch.device('cuda', 0), |
|
|
verbose=True, |
|
|
verbose_preprocessing=True, |
|
|
allow_tqdm=True |
|
|
) |
|
|
predictor.initialize_from_trained_model_folder( |
|
|
join(nnUNet_results, f'{dataset_name}/{result_folder}'), |
|
|
use_folds=FOLD, |
|
|
checkpoint_name='checkpoint_final.pth', |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO |
|
|
|
|
|
img, props = SimpleITKIO().read_images([join(nnUNet_raw, f'{dataset_name}/imagesTr/{IMG_NAME}')]) |
|
|
ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|