FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
import torch
from tqdm.auto import tqdm
import os
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from batchgenerators.utilities.file_and_folder_operations import load_json,join
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window
from totalsegmentator.alignment import undo_canonical
from totalsegmentator.resampling import change_spacing
from totalsegmentator.postprocessing import remove_auxiliary_labels
from nnunetv2.utilities.helpers import empty_cache
import nnunetv2
from pathlib import Path
import nibabel as nib
import time
import SimpleITK
# from resampling import change_spacing
import tempfile
import numpy as np
from nnunetv2.inference.export_prediction import convert_predicted_logits_to_segmentation_with_correct_shape
from nnunetv2.training.nnUNetTrainer.nnUNetTSTrainer import nnUNetTSTrainer
class MinialTotalSegmentator():
def __init__(self, verbose=False):
super().__init__()
# path upated. by bx
os.environ['nnUNet_raw'] = '/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/ref/evaluation/.totalsegmentator/nnunet/results'
os.environ['nnUNet_preprocessed'] = '/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/ref/evaluation/.totalsegmentator/nnunet/results'
os.environ['nnUNet_results'] = '/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/ref/evaluation/.totalsegmentator/nnunet/results'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ['KMP_INIT_AT_FORK'] = 'FALSE'
self.verbose = verbose
self.verbose_preprocessing = verbose
self.allow_tqdm = verbose
self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \
self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None
self.tile_step_size = 0.5
self.use_gaussian = True
self.use_mirroring = False
device = torch.device('cpu')
if device.type == 'cuda':
torch.backends.cudnn.benchmark = True
else:
print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False')
perform_everything_on_device = False
self.device = device
self.perform_everything_on_device = perform_everything_on_device
# path updated. by bx
model_training_output_dir = '/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/ref/evaluation/.totalsegmentator/nnunet/results/Dataset297_TotalSegmentator_total_3mm_1559subj/nnUNetTrainer_4000epochs_NoMirroring__nnUNetPlans__3d_fullres'
checkpoint_name = 'checkpoint_final.pth'
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
plans = load_json(join(model_training_output_dir, 'plans.json'))
plans_manager = PlansManager(plans)
with torch.serialization.safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType,np.dtypes.Float32DType]):
checkpoint = torch.load(join(model_training_output_dir, f'fold_0', checkpoint_name),
map_location=torch.device('cpu'), weights_only=False)
configuration_name = checkpoint['init_args']['configuration']
trainer_name = checkpoint['trainer_name']
configuration_manager = plans_manager.get_configuration(configuration_name)
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
self.network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
self.network.load_state_dict(checkpoint['network_weights'])
self.network.eval()
# self.network = torch.compile(self.network, backend="openvino")
self.network.share_memory()
def _internal_get_sliding_window_slicers(self, image_size):
slicers = []
if len(self.configuration_manager.patch_size) < len(image_size):
assert len(self.configuration_manager.patch_size) == len(
image_size) - 1, 'if tile_size has less entries than image_size, ' \
'len(tile_size) ' \
'must be one shorter than len(image_size) ' \
'(only dimension ' \
'discrepancy of 1 allowed).'
steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size,
self.tile_step_size)
for d in range(image_size[0]):
for sx in steps[0]:
for sy in steps[1]:
slicers.append(
tuple([slice(None), d, *[slice(si, si + ti) for si, ti in
zip((sx, sy), self.configuration_manager.patch_size)]]))
else:
steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,
self.tile_step_size)
for sx in steps[0]:
for sy in steps[1]:
for sz in steps[2]:
slicers.append(
tuple([slice(None), *[slice(si, si + ti) for si, ti in
zip((sx, sy, sz), self.configuration_manager.patch_size)]]))
return slicers
def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
prediction = self.network(x)
if mirror_axes is not None:
# check for invalid numbers in mirror_axes
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'
mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
prediction /= (len(axes_combinations) + 1)
return prediction
def _internal_predict_sliding_window_return_logits(self,
data: torch.Tensor,
slicers,
do_on_device: bool = True,
):
predicted_logits = n_predictions = prediction = gaussian = workon = None
results_device = self.device if do_on_device else torch.device('cpu')
try:
empty_cache(self.device)
# move data to device
if self.verbose:
print(f'move image to device {results_device}')
data = data.to(results_device) if torch.is_tensor(data) else torch.from_numpy(data).to(results_device)
# preallocate arrays
if self.verbose:
print(f'preallocating results arrays on device {results_device}')
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
else:
gaussian = 1
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device)
prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
if self.use_gaussian:
prediction *= gaussian
predicted_logits[sl] += prediction
n_predictions[sl[1:]] += gaussian
predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
'predicted_logits to fp32')
except Exception as e:
del predicted_logits, n_predictions, prediction, gaussian, workon
empty_cache(self.device)
empty_cache(results_device)
raise e
return predicted_logits
def score_patient(self, file_in, orientation, mask, resample=3.0, nr_threads_resampling=1, save_pred_seg_path=None):
fname_id = Path(str(file_in)).stem #str(file_in).split('/')[-1].split('.')[0]
with tempfile.TemporaryDirectory(prefix=f"nnunet_tmp_{fname_id}") as tmp_folder:
tmp_dir = Path(tmp_folder)
(tmp_dir / "mha").mkdir()
# mha to nifti
read = SimpleITK.ReadImage(file_in)
if orientation is not None:
spacing, origin, direction = orientation
read.SetSpacing(spacing)
read.SetOrigin(origin)
read.SetDirection(direction)
SimpleITK.WriteImage(read, str(tmp_dir / "mha" / "converted_mha.nii.gz"))
# mha_to_nifti(file_in, tmp_dir / "mha" / "converted_mha.nii.gz", tmp_dir, orientation=orientation, verbose=verbose)
file_in_mha = file_in
file_in = tmp_dir / "mha" / "converted_mha.nii.gz"
img_in_orig = nib.load(file_in)
img_data = img_in_orig.get_fdata()
img_data = img_data if mask is None else np.where(mask == 0, -1024, img_data)
img_in = nib.Nifti1Image(img_data, img_in_orig.affine) # copy img_in_orig
img_in = nib.as_closest_canonical(img_in)
if resample is not None:
st = time.time()
img_in_shape = img_in.shape
img_in_zooms = img_in.header.get_zooms()
img_in_rsp = change_spacing(img_in, resample,
order=3, dtype=np.int32, nr_cpus=nr_threads_resampling) # 4 cpus instead of 1 makes it a bit slower
if self.verbose:
print(f" from shape {img_in.shape} to shape {img_in_rsp.shape}")
# if not quiet: print(f" Resampled in {time.time() - st:.2f}s")
else:
img_in_rsp = img_in
label_manager = self.plans_manager.get_label_manager(self.dataset_json)
preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose)
data_properties = {
'nibabel_stuff': {
'original_affine': img_in_rsp.affine,
'reoriented_affine': img_in_rsp.affine
},
'spacing': [3.0, 3.0, 3.0],
}
data, seg = preprocessor.run_case_npy(data=img_in_rsp.get_fdata().T[np.newaxis,...], seg=None, properties=data_properties,
plans_manager=self.plans_manager,
configuration_manager=self.configuration_manager,
dataset_json=self.dataset_json)
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
data, slicer_revert_padding = pad_nd_image(image=data, new_shape=self.configuration_manager.patch_size,
mode='constant', return_slicer=True,
shape_must_be_divisible_by=None)
slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
img_pred = convert_predicted_logits_to_segmentation_with_correct_shape(
predicted_logits, self.plans_manager, self.configuration_manager, self.label_manager,data_properties,False)
img_pred = nib.Nifti1Image(img_pred.T, img_in_rsp.affine)
empty_cache(self.device)
if resample is not None:
if self.verbose: print(f" back from {img_pred.shape} to original shape: {img_in_shape}")
# Use force_affine otherwise output affine sometimes slightly off (which then is even increased
# by undo_canonical)
img_pred = change_spacing(img_pred, resample, img_in_shape,
order=0, dtype=np.uint8, nr_cpus=nr_threads_resampling,
force_affine=img_in.affine)
if self.verbose: print("Undoing canonical...")
empty_cache(self.device)
img_pred = undo_canonical(img_pred, img_in_orig)
empty_cache(self.device)
img_data = img_pred.get_fdata().astype(np.uint8)
if save_pred_seg_path is not None:
if self.verbose: print(f"Saving predicted segmentation to {save_pred_seg_path}")
nib.save(img_pred, os.path.join(save_pred_seg_path, f"{fname_id}.nii.gz"))
return img_data