File size: 10,767 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from typing import Optional
import nibabel as nib
import os
import torch
import SimpleITK as sitk
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from nibabel.nifti1 import Nifti1Image
from nnunetv2.analysis.ts_utils import MinialTotalSegmentator
class SegmentationMetrics():
def __init__(self, debug=False):
# Use fixed wide dynamic range
self.debug = debug
self.dynamic_range = [-1024., 3000.]
self.my_ts = MinialTotalSegmentator(verbose=self.debug)
# TotalSegmentator classes. See here https://github.com/wasserth/TotalSegmentator?tab=readme-ov-file#class-details (TotalSegmenator commit cd3d5362245237f13adbb78cdfaee615f54096a1)
self.classes_to_use = {
"AB": [
2, # kidney right
3, # kidney left
5, # liver
6, # stomach
*range(10, 14+1), #lungs
*range(26, 50+1), #vertebrae
51, #heart
79, # spinal cord
*range(92, 115+1), # ribs
116 #sternum
],
"HN": [
15, # esophagus
16, # trachea
17, # thyroid
*range(26, 50+1), #vertebrae
79, #spinal cord
90, # brain
91, # skull
],
"TH": [
2, # kidney right
3, # kidney left
5, # liver
6, # stomach
*range(10, 14+1), #lungs
*range(26, 50+1), #vertebrae
51, #heart
79, # spinal cord
*range(92, 115+1), # ribs
116 #sternum
]
}
def score_patient_ts(self, synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=None, save_pred_seg_path=None):
with torch.no_grad():
pred_seg=self.my_ts.score_patient(synthetic_ct_location, orientation, mask, save_pred_seg_path=save_pred_seg_path)
# Retrieve the data in the NiftiImage from nibabel
if isinstance(pred_seg, Nifti1Image):
pred_seg = np.array(pred_seg.get_fdata())
return self.score_patient(gt_segmentation, pred_seg, mask, patient_id, orientation)
def score_patient(self, gt_segmentation, sct_segmentation, mask, patient_id, orientation=None):
# Calculate segmentation metrics
# Perform segmentation using TotalSegmentator, enforce the orientation of the ground-truth on the output
anatomy = patient_id[1:3].upper()
assert sct_segmentation.shape == gt_segmentation.shape
# Convert to PyTorch tensors for MONAI
gt_seg = gt_segmentation.cpu().detach() if torch.is_tensor(gt_segmentation) else torch.from_numpy(gt_segmentation).cpu().detach()
pred_seg = sct_segmentation.cpu().detach() if torch.is_tensor(sct_segmentation) else torch.from_numpy(sct_segmentation).cpu().detach()
assert gt_seg.shape == pred_seg.shape
if orientation is not None:
spacing, origin, direction = orientation
else:
spacing=None
# list of metrics to evaluate
metrics = [
{
'name': 'DICE',
'f':DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
}, {
'name': 'HD95',
'f': HausdorffDistanceMetric(include_background=True, reduction="mean", percentile=95, get_not_nans=False),
'kwargs': {'spacing': spacing}
}
]
# Evaluate each one-hot metric
for c in self.classes_to_use[anatomy]:
gt_tensor = (gt_seg == c).view(1, 1, *gt_seg.shape)
if gt_tensor.sum() == 0:
if self.debug:
print(f"No {c} in {patient_id}")
continue
est_tensor = (pred_seg == c).view(1, 1, *pred_seg.shape)
for metric in metrics:
metric['f'](est_tensor, gt_tensor, **metric['kwargs'] if 'kwargs' in metric else {})
# aggregate the mean metrics for the patient over the classes
result = {}
for metric in metrics:
result[metric['name']] = metric['f'].aggregate().item()
metric['f'].reset()
return result
def load_image_file_directly(*, location, return_orientation=False, set_orientation=None):
# immediatly load the file and find its orientation
result = sitk.ReadImage(location)
# Note, transpose needed because Numpy is ZYX according to SimpleITKs XYZ
img_arr = np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0])
if return_orientation:
spacing = result.GetSpacing()
origin = result.GetOrigin()
direction = result.GetDirection()
return img_arr, spacing, origin, direction
else:
# If desired, force the orientation on an image before converting to NumPy array
if set_orientation is not None:
spacing, origin, direction = set_orientation
result.SetSpacing(spacing)
result.SetOrigin(origin)
result.SetDirection(direction)
# Note, transpose needed because Numpy is ZYX according to SimpleITKs XYZ
return np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0])
class SegmentationMetricsCompute(SegmentationMetrics):
"""
This class is used to compute the segmentation metrics for a patient.
It inherits from SegmentationMetrics and overrides the score_patient method.
"""
def __init__(self, debug=False):
super().__init__(debug=debug)
self.names = ['DICE', 'HD95']
def init_storage(self, names: list):
self.storage = dict()
self.storage_id = []
self.names = names
for name in names:
self.storage[name] = []
def add(self, res: dict, patient_id=None):
for key, value in res.items():
self.storage[key].append(value)
if patient_id:
self.storage_id.append(patient_id)
def aggregate(self):
res = dict()
for name in self.names:
res[name] = dict()
for key, value in self.storage.items():
res[key]['mean'] = np.nanmean(value)
res[key]['std'] = np.nanstd(value)
res[key]['max'] = np.nanmax(value)
res[key]['min'] = np.nanmin(value)
res[key]['25pc'] = np.nanpercentile(value, 25)
res[key]['50pc'] = np.nanpercentile(value, 50)
res[key]['75pc'] = np.nanpercentile(value, 75)
res[key]['count'] = len(value)
return res
def reset(self):
for key, value in self.storage.items():
self.storage[key] = []
if __name__ == "__main__":
# Example usage
# metrics = SegmentationMetrics(debug=True)
# gt_segmentation_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset800_SEGMENTATION_synthrad2025_task1_CT_AB_aligned_to_Dataset261/labelsTr/1ABA005.mha"
# sct_segmentation_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset800_SEGMENTATION_synthrad2025_task1_CT_AB_aligned_to_Dataset261/labelsTr/1ABA005.mha"
# gt_segmentation = sitk.GetArrayFromImage(sitk.ReadImage(gt_segmentation_path))
# sct_segmentation = sitk.GetArrayFromImage(sitk.ReadImage(sct_segmentation_path))
# mask = None # Example mask (not used in this example)
# patient_id = "1ABA005" # Example patient ID
# orientation = None # Example orientation (not used in this example)
# result = metrics.score_patient(gt_segmentation, sct_segmentation, mask, patient_id, orientation)
# print(result)
# # real example
# _segmentation_evaluator = SegmentationMetrics(debug=True)
# patient_id = "1ABA011"
# gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha"
# gt_segmentation, spacing, origin, direction = load_image_file_directly(location=gt_segmentation_path, return_orientation=True)
# # synthetic_ct_location = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/imagesTr/{patient_id}_0000.mha"
# synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha"
# # mask = None
# mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha", set_orientation=(spacing, origin, direction))
# seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=(spacing, origin, direction))
# print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}")
# # if we are in test phase, there is a doseplan for every patient in this folder
# real example without orientation
_segmentation_evaluator = SegmentationMetrics(debug=True)
patient_id = "1ABA011"
gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha"
gt_segmentation = load_image_file_directly(location=gt_segmentation_path)
# synthetic_ct_location = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/imagesTr/{patient_id}_0000.mha"
synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha"
# mask = None
mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha")
seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id)
print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}") |