#!/usr/bin/env python3 # -*- coding: utf-8 -*- import numpy as np from typing import Optional import nibabel as nib import os import torch import SimpleITK as sitk from monai.metrics import DiceMetric, HausdorffDistanceMetric from nibabel.nifti1 import Nifti1Image from nnunetv2.analysis.ts_utils import MinialTotalSegmentator class SegmentationMetrics(): def __init__(self, debug=False): # Use fixed wide dynamic range self.debug = debug self.dynamic_range = [-1024., 3000.] self.my_ts = MinialTotalSegmentator(verbose=self.debug) # TotalSegmentator classes. See here https://github.com/wasserth/TotalSegmentator?tab=readme-ov-file#class-details (TotalSegmenator commit cd3d5362245237f13adbb78cdfaee615f54096a1) self.classes_to_use = { "AB": [ 2, # kidney right 3, # kidney left 5, # liver 6, # stomach *range(10, 14+1), #lungs *range(26, 50+1), #vertebrae 51, #heart 79, # spinal cord *range(92, 115+1), # ribs 116 #sternum ], "HN": [ 15, # esophagus 16, # trachea 17, # thyroid *range(26, 50+1), #vertebrae 79, #spinal cord 90, # brain 91, # skull ], "TH": [ 2, # kidney right 3, # kidney left 5, # liver 6, # stomach *range(10, 14+1), #lungs *range(26, 50+1), #vertebrae 51, #heart 79, # spinal cord *range(92, 115+1), # ribs 116 #sternum ] } def score_patient_ts(self, synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=None, save_pred_seg_path=None): with torch.no_grad(): pred_seg=self.my_ts.score_patient(synthetic_ct_location, orientation, mask, save_pred_seg_path=save_pred_seg_path) # Retrieve the data in the NiftiImage from nibabel if isinstance(pred_seg, Nifti1Image): pred_seg = np.array(pred_seg.get_fdata()) return self.score_patient(gt_segmentation, pred_seg, mask, patient_id, orientation) def score_patient(self, gt_segmentation, sct_segmentation, mask, patient_id, orientation=None): # Calculate segmentation metrics # Perform segmentation using TotalSegmentator, enforce the orientation of the ground-truth on the output anatomy = patient_id[1:3].upper() assert sct_segmentation.shape == gt_segmentation.shape # Convert to PyTorch tensors for MONAI gt_seg = gt_segmentation.cpu().detach() if torch.is_tensor(gt_segmentation) else torch.from_numpy(gt_segmentation).cpu().detach() pred_seg = sct_segmentation.cpu().detach() if torch.is_tensor(sct_segmentation) else torch.from_numpy(sct_segmentation).cpu().detach() assert gt_seg.shape == pred_seg.shape if orientation is not None: spacing, origin, direction = orientation else: spacing=None # list of metrics to evaluate metrics = [ { 'name': 'DICE', 'f':DiceMetric(include_background=True, reduction="mean", get_not_nans=False) }, { 'name': 'HD95', 'f': HausdorffDistanceMetric(include_background=True, reduction="mean", percentile=95, get_not_nans=False), 'kwargs': {'spacing': spacing} } ] # Evaluate each one-hot metric for c in self.classes_to_use[anatomy]: gt_tensor = (gt_seg == c).view(1, 1, *gt_seg.shape) if gt_tensor.sum() == 0: if self.debug: print(f"No {c} in {patient_id}") continue est_tensor = (pred_seg == c).view(1, 1, *pred_seg.shape) for metric in metrics: metric['f'](est_tensor, gt_tensor, **metric['kwargs'] if 'kwargs' in metric else {}) # aggregate the mean metrics for the patient over the classes result = {} for metric in metrics: result[metric['name']] = metric['f'].aggregate().item() metric['f'].reset() return result def load_image_file_directly(*, location, return_orientation=False, set_orientation=None): # immediatly load the file and find its orientation result = sitk.ReadImage(location) # Note, transpose needed because Numpy is ZYX according to SimpleITKs XYZ img_arr = np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0]) if return_orientation: spacing = result.GetSpacing() origin = result.GetOrigin() direction = result.GetDirection() return img_arr, spacing, origin, direction else: # If desired, force the orientation on an image before converting to NumPy array if set_orientation is not None: spacing, origin, direction = set_orientation result.SetSpacing(spacing) result.SetOrigin(origin) result.SetDirection(direction) # Note, transpose needed because Numpy is ZYX according to SimpleITKs XYZ return np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0]) class SegmentationMetricsCompute(SegmentationMetrics): """ This class is used to compute the segmentation metrics for a patient. It inherits from SegmentationMetrics and overrides the score_patient method. """ def __init__(self, debug=False): super().__init__(debug=debug) self.names = ['DICE', 'HD95'] def init_storage(self, names: list): self.storage = dict() self.storage_id = [] self.names = names for name in names: self.storage[name] = [] def add(self, res: dict, patient_id=None): for key, value in res.items(): self.storage[key].append(value) if patient_id: self.storage_id.append(patient_id) def aggregate(self): res = dict() for name in self.names: res[name] = dict() for key, value in self.storage.items(): res[key]['mean'] = np.nanmean(value) res[key]['std'] = np.nanstd(value) res[key]['max'] = np.nanmax(value) res[key]['min'] = np.nanmin(value) res[key]['25pc'] = np.nanpercentile(value, 25) res[key]['50pc'] = np.nanpercentile(value, 50) res[key]['75pc'] = np.nanpercentile(value, 75) res[key]['count'] = len(value) return res def reset(self): for key, value in self.storage.items(): self.storage[key] = [] if __name__ == "__main__": # Example usage # metrics = SegmentationMetrics(debug=True) # gt_segmentation_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset800_SEGMENTATION_synthrad2025_task1_CT_AB_aligned_to_Dataset261/labelsTr/1ABA005.mha" # sct_segmentation_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset800_SEGMENTATION_synthrad2025_task1_CT_AB_aligned_to_Dataset261/labelsTr/1ABA005.mha" # gt_segmentation = sitk.GetArrayFromImage(sitk.ReadImage(gt_segmentation_path)) # sct_segmentation = sitk.GetArrayFromImage(sitk.ReadImage(sct_segmentation_path)) # mask = None # Example mask (not used in this example) # patient_id = "1ABA005" # Example patient ID # orientation = None # Example orientation (not used in this example) # result = metrics.score_patient(gt_segmentation, sct_segmentation, mask, patient_id, orientation) # print(result) # # real example # _segmentation_evaluator = SegmentationMetrics(debug=True) # patient_id = "1ABA011" # gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha" # gt_segmentation, spacing, origin, direction = load_image_file_directly(location=gt_segmentation_path, return_orientation=True) # # synthetic_ct_location = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/imagesTr/{patient_id}_0000.mha" # synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha" # # mask = None # mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha", set_orientation=(spacing, origin, direction)) # seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=(spacing, origin, direction)) # print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}") # # if we are in test phase, there is a doseplan for every patient in this folder # real example without orientation _segmentation_evaluator = SegmentationMetrics(debug=True) patient_id = "1ABA011" gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha" gt_segmentation = load_image_file_directly(location=gt_segmentation_path) # synthetic_ct_location = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/imagesTr/{patient_id}_0000.mha" synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha" # mask = None mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha") seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id) print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}")