File size: 2,874 Bytes
19c1f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys
import SimpleITK as sitk
import json
import glob
import os
from tqdm import tqdm
import numpy as np
import torch


# revert normalisation
def get_ct_normalisation_values(ct_plan_path):
    """
    Get the mean and standard deviation for CT normalisation.
    """
    # Load the nnUNet plans file for CT
    with open(ct_plan_path, "r") as f:
        ct_plan = json.load(f)

    ct_mean = ct_plan['foreground_intensity_properties_per_channel']["0"]['mean']
    ct_std = ct_plan['foreground_intensity_properties_per_channel']["0"]['std']
    print(f"CT mean: {ct_mean}, CT std: {ct_std}")
    return ct_mean, ct_std

def revert_normalisation(pred_path, ct_mean, ct_std, save_path=None, mask_path=None, mask_outside_value=-1000):
    """
    Revert the normalisation of a CT image.
    """
    if save_path is None:
        save_path = pred_path + '_revert_norm'
    os.makedirs(save_path, exist_ok=True)
    imgs = glob.glob(os.path.join(pred_path, "*.mha"))
    if mask_path:
        print(f"Applying mask from {mask_path} with outside value {mask_outside_value}")
    else:
        print("No mask provided, normalisation will be applied to all images.")
    for img in tqdm(imgs):
        img_sitk = sitk.ReadImage(img)
        img_array = sitk.GetArrayFromImage(img_sitk)
        img_array = img_array * ct_std + ct_mean
        img_sitk_reverted = sitk.GetImageFromArray(img_array)
        img_sitk_reverted.CopyInformation(img_sitk)

        # if mask_path is provided, apply the mask
        if mask_path:
            filename = os.path.basename(img)
            filename = filename.replace('_0000', '') if '_0000' in filename else filename
            mask_itk = sitk.ReadImage(os.path.join(mask_path, filename))
            img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value)
        sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img)))
        # print(f"Reverted saved to {os.path.join(save_path, os.path.basename(img))}")
    
if __name__ == "__main__":
    ct_plan_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset251_synthrad2025_task1_CT_AB_pre_v2r_stitched_masked_synseg/nnUNetResEncUNetLPlans.json"
    ct_mean, ct_std = get_ct_normalisation_values(ct_plan_path)
    mask_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset250_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked_synseg/masks_real"
    pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset250_synthrad2025_task1_MR_AB_pre_v2r_stitched_masked_synseg/nnUNetTrainerMRCT_loss_seg__nnUNetResEncUNetLPlans__3d_fullres/fold_0/validation"
    revert_normalisation(pred_path, ct_mean, ct_std, save_path=pred_path + "_revert_norm", mask_path=mask_path)