|
|
import os |
|
|
from copy import deepcopy |
|
|
from typing import Union, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice |
|
|
from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle |
|
|
|
|
|
from nnunetv2.configuration import default_num_processes |
|
|
from nnunetv2.utilities.label_handling.label_handling import LabelManager |
|
|
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager |
|
|
|
|
|
import SimpleITK as sitk |
|
|
def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], |
|
|
plans_manager: PlansManager, |
|
|
configuration_manager: ConfigurationManager, |
|
|
label_manager: LabelManager, |
|
|
properties_dict: dict, |
|
|
return_probabilities: bool = False, |
|
|
num_threads_torch: int = default_num_processes): |
|
|
old_threads = torch.get_num_threads() |
|
|
torch.set_num_threads(num_threads_torch) |
|
|
|
|
|
|
|
|
current_spacing = configuration_manager.spacing if \ |
|
|
len(configuration_manager.spacing) == \ |
|
|
len(properties_dict['shape_after_cropping_and_before_resampling']) else \ |
|
|
[properties_dict['spacing'][0], *configuration_manager.spacing] |
|
|
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, |
|
|
properties_dict['shape_after_cropping_and_before_resampling'], |
|
|
current_spacing, |
|
|
properties_dict['spacing']) |
|
|
|
|
|
|
|
|
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) |
|
|
del predicted_logits |
|
|
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) |
|
|
|
|
|
|
|
|
if isinstance(segmentation, torch.Tensor): |
|
|
segmentation = segmentation.cpu().numpy() |
|
|
|
|
|
|
|
|
segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], |
|
|
dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) |
|
|
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) |
|
|
segmentation_reverted_cropping[slicer] = segmentation |
|
|
del segmentation |
|
|
|
|
|
|
|
|
segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) |
|
|
if return_probabilities: |
|
|
|
|
|
predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities, |
|
|
properties_dict[ |
|
|
'bbox_used_for_cropping'], |
|
|
properties_dict[ |
|
|
'shape_before_cropping']) |
|
|
predicted_probabilities = predicted_probabilities.cpu().numpy() |
|
|
|
|
|
predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in |
|
|
plans_manager.transpose_backward]) |
|
|
torch.set_num_threads(old_threads) |
|
|
return segmentation_reverted_cropping, predicted_probabilities |
|
|
else: |
|
|
torch.set_num_threads(old_threads) |
|
|
return segmentation_reverted_cropping |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
def convert_predicted_image_to_original_shape(predicted_image: Union[torch.Tensor, np.ndarray], |
|
|
plans_manager: PlansManager, |
|
|
configuration_manager: ConfigurationManager, |
|
|
properties_dict: dict, |
|
|
num_threads_torch: int = default_num_processes): |
|
|
old_threads = torch.get_num_threads() |
|
|
torch.set_num_threads(num_threads_torch) |
|
|
|
|
|
|
|
|
|
|
|
current_spacing = configuration_manager.spacing if \ |
|
|
len(configuration_manager.spacing) == \ |
|
|
len(properties_dict['shape_after_cropping_and_before_resampling']) else \ |
|
|
[properties_dict['spacing'][0], *configuration_manager.spacing] |
|
|
|
|
|
predicted_resampled = configuration_manager.resampling_fn_data(predicted_image, |
|
|
properties_dict['shape_after_cropping_and_before_resampling'], |
|
|
current_spacing, |
|
|
properties_dict['spacing']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(predicted_resampled, torch.Tensor): |
|
|
predicted_resampled = predicted_resampled.cpu().numpy() |
|
|
|
|
|
predicted_resampled = predicted_resampled[0] |
|
|
|
|
|
print(predicted_resampled.shape, np.min(predicted_resampled), np.max(predicted_resampled), predicted_resampled.dtype) |
|
|
|
|
|
|
|
|
original_shape_image = np.zeros(properties_dict['shape_before_cropping'], |
|
|
dtype=predicted_resampled.dtype) |
|
|
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) |
|
|
original_shape_image[slicer] = predicted_resampled |
|
|
del predicted_resampled |
|
|
|
|
|
|
|
|
original_shape_image = original_shape_image.transpose(plans_manager.transpose_backward) |
|
|
|
|
|
torch.set_num_threads(old_threads) |
|
|
return original_shape_image |
|
|
|
|
|
|
|
|
|
|
|
def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, |
|
|
configuration_manager: ConfigurationManager, |
|
|
plans_manager: PlansManager, |
|
|
dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, |
|
|
save_probabilities: bool = False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(dataset_json_dict_or_file, str): |
|
|
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) |
|
|
|
|
|
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ret = convert_predicted_image_to_original_shape( |
|
|
predicted_array_or_file, plans_manager, configuration_manager, properties_dict |
|
|
) |
|
|
|
|
|
del predicted_array_or_file |
|
|
|
|
|
|
|
|
if save_probabilities: |
|
|
segmentation_final, probabilities_final = ret |
|
|
np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) |
|
|
save_pickle(properties_dict, output_file_truncated + '.pkl') |
|
|
del probabilities_final, ret |
|
|
else: |
|
|
segmentation_final = ret |
|
|
del ret |
|
|
|
|
|
rw = plans_manager.image_reader_writer_class() |
|
|
rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], |
|
|
properties_dict) |
|
|
|
|
|
|
|
|
def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, |
|
|
plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, |
|
|
dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \ |
|
|
-> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
old_threads = torch.get_num_threads() |
|
|
torch.set_num_threads(num_threads_torch) |
|
|
|
|
|
if isinstance(dataset_json_dict_or_file, str): |
|
|
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) |
|
|
|
|
|
|
|
|
current_spacing = configuration_manager.spacing if \ |
|
|
len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ |
|
|
[properties_dict['spacing'][0], *configuration_manager.spacing] |
|
|
target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ |
|
|
len(properties_dict['shape_after_cropping_and_before_resampling']) else \ |
|
|
[properties_dict['spacing'][0], *configuration_manager.spacing] |
|
|
predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, |
|
|
target_shape, |
|
|
current_spacing, |
|
|
target_spacing) |
|
|
|
|
|
|
|
|
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) |
|
|
segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file) |
|
|
|
|
|
if isinstance(segmentation, torch.Tensor): |
|
|
segmentation = segmentation.cpu().numpy() |
|
|
np.savez_compressed(output_file, seg=segmentation.astype(np.uint8)) |
|
|
torch.set_num_threads(old_threads) |
|
|
|