synthrad2025_docker / docker_task_2 /nnunetv2 /analysis /result_analysis_not_used.py
FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
from glob import glob
import os
from tqdm import tqdm
import SimpleITK as sitk
import pandas as pd
import json
from image_metrics import ImageMetricsCompute
import matplotlib.pyplot as plt
from pathlib import Path
import shutil
# def compute_folder_metrics(pred_path, gt_path, mask_path):
# """
# Analyze the results of the predictions.
# """
# pred_files = sorted(glob(os.path.join(pred_path, '*.mha')))
# # gt_path = os.path.join(raw_data_path, "gt_segmentations")
# # mask_path = os.path.join(raw_data_path, "masks")
# testing_metrics = ImageMetricsCompute()
# testing_metrics.init_storage(["mae", "psnr", "ms_ssim"])
# for pred_file in tqdm(pred_files):
# filename = os.path.basename(pred_file)
# gt_file = os.path.join(gt_path, filename)
# mask_file = os.path.join(mask_path, filename)
# img_pred = sitk.ReadImage(pred_file)
# img_gt = sitk.ReadImage(gt_file)
# img_mask = sitk.ReadImage(mask_file, sitk.sitkUInt8)
# array_pred = sitk.GetArrayFromImage(img_pred)
# array_gt = sitk.GetArrayFromImage(img_gt)
# array_mask = sitk.GetArrayFromImage(img_mask)
# res = testing_metrics.score_patient(array_gt, array_pred, array_mask)
# testing_metrics.add(res, filename)
# # aggregate results
# results = testing_metrics.aggregate()
# df = pd.DataFrame(
# {
# 'patient_id': testing_metrics.storage_id,
# 'mae': testing_metrics.storage['mae'],
# 'ms_ssim': testing_metrics.storage['ms_ssim'],
# 'psnr': testing_metrics.storage['psnr'],
# }
# )
# # save results, and df in the folder "results"
# results_path = os.path.join(pred_path, "results.json")
# df_path = os.path.join(pred_path, "results_individual.csv")
# os.makedirs(os.path.dirname(results_path), exist_ok=True)
# with open(results_path, 'w') as f:
# json.dump(results, f, indent=4)
# df.to_csv(df_path, index=False)
# return results, df
class ValidationResults():
"""
Class to analyze the results of the predictions.
It computes the metrics and saves the results in a folder.
This is used directly in the nnUNetTrainerMRCT class.
"""
def __init__(self, pred_path, gt_path, mask_path, save_path=None):
if not save_path:
save_path = pred_path+'_analysis'
print(f'Save path: {save_path}')
os.makedirs(save_path, exist_ok=True)
self.save_path = save_path
self.pred_path = pred_path
self.gt_path = gt_path
self.mask_path = mask_path
pred_files = sorted(glob(os.path.join(pred_path, '*.mha')))
self.patient_ids = [Path(pred_file).stem for pred_file in pred_files]
# init image metrics
self.test_metrics = ImageMetricsCompute()
self.test_metrics.init_storage(["mae", "psnr", "ms_ssim"])
def process_patients(self):
for patient_id in tqdm(self.patient_ids):
self.process_a_patient(patient_id)
dict_metric = self.analysis_patients()
return dict_metric
def analysis_patients(self):
# save aggregated metrics
dict_metric = self.test_metrics.aggregate()
with open(os.path.join(self.save_path, 'results_overall_masked.json'), 'w') as f:
json.dump(dict_metric, f, indent=4)
# save individual metric
df = pd.DataFrame(
{
'patient_id': self.test_metrics.storage_id,
'mae': self.test_metrics.storage['mae'],
'ms_ssim': self.test_metrics.storage['ms_ssim'],
'psnr': self.test_metrics.storage['psnr'],
}
)
df.to_csv(os.path.join(self.save_path, 'results_individual.csv'), index=True)
print("mean mae:", dict_metric['mae']['mean'])
print("mean psnr:", dict_metric['psnr']['mean'])
print("mean ms_ssim:", dict_metric['ms_ssim']['mean'])
return dict_metric
def process_a_patient(self, patient_id):
pred_path = os.path.join(self.pred_path, f'{patient_id}.mha')
gt_path = os.path.join(self.gt_path, f'{patient_id}.mha') #### to be changed; hand-coded here
mask_path = os.path.join(self.mask_path, f'{patient_id}.mha')
# read images
img_pred = sitk.ReadImage(pred_path, sitk.sitkFloat32)
img_gt = sitk.ReadImage(gt_path, sitk.sitkFloat32)
img_mask = sitk.ReadImage(mask_path, sitk.sitkUInt8)
# compute scores
array_pred = sitk.GetArrayFromImage(img_pred)
array_gt = sitk.GetArrayFromImage(img_gt)
array_mask = sitk.GetArrayFromImage(img_mask)
res = self.test_metrics.score_patient(array_gt, array_pred, array_mask)
self.test_metrics.add(res, patient_id)
class TestingResults():
def __init__(self, pred_path, raw_image_path, save_path=None):
if not save_path:
save_path = pred_path+'_analysis'
print(f'Save path: {save_path}')
os.makedirs(save_path, exist_ok=True)
self.pred_path = pred_path
self.save_path = save_path
self.raw_image_path = raw_image_path
pred_files = sorted(glob(os.path.join(pred_path, '*.mha')))
self.patient_ids = [Path(pred_file).stem for pred_file in pred_files]
self.col_names = ['src', 'pred', 'gt', 'mask', 'error']
# init image metrics
self.test_metrics = ImageMetricsCompute()
self.test_metrics.init_storage(["mae", "psnr", "ms_ssim"])
# init save sub-folders
self.slice_pc_to_save = [25, 50, 75]
for pc in self.slice_pc_to_save:
save_path_pc = os.path.join(self.save_path, '{}pc_png'.format(pc))
if not os.path.exists(save_path_pc):
os.makedirs(save_path_pc)
print('Make path: {}'.format(save_path_pc))
# all 3d images for analysis
self.save_path_all_3d_img = os.path.join(self.save_path, 'all_3d_img')
if not os.path.exists(self.save_path_all_3d_img):
os.makedirs(self.save_path_all_3d_img)
def process_patients(self):
for patient_id in tqdm(self.patient_ids):
self.process_a_patient(patient_id)
def analysis_patients(self):
# save aggregated metrics
dict_metric = self.test_metrics.aggregate()
with open(os.path.join(self.save_path, 'results_overall_masked.json'), 'w') as f:
json.dump(dict_metric, f, indent=4)
# save individual metric
df = pd.DataFrame(
{
'patient_id': self.test_metrics.storage_id,
'mae': self.test_metrics.storage['mae'],
'ms_ssim': self.test_metrics.storage['ms_ssim'],
'psnr': self.test_metrics.storage['psnr'],
}
)
df.to_csv(os.path.join(self.save_path, 'results_individual.csv'), index=True)
print("mean mae:", dict_metric['mae']['mean'])
print("mean psnr:", dict_metric['psnr']['mean'])
print("mean ms_ssim:", dict_metric['ms_ssim']['mean'])
def process_a_patient(self, patient_id):
pred_path = os.path.join(self.pred_path, f'{patient_id}.mha')
src_path = os.path.join(self.raw_image_path, patient_id, 'mr.mha')
gt_path = os.path.join(self.raw_image_path, patient_id, 'ct_deformed_padded.mha') #### to be changed; hand-coded here
mask_path = os.path.join(self.raw_image_path, patient_id, 'mask.mha')
# read images
img_src = sitk.ReadImage(src_path)
img_pred = sitk.ReadImage(pred_path, sitk.sitkFloat32)
img_gt = sitk.ReadImage(gt_path, sitk.sitkFloat32)
img_mask = sitk.ReadImage(mask_path, sitk.sitkUInt8)
# compute scores
array_src = sitk.GetArrayFromImage(img_src)
array_pred = sitk.GetArrayFromImage(img_pred)
array_gt = sitk.GetArrayFromImage(img_gt)
array_mask = sitk.GetArrayFromImage(img_mask)
res = self.test_metrics.score_patient(array_gt, array_pred, array_mask)
self.test_metrics.add(res, patient_id)
# save error images
self._save_error_image(img_pred, img_gt, img_mask, patient_id)
self._copy_images(pred_path, src_path, gt_path, mask_path, patient_id)
# save_png_slice
self._save_png_slice(array_src, array_pred, array_gt, array_mask, patient_id, pc=50)
self._save_png_slice(array_src, array_pred, array_gt, array_mask, patient_id, pc=25)
self._save_png_slice(array_src, array_pred, array_gt, array_mask, patient_id, pc=75)
plt.close('all')
def _save_error_image(self, img_pred, img_gt, img_mask, patient_id):
# save error images
img_err = sitk.AbsoluteValueDifference(img_pred, img_gt)
img_err = sitk.Mask(img_err, img_mask, outsideValue=0)
img_err.CopyInformation(img_pred)
sitk.WriteImage(img_err, os.path.join(self.save_path_all_3d_img, f'{patient_id}_error.mha'))
# print('Save Error images: ', os.path.join(save_err_path, f'{patient_id}.mha'))
def _copy_images(self, pred_path, src_path, gt_path, mask_path, patient_id):
shutil.copy(pred_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_pred.mha'))
shutil.copy(src_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_src.mha'))
shutil.copy(gt_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_gt.mha'))
shutil.copy(mask_path, os.path.join(self.save_path_all_3d_img, f'{patient_id}_mask.mha'))
def _save_png_slice(self, array_src, array_pred, array_gt, array_mask, patient_id, pc=50):
# init parameters
slice_a0 = int(array_gt.shape[0] * pc / 100)
slice_a1 = int(array_gt.shape[1] * pc / 100)
slice_a2 = int(array_gt.shape[2] * pc / 100)
rows = []
row_slices = [slice_a0, slice_a1, slice_a2]
# axial images
slice_a0_src = array_src[slice_a0, :, :]
slice_a0_pred = array_pred[slice_a0, :, :]
slice_a0_gt = array_gt[slice_a0, :, :]
slice_a0_mask = array_mask[slice_a0, :, :]
slice_a0_error = slice_a0_gt-slice_a0_pred
slice_a0_error[~slice_a0_mask.astype('bool')] = 0
row_0 = [slice_a0_src, slice_a0_pred, slice_a0_gt, slice_a0_mask, slice_a0_error]
rows.append(row_0)
# coronal images
slice_a1_src = array_src[:, slice_a1, :]
slice_a1_pred = array_pred[:, slice_a1, :]
slice_a1_gt = array_gt[:, slice_a1, :]
slice_a1_mask = array_mask[:, slice_a1, :]
slice_a1_error = slice_a1_gt - slice_a1_pred
slice_a1_error[~slice_a1_mask.astype('bool')] = 0
row_1 = [slice_a1_src, slice_a1_pred, slice_a1_gt, slice_a1_mask, slice_a1_error]
rows.append(row_1)
# sagital images
slice_a2_src = array_src[:, :, slice_a2]
slice_a2_pred = array_pred[:, :, slice_a2]
slice_a2_gt = array_gt[:, :, slice_a2]
slice_a2_mask = array_mask[:, :, slice_a2]
slice_a2_error = slice_a2_gt - slice_a2_pred
slice_a2_error[~slice_a2_mask.astype('bool')] = 0
row_2 = [slice_a2_src, slice_a2_pred, slice_a2_gt, slice_a2_mask, slice_a2_error]
rows.append(row_2)
# plot
fig, ax = plt.subplots(3, len(row_0), figsize=(15, 10))
for row in range(3):
for col in range(len(row_0)):
if col < 4:
if col ==1 or col == 2:
ax[row, col].imshow(rows[row][col], cmap='gray', vmin=-1024, vmax=2000)
else:
ax[row, col].imshow(rows[row][col], cmap='gray')
else:
ax[row, col].imshow(rows[row][col], cmap='twilight_shifted')
ax[row, col].axis('off')
ax[row, col].set_title('{}_slice{}'.format(self.col_names[col], row_slices[row]))
fig.subplots_adjust(wspace=0.05, top=0.8)
save_path = os.path.join(self.save_path, '{}pc_png' .format(pc))
if not os.path.exists(save_path):
os.makedirs(save_path)
fig.savefig(os.path.join(save_path, '{}.png'.format(patient_id)))
# print('Save png slices: ', save_path)
return fig
if __name__ == '__main__':
# pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset206_synthrad2025_task1_MR_mednextL/nnUNetTrainerV2_MedNeXt_L_kernel3__nnUNetPlans__3d_fullres/fold_0/validation"
# pred_path_revert_norm = pred_path + "_revert_norm"
# raw_data_path = f"/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset206_synthrad2025_task1_MR_mednextL"
# gt_path = os.path.join(raw_data_path, "gt_segmentations")
# mask_path = os.path.join(raw_data_path, "masks")
# results, df = compute_folder_metrics(pred_path_revert_norm, gt_path, mask_path)
# print("mean mae:", results['mae']['mean'])
# print("mean psnr:", results['psnr']['mean'])
# print("mean ms_ssim:", results['ms_ssim']['mean'])
input_path = "/datasets/work/hb-synthrad2023/source/synthrad2025_data_v2/synthRAD2025_Task1_Train/Task1/AB" # contain p_id/ct.mha, mask.mha, mr.mha
pred_path_revert_norm = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset340_synthrad2025_task1_MR_AB_mednext/nnUNetTrainerV2_MedNeXt_L_kernel5__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm"
vs = TestingResults(pred_path_revert_norm, task=1, region='AB')
vs.process_a_patient('1ABA011')
# vs.process_patients()