|
|
|
|
|
|
|
|
import numpy as np |
|
|
from typing import Optional |
|
|
import nibabel as nib |
|
|
import os |
|
|
import torch |
|
|
import SimpleITK as sitk |
|
|
from monai.metrics import DiceMetric, HausdorffDistanceMetric |
|
|
from nibabel.nifti1 import Nifti1Image |
|
|
from nnunetv2.analysis.ts_utils import MinialTotalSegmentator |
|
|
|
|
|
|
|
|
|
|
|
class SegmentationMetrics(): |
|
|
def __init__(self, debug=False): |
|
|
|
|
|
self.debug = debug |
|
|
self.dynamic_range = [-1024., 3000.] |
|
|
self.my_ts = MinialTotalSegmentator(verbose=self.debug) |
|
|
|
|
|
|
|
|
self.classes_to_use = { |
|
|
"AB": [ |
|
|
2, |
|
|
3, |
|
|
5, |
|
|
6, |
|
|
*range(10, 14+1), |
|
|
*range(26, 50+1), |
|
|
51, |
|
|
79, |
|
|
*range(92, 115+1), |
|
|
116 |
|
|
], |
|
|
"HN": [ |
|
|
15, |
|
|
16, |
|
|
17, |
|
|
*range(26, 50+1), |
|
|
79, |
|
|
90, |
|
|
91, |
|
|
], |
|
|
"TH": [ |
|
|
2, |
|
|
3, |
|
|
5, |
|
|
6, |
|
|
*range(10, 14+1), |
|
|
*range(26, 50+1), |
|
|
51, |
|
|
79, |
|
|
*range(92, 115+1), |
|
|
116 |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
def score_patient_ts(self, synthetic_ct_location, mask, gt_segmentation, patient_id, orientation=None, save_pred_seg_path=None): |
|
|
with torch.no_grad(): |
|
|
pred_seg=self.my_ts.score_patient(synthetic_ct_location, orientation, mask, save_pred_seg_path=save_pred_seg_path) |
|
|
|
|
|
if isinstance(pred_seg, Nifti1Image): |
|
|
pred_seg = np.array(pred_seg.get_fdata()) |
|
|
|
|
|
return self.score_patient(gt_segmentation, pred_seg, mask, patient_id, orientation) |
|
|
|
|
|
|
|
|
|
|
|
def score_patient(self, gt_segmentation, sct_segmentation, mask, patient_id, orientation=None): |
|
|
|
|
|
|
|
|
|
|
|
anatomy = patient_id[1:3].upper() |
|
|
|
|
|
assert sct_segmentation.shape == gt_segmentation.shape |
|
|
|
|
|
|
|
|
gt_seg = gt_segmentation.cpu().detach() if torch.is_tensor(gt_segmentation) else torch.from_numpy(gt_segmentation).cpu().detach() |
|
|
pred_seg = sct_segmentation.cpu().detach() if torch.is_tensor(sct_segmentation) else torch.from_numpy(sct_segmentation).cpu().detach() |
|
|
|
|
|
|
|
|
assert gt_seg.shape == pred_seg.shape |
|
|
if orientation is not None: |
|
|
spacing, origin, direction = orientation |
|
|
else: |
|
|
spacing=None |
|
|
|
|
|
|
|
|
metrics = [ |
|
|
{ |
|
|
'name': 'DICE', |
|
|
'f':DiceMetric(include_background=True, reduction="mean", get_not_nans=False) |
|
|
}, { |
|
|
'name': 'HD95', |
|
|
'f': HausdorffDistanceMetric(include_background=True, reduction="mean", percentile=95, get_not_nans=False), |
|
|
'kwargs': {'spacing': spacing} |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
for c in self.classes_to_use[anatomy]: |
|
|
gt_tensor = (gt_seg == c).view(1, 1, *gt_seg.shape) |
|
|
if gt_tensor.sum() == 0: |
|
|
if self.debug: |
|
|
print(f"No {c} in {patient_id}") |
|
|
continue |
|
|
est_tensor = (pred_seg == c).view(1, 1, *pred_seg.shape) |
|
|
for metric in metrics: |
|
|
metric['f'](est_tensor, gt_tensor, **metric['kwargs'] if 'kwargs' in metric else {}) |
|
|
|
|
|
|
|
|
result = {} |
|
|
for metric in metrics: |
|
|
result[metric['name']] = metric['f'].aggregate().item() |
|
|
metric['f'].reset() |
|
|
return result |
|
|
|
|
|
def load_image_file_directly(*, location, return_orientation=False, set_orientation=None): |
|
|
|
|
|
result = sitk.ReadImage(location) |
|
|
|
|
|
img_arr = np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0]) |
|
|
|
|
|
if return_orientation: |
|
|
spacing = result.GetSpacing() |
|
|
origin = result.GetOrigin() |
|
|
direction = result.GetDirection() |
|
|
|
|
|
|
|
|
return img_arr, spacing, origin, direction |
|
|
else: |
|
|
|
|
|
if set_orientation is not None: |
|
|
spacing, origin, direction = set_orientation |
|
|
result.SetSpacing(spacing) |
|
|
result.SetOrigin(origin) |
|
|
result.SetDirection(direction) |
|
|
|
|
|
|
|
|
return np.transpose(sitk.GetArrayFromImage(result), [2, 1, 0]) |
|
|
|
|
|
|
|
|
class SegmentationMetricsCompute(SegmentationMetrics): |
|
|
""" |
|
|
This class is used to compute the segmentation metrics for a patient. |
|
|
It inherits from SegmentationMetrics and overrides the score_patient method. |
|
|
""" |
|
|
def __init__(self, debug=False): |
|
|
super().__init__(debug=debug) |
|
|
self.names = ['DICE', 'HD95'] |
|
|
|
|
|
def init_storage(self, names: list): |
|
|
self.storage = dict() |
|
|
self.storage_id = [] |
|
|
self.names = names |
|
|
for name in names: |
|
|
self.storage[name] = [] |
|
|
|
|
|
def add(self, res: dict, patient_id=None): |
|
|
for key, value in res.items(): |
|
|
self.storage[key].append(value) |
|
|
if patient_id: |
|
|
self.storage_id.append(patient_id) |
|
|
|
|
|
def aggregate(self): |
|
|
res = dict() |
|
|
for name in self.names: |
|
|
res[name] = dict() |
|
|
|
|
|
for key, value in self.storage.items(): |
|
|
res[key]['mean'] = np.nanmean(value) |
|
|
res[key]['std'] = np.nanstd(value) |
|
|
res[key]['max'] = np.nanmax(value) |
|
|
res[key]['min'] = np.nanmin(value) |
|
|
res[key]['25pc'] = np.nanpercentile(value, 25) |
|
|
res[key]['50pc'] = np.nanpercentile(value, 50) |
|
|
res[key]['75pc'] = np.nanpercentile(value, 75) |
|
|
res[key]['count'] = len(value) |
|
|
return res |
|
|
|
|
|
def reset(self): |
|
|
for key, value in self.storage.items(): |
|
|
self.storage[key] = [] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_segmentation_evaluator = SegmentationMetrics(debug=True) |
|
|
|
|
|
patient_id = "1ABA011" |
|
|
gt_segmentation_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/raw/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/labelsTr/{patient_id}.mha" |
|
|
gt_segmentation = load_image_file_directly(location=gt_segmentation_path) |
|
|
|
|
|
|
|
|
synthetic_ct_location = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset280_synthrad2025_task1_MR_AB_pre_v2r_stitched/nnUNetTrainerMRCT_track__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm/1ABA011.mha" |
|
|
|
|
|
|
|
|
mask = load_image_file_directly(location=f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset260_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked/masks/{patient_id}.mha") |
|
|
|
|
|
seg_metrics = _segmentation_evaluator.score_patient_ts(synthetic_ct_location, mask, gt_segmentation, patient_id) |
|
|
print(f"Segmentation metrics for patient {patient_id}: {seg_metrics}") |