|
|
from glob import glob |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
import SimpleITK as sitk |
|
|
import pandas as pd |
|
|
import json |
|
|
from image_metrics import ImageMetricsCompute |
|
|
import matplotlib.pyplot as plt |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ValidationResults(): |
|
|
""" |
|
|
Class to analyze the results of the predictions. |
|
|
It computes the metrics and saves the results in a folder. |
|
|
This is used directly in the nnUNetTrainerMRCT class. |
|
|
""" |
|
|
|
|
|
def __init__(self, pred_path, gt_path, mask_path, save_path=None): |
|
|
if not save_path: |
|
|
save_path = pred_path+'_analysis' |
|
|
print(f'Save path: {save_path}') |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
self.save_path = save_path |
|
|
|
|
|
self.pred_path = pred_path |
|
|
self.gt_path = gt_path |
|
|
self.mask_path = mask_path |
|
|
|
|
|
pred_files = sorted(glob(os.path.join(pred_path, '*.mha'))) |
|
|
self.patient_ids = [Path(pred_file).stem for pred_file in pred_files] |
|
|
|
|
|
|
|
|
self.test_metrics = ImageMetricsCompute() |
|
|
self.test_metrics.init_storage(["mae", "psnr", "ms_ssim"]) |
|
|
|
|
|
def process_patients(self): |
|
|
for patient_id in tqdm(self.patient_ids): |
|
|
self.process_a_patient(patient_id) |
|
|
dict_metric = self.analysis_patients() |
|
|
return dict_metric |
|
|
|
|
|
def analysis_patients(self): |
|
|
|
|
|
dict_metric = self.test_metrics.aggregate() |
|
|
with open(os.path.join(self.save_path, 'results_overall_masked.json'), 'w') as f: |
|
|
json.dump(dict_metric, f, indent=4) |
|
|
|
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
'patient_id': self.test_metrics.storage_id, |
|
|
'mae': self.test_metrics.storage['mae'], |
|
|
'ms_ssim': self.test_metrics.storage['ms_ssim'], |
|
|
'psnr': self.test_metrics.storage['psnr'], |
|
|
} |
|
|
) |
|
|
df.to_csv(os.path.join(self.save_path, 'results_individual.csv'), index=True) |
|
|
print("mean mae:", dict_metric['mae']['mean']) |
|
|
print("mean psnr:", dict_metric['psnr']['mean']) |
|
|
print("mean ms_ssim:", dict_metric['ms_ssim']['mean']) |
|
|
return dict_metric |
|
|
|
|
|
def process_a_patient(self, patient_id): |
|
|
pred_path = os.path.join(self.pred_path, f'{patient_id}.mha') |
|
|
gt_path = os.path.join(self.gt_path, f'{patient_id}.mha') |
|
|
mask_path = os.path.join(self.mask_path, f'{patient_id}.mha') |
|
|
|
|
|
|
|
|
img_pred = sitk.ReadImage(pred_path, sitk.sitkFloat32) |
|
|
img_gt = sitk.ReadImage(gt_path, sitk.sitkFloat32) |
|
|
img_mask = sitk.ReadImage(mask_path, sitk.sitkUInt8) |
|
|
|
|
|
|
|
|
array_pred = sitk.GetArrayFromImage(img_pred) |
|
|
array_gt = sitk.GetArrayFromImage(img_gt) |
|
|
array_mask = sitk.GetArrayFromImage(img_mask) |
|
|
res = self.test_metrics.score_patient(array_gt, array_pred, array_mask) |
|
|
self.test_metrics.add(res, patient_id) |
|
|
|
|
|
|
|
|
class TestingResults(): |
|
|
def __init__(self, pred_path, raw_image_path, save_path=None): |
|
|
if not save_path: |
|
|
save_path = pred_path+'_analysis' |
|
|
print(f'Save path: {save_path}') |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
self.pred_path = pred_path |
|
|
self.save_path = save_path |
|
|
self.raw_image_path = raw_image_path |
|
|
|
|
|
pred_files = sorted(glob(os.path.join(pred_path, '*.mha'))) |
|
|
self.patient_ids = [Path(pred_file).stem for pred_file in pred_files] |
|
|
self.col_names = ['src', 'pred', 'gt', 'mask', 'error'] |
|
|
|
|
|
|
|
|
self.test_metrics = ImageMetricsCompute() |
|
|
self.test_metrics.init_storage(["mae", "psnr", "ms_ssim"]) |
|
|
|
|
|
|
|
|
self.slice_pc_to_save = [25, 50, 75] |
|
|
for pc in self.slice_pc_to_save: |
|
|
save_path_pc = os.path.join(self.save_path, '{}pc_png'.format(pc)) |
|
|
if not os.path.exists(save_path_pc): |
|
|
os.makedirs(save_path_pc) |
|
|
print('Make path: {}'.format(save_path_pc)) |
|
|
|
|
|
|
|
|
self.save_path_all_3d_img = os.path.join(self.save_path, 'all_3d_img') |
|
|
if not os.path.exists(self.save_path_all_3d_img): |
|
|
os.makedirs(self.save_path_all_3d_img) |
|
|
|
|
|
def process_patients(self): |
|
|
for patient_id in tqdm(self.patient_ids): |
|
|
self.process_a_patient(patient_id) |
|
|
|
|
|
def analysis_patients(self): |
|
|
|
|
|
dict_metric = self.test_metrics.aggregate() |
|
|
with open(os.path.join(self.save_path, 'results_overall_masked.json'), 'w') as f: |
|
|
json.dump(dict_metric, f, indent=4) |
|
|
|
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
'patient_id': self.test_metrics.storage_id, |
|
|
'mae': self.test_metrics.storage['mae'], |
|
|
'ms_ssim': self.test_metrics.storage['ms_ssim'], |
|
|
'psnr': self.test_metrics.storage['psnr'], |
|
|
} |
|
|
) |
|
|
df.to_csv(os.path.join(self.save_path, 'results_individual.csv'), index=True) |
|
|
print("mean mae:", dict_metric['mae']['mean']) |
|
|
print("mean psnr:", dict_metric['psnr']['mean']) |
|
|
print("mean ms_ssim:", dict_metric['ms_ssim']['mean']) |
|
|
|
|
|
def process_a_patient(self, patient_id): |
|
|
pred_path = os.path.join(self.pred_path, f'{patient_id}.mha') |
|
|
src_path = os.path.join(self.raw_image_path, patient_id, 'mr.mha') |
|
|
gt_path = os.path.join(self.raw_image_path, patient_id, 'ct_deformed_padded.mha') |
|
|
mask_path = os.path.join(self.raw_image_path, patient_id, 'mask.mha') |
|
|
|
|
|
|
|
|
img_src = sitk.ReadImage(src_path) |
|
|
img_pred = sitk.ReadImage(pred_path, sitk.sitkFloat32) |
|
|
img_gt = sitk.ReadImage(gt_path, sitk.sitkFloat32) |
|
|
img_mask = sitk.ReadImage(mask_path, sitk.sitkUInt8) |
|
|
|
|
|
|
|
|
array_src = sitk.GetArrayFromImage(img_src) |
|
|
array_pred = sitk.GetArrayFromImage(img_pred) |
|
|
array_gt = sitk.GetArrayFromImage(img_gt) |
|
|
array_mask = sitk.GetArrayFromImage(img_mask) |
|
|
res = self.test_metrics.score_patient(array_gt, array_pred, array_mask) |
|
|
self.test_metrics.add(res, patient_id) |
|
|
|
|
|
|
|
|
self._save_error_image(img_pred, img_gt, img_mask, patient_id) |
|
|
self._copy_images(pred_path, src_path, gt_path, mask_path, patient_id) |
|
|
|
|
|
|
|
|
self._save_png_slice(array_src, array_pred, array_gt, array_mask, patient_id, pc=50) |
|
|
self._save_png_slice(array_src, array_pred, array_gt, array_mask, patient_id, pc=25) |
|
|
self._save_png_slice(array_src, array_pred, array_gt, array_mask, patient_id, pc=75) |
|
|
plt.close('all') |
|
|
|
|
|
def _save_error_image(self, img_pred, img_gt, img_mask, patient_id): |
|
|
|
|
|
img_err = sitk.AbsoluteValueDifference(img_pred, img_gt) |
|
|
img_err = sitk.Mask(img_err, img_mask, outsideValue=0) |
|
|
img_err.CopyInformation(img_pred) |
|
|
sitk.WriteImage(img_err, os.path.join(self.save_path_all_3d_img, f'{patient_id}_error.mha')) |
|
|
|
|
|
|
|
|
def _copy_images(self, pred_path, src_path, gt_path, mask_path, patient_id): |
|
|
shutil.copy(pred_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_pred.mha')) |
|
|
shutil.copy(src_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_src.mha')) |
|
|
shutil.copy(gt_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_gt.mha')) |
|
|
shutil.copy(mask_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_mask.mha')) |
|
|
|
|
|
def _save_png_slice(self, array_src, array_pred, array_gt, array_mask, patient_id, pc=50): |
|
|
|
|
|
slice_a0 = int(array_gt.shape[0] * pc / 100) |
|
|
slice_a1 = int(array_gt.shape[1] * pc / 100) |
|
|
slice_a2 = int(array_gt.shape[2] * pc / 100) |
|
|
rows = [] |
|
|
|
|
|
row_slices = [slice_a0, slice_a1, slice_a2] |
|
|
|
|
|
slice_a0_src = array_src[slice_a0, :, :] |
|
|
slice_a0_pred = array_pred[slice_a0, :, :] |
|
|
slice_a0_gt = array_gt[slice_a0, :, :] |
|
|
slice_a0_mask = array_mask[slice_a0, :, :] |
|
|
slice_a0_error = slice_a0_gt-slice_a0_pred |
|
|
slice_a0_error[~slice_a0_mask.astype('bool')] = 0 |
|
|
row_0 = [slice_a0_src, slice_a0_pred, slice_a0_gt, slice_a0_mask, slice_a0_error] |
|
|
rows.append(row_0) |
|
|
|
|
|
slice_a1_src = array_src[:, slice_a1, :] |
|
|
slice_a1_pred = array_pred[:, slice_a1, :] |
|
|
slice_a1_gt = array_gt[:, slice_a1, :] |
|
|
slice_a1_mask = array_mask[:, slice_a1, :] |
|
|
slice_a1_error = slice_a1_gt - slice_a1_pred |
|
|
slice_a1_error[~slice_a1_mask.astype('bool')] = 0 |
|
|
row_1 = [slice_a1_src, slice_a1_pred, slice_a1_gt, slice_a1_mask, slice_a1_error] |
|
|
rows.append(row_1) |
|
|
|
|
|
slice_a2_src = array_src[:, :, slice_a2] |
|
|
slice_a2_pred = array_pred[:, :, slice_a2] |
|
|
slice_a2_gt = array_gt[:, :, slice_a2] |
|
|
slice_a2_mask = array_mask[:, :, slice_a2] |
|
|
slice_a2_error = slice_a2_gt - slice_a2_pred |
|
|
slice_a2_error[~slice_a2_mask.astype('bool')] = 0 |
|
|
row_2 = [slice_a2_src, slice_a2_pred, slice_a2_gt, slice_a2_mask, slice_a2_error] |
|
|
rows.append(row_2) |
|
|
|
|
|
fig, ax = plt.subplots(3, len(row_0), figsize=(15, 10)) |
|
|
for row in range(3): |
|
|
for col in range(len(row_0)): |
|
|
if col < 4: |
|
|
if col ==1 or col == 2: |
|
|
ax[row, col].imshow(rows[row][col], cmap='gray', vmin=-1024, vmax=2000) |
|
|
else: |
|
|
ax[row, col].imshow(rows[row][col], cmap='gray') |
|
|
else: |
|
|
ax[row, col].imshow(rows[row][col], cmap='twilight_shifted') |
|
|
ax[row, col].axis('off') |
|
|
ax[row, col].set_title('{}_slice{}'.format(self.col_names[col], row_slices[row])) |
|
|
fig.subplots_adjust(wspace=0.05, top=0.8) |
|
|
save_path = os.path.join(self.save_path, '{}pc_png' .format(pc)) |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path) |
|
|
fig.savefig(os.path.join(save_path, '{}.png'.format(patient_id))) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_path = "/datasets/work/hb-synthrad2023/source/synthrad2025_data_v2/synthRAD2025_Task1_Train/Task1/AB" |
|
|
pred_path_revert_norm = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset340_synthrad2025_task1_MR_AB_mednext/nnUNetTrainerV2_MedNeXt_L_kernel5__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm" |
|
|
vs = TestingResults(pred_path_revert_norm, task=1, region='AB') |
|
|
vs.process_a_patient('1ABA011') |
|
|
|
|
|
|
|
|
|
|
|
|