File size: 10,767 Bytes
19c1f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from typing import Optional
import nibabel as nib
import os
import torch
import SimpleITK as sitk
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from nibabel.nifti1 import Nifti1Image
from nnunetv2.analysis.ts_utils import MinialTotalSegmentator



class SegmentationMetrics():
    def __init__(self, debug=False):
        # Use fixed wide dynamic range
        self.debug = debug
        self.dynamic_range = [-1024., 3000.]
        self.my_ts = MinialTotalSegmentator(verbose=self.debug)

        # TotalSegmentator classes. See here https://github.com/wasserth/TotalSegmentator?tab=readme-ov-file#class-details (TotalSegmenator commit cd3d5362245237f13adbb78cdfaee615f54096a1)
        self.classes_to_use = {
            "AB": [
                2, # kidney right
                3, # kidney left
                5, # liver
                6, # stomach
                *range(10, 14+1), #lungs
                *range(26, 50+1), #vertebrae
                51, #heart
                79, # spinal cord
                *range(92, 115+1), # ribs
                116 #sternum
            ],
            "HN": [
                15, # esophagus
                16, # trachea
                17, # thyroid
                *range(26, 50+1), #vertebrae
                79, #spinal cord
                90, # brain
                91, # skull
            ],
            "TH": [
                2, # kidney right
                3, # kidney left
                5, # liver
                6, # stomach
                *range(10, 14+1), #lungs
                *range(26, 50+1), #vertebrae
                51, #heart
                79, # spinal cord
                *range(92, 115+1), # ribs
                116 #sternum
            ]
        }

    
    def score_patient_ts(self, synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=None, save_pred_seg_path=None):
        with torch.no_grad():
            pred_seg=self.my_ts.score_patient(synthetic_ct_location, orientation, mask, save_pred_seg_path=save_pred_seg_path)
        # Retrieve the data in the NiftiImage from nibabel
        if isinstance(pred_seg, Nifti1Image):
            pred_seg = np.array(pred_seg.get_fdata())

        return self.score_patient(gt_segmentation, pred_seg, mask, patient_id, orientation)


    
    def score_patient(self, gt_segmentation, sct_segmentation, mask, patient_id, orientation=None):        
        # Calculate segmentation metrics
        # Perform segmentation using TotalSegmentator, enforce the orientation of the ground-truth on the output

        anatomy = patient_id[1:3].upper()

        assert sct_segmentation.shape == gt_segmentation.shape

        # Convert to PyTorch tensors for MONAI
        gt_seg = gt_segmentation.cpu().detach() if torch.is_tensor(gt_segmentation) else torch.from_numpy(gt_segmentation).cpu().detach()
        pred_seg = sct_segmentation.cpu().detach() if torch.is_tensor(sct_segmentation) else torch.from_numpy(sct_segmentation).cpu().detach()


        assert gt_seg.shape == pred_seg.shape
        if orientation is not None:
            spacing, origin, direction = orientation
        else:
            spacing=None
        
        # list of metrics to evaluate
        metrics = [
            {
                'name': 'DICE',
                'f':DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
            }, {
                'name': 'HD95',
                'f': HausdorffDistanceMetric(include_background=True, reduction="mean", percentile=95, get_not_nans=False),
                'kwargs': {'spacing': spacing}
            }
        ]

        # Evaluate each one-hot metric 
        for c in self.classes_to_use[anatomy]:
            gt_tensor = (gt_seg == c).view(1, 1, *gt_seg.shape)
            if gt_tensor.sum() == 0:
                if self.debug:
                    print(f"No {c} in {patient_id}")
                continue
            est_tensor = (pred_seg == c).view(1, 1, *pred_seg.shape)
            for metric in metrics:
                metric['f'](est_tensor, gt_tensor, **metric['kwargs'] if 'kwargs' in metric else {})

        # aggregate the mean metrics for the patient over the classes
        result = {}
        for metric in metrics:
            result[metric['name']] = metric['f'].aggregate().item()
            metric['f'].reset()
        return result
    
def load_image_file_directly(*, location, return_orientation=False, set_orientation=None):
    # immediatly load the file and find its orientation
    result = sitk.ReadImage(location)
    # Note, transpose needed because Numpy is ZYX according to SimpleITKs XYZ
    img_arr = np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0])

    if return_orientation:
        spacing = result.GetSpacing()
        origin = result.GetOrigin()
        direction = result.GetDirection()


        return img_arr, spacing, origin, direction
    else:
        # If desired, force the orientation on an image before converting to NumPy array
        if set_orientation is not None:
            spacing, origin, direction = set_orientation
            result.SetSpacing(spacing)
            result.SetOrigin(origin)
            result.SetDirection(direction)

        # Note, transpose needed because Numpy is ZYX according to SimpleITKs XYZ
        return np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0])


class SegmentationMetricsCompute(SegmentationMetrics):
    """
    This class is used to compute the segmentation metrics for a patient.
    It inherits from SegmentationMetrics and overrides the score_patient method.
    """
    def __init__(self, debug=False):
        super().__init__(debug=debug)
        self.names = ['DICE', 'HD95']

    def init_storage(self, names: list):
        self.storage = dict()
        self.storage_id = []
        self.names = names
        for name in names:
            self.storage[name] = []

    def add(self, res: dict, patient_id=None):
        for key, value in res.items():
            self.storage[key].append(value)
        if patient_id:
            self.storage_id.append(patient_id)

    def aggregate(self):
        res = dict()
        for name in self.names:
            res[name] = dict()

        for key, value in self.storage.items():
            res[key]['mean'] = np.nanmean(value)
            res[key]['std'] = np.nanstd(value)
            res[key]['max'] = np.nanmax(value)
            res[key]['min'] = np.nanmin(value)
            res[key]['25pc'] = np.nanpercentile(value, 25)
            res[key]['50pc'] = np.nanpercentile(value, 50)
            res[key]['75pc'] = np.nanpercentile(value, 75)
            res[key]['count'] = len(value)
        return res

    def reset(self):
        for key, value in self.storage.items():
            self.storage[key] = []



if __name__ == "__main__":
    # Example usage
    # metrics = SegmentationMetrics(debug=True)
    # gt_segmentation_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset800_SEGMENTATION_synthrad2025_task1_CT_AB_aligned_to_Dataset261/labelsTr/1ABA005.mha"
    # sct_segmentation_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset800_SEGMENTATION_synthrad2025_task1_CT_AB_aligned_to_Dataset261/labelsTr/1ABA005.mha"
    # gt_segmentation = sitk.GetArrayFromImage(sitk.ReadImage(gt_segmentation_path))
    # sct_segmentation = sitk.GetArrayFromImage(sitk.ReadImage(sct_segmentation_path))
    # mask = None  # Example mask (not used in this example)
    # patient_id = "1ABA005"  # Example patient ID
    # orientation = None  # Example orientation (not used in this example)

    # result = metrics.score_patient(gt_segmentation, sct_segmentation, mask, patient_id, orientation)
    # print(result)


    # # real example
    # _segmentation_evaluator = SegmentationMetrics(debug=True)

    # patient_id = "1ABA011"
    # gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha"
    # gt_segmentation, spacing, origin, direction = load_image_file_directly(location=gt_segmentation_path, return_orientation=True)

    # # synthetic_ct_location = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/imagesTr/{patient_id}_0000.mha"
    # synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha"

    # # mask = None
    # mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha", set_orientation=(spacing, origin, direction))

    # seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=(spacing, origin, direction))
    # print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}")
    # # if we are in test phase, there is a doseplan for every patient in this folder

    # real example without orientation
    _segmentation_evaluator = SegmentationMetrics(debug=True)

    patient_id = "1ABA011"
    gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha"
    gt_segmentation = load_image_file_directly(location=gt_segmentation_path)

    # synthetic_ct_location = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/imagesTr/{patient_id}_0000.mha"
    synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha"

    # mask = None
    mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha")

    seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id)
    print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}")