File size: 7,154 Bytes
19c1f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys
import SimpleITK as sitk
import json
import glob
import os
from tqdm import tqdm
import numpy as np
import torch


# revert normalisation
def get_ct_normalisation_values(ct_plan_path):
    """
    Get the mean and standard deviation for CT normalisation.
    """
    # Load the nnUNet plans file for CT
    with open(ct_plan_path, "r") as f:
        ct_plan = json.load(f)

    ct_mean = ct_plan['foreground_intensity_properties_per_channel']["0"]['mean']
    ct_std = ct_plan['foreground_intensity_properties_per_channel']["0"]['std']
    print(f"CT mean: {ct_mean}, CT std: {ct_std}")
    return ct_mean, ct_std

def revert_normalisation(pred_path, ct_mean, ct_std, save_path=None, mask_path=None, mask_outside_value=-1000):
    """
    Revert the normalisation of a CT image.
    """
    if save_path is None:
        save_path = pred_path + '_revert_norm'
    os.makedirs(save_path, exist_ok=True)
    imgs = glob.glob(os.path.join(pred_path, "*.nii.gz")) + \
        glob.glob(os.path.join(pred_path, "*.mha"))

    if mask_path:
        print(f"Applying mask from {mask_path} with outside value {mask_outside_value}")
    else:
        print("No mask provided, normalisation will be applied to all images.")
    for img in tqdm(imgs):
        img_sitk = sitk.ReadImage(img)
        img_array = sitk.GetArrayFromImage(img_sitk)
        img_array = img_array * ct_std + ct_mean
        img_sitk_reverted = sitk.GetImageFromArray(img_array)
        img_sitk_reverted.CopyInformation(img_sitk)

        # if mask_path is provided, apply the mask
        if mask_path:
            filename = os.path.basename(img)
            filename = filename.replace('_0000', '') if '_0000' in filename else filename
            mask_itk = sitk.ReadImage(os.path.join(mask_path, filename))
            img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value)
        sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img)))
        # print(f"Reverted saved to {os.path.join(save_path, os.path.basename(img))}")
import SimpleITK as sitk
import numpy as np

def print_sitk_space(img: sitk.Image, name: str = "img"):
    if not isinstance(img, sitk.Image):
        print(f"[{name}] 不是 SimpleITK.Image(得到 {type(img)}),没有空间信息可打印。")
        return
    size = img.GetSize()            # (x, y, z)
    spacing = img.GetSpacing()      # (x, y, z)
    origin = img.GetOrigin()        # (x, y, z)
    direction = np.array(img.GetDirection())
    dim = img.GetDimension()
    if direction.size == dim*dim:
        direction = direction.reshape(dim, dim)

    print(f"[{name}] size (x,y,z)     = {size}")
    print(f"[{name}] spacing (x,y,z)  = {spacing}")
    print(f"[{name}] origin  (x,y,z)  = {origin}")
    print(f"[{name}] direction matrix =\n{direction}")
    print(f"[{name}] pixel type       = {img.GetPixelIDTypeAsString()}")

def revert_normalisation_modified(pred_path, ct_mean, ct_std, save_path=None,
                         mask_path=None, mask_sitk=None, mask_outside_value=-1000):
    if save_path is None:
        save_path = pred_path + '_revert_norm'
    os.makedirs(save_path, exist_ok=True)

    imgs = glob.glob(os.path.join(pred_path, "*.nii.gz")) + \
           glob.glob(os.path.join(pred_path, "*.mha"))

    if mask_path:
        print(f"Applying mask from {mask_path} with outside value {mask_outside_value}")
    elif mask_sitk is not None:
        print(f"Applying provided mask_sitk with outside value {mask_outside_value}")
    else:
        print("No mask provided, normalisation will be applied to all images.")

    for img in tqdm(imgs):
        img_sitk = sitk.ReadImage(img)
        img_array = sitk.GetArrayFromImage(img_sitk)
        img_array = img_array * ct_std + ct_mean
        img_sitk_reverted = sitk.GetImageFromArray(img_array)
        img_sitk_reverted.CopyInformation(img_sitk)

        if mask_path:
            filename = os.path.basename(img)
            filename = filename.replace('_0000', '') if '_0000' in filename else filename
            mask_itk = sitk.ReadImage(os.path.join(mask_path, filename))
            img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value)
        elif mask_sitk is not None:
            img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_sitk, outsideValue=mask_outside_value)

        sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img)))
def revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=None, mr_sitk = None,outside_value=-1000):
    print(type(pred_sitk))
    # print()
    # arr = sitk.GetArrayFromImage(pred_sitk).astype(np.float32)
    # print(arr)
    arr = pred_sitk * float(ct_std) + float(ct_mean)

    # out = sitk.GetImageFromArray(arr)
    # out.CopyInformation(mr_sitk)

    if mask_sitk is not None:
        out = sitk.Mask(arr, mask_sitk, outsideValue=outside_value)

    return out
# def revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=None, mr_sitk=None, outside_value=-1000):
#     import SimpleITK as sitk
#     import numpy as np

#     print_sitk_space(pred_sitk, "pred_sitk (in)")  # 打印传入影像的空间信息

#     arr = sitk.GetArrayFromImage(pred_sitk).astype(np.float32)  # (z, y, x)
#     arr = arr * float(ct_std) + float(ct_mean)

#     out = sitk.GetImageFromArray(arr)   # 这里生成的新图默认 spacing=(1,1,1), origin=(0,0,0), direction=I

#     # 用参考影像复制空间信息:优先用 mr_sitk(如果你希望与原始 MR 对齐)
#     ref = mr_sitk if mr_sitk is not None else pred_sitk
#     out.CopyInformation(ref)

#     print_sitk_space(out, "out (after CopyInformation)")  # 打印复制后的空间信息

#     if mask_sitk is not None:
#         # 如果 out 和 mask 的网格不完全一致,可以先重采样到 mask 的网格
#         if (out.GetSize()!=mask_sitk.GetSize() or
#             out.GetSpacing()!=mask_sitk.GetSpacing() or
#             out.GetOrigin()!=mask_sitk.GetOrigin() or
#             out.GetDirection()!=mask_sitk.GetDirection()):
#             out = sitk.Resample(out, mask_sitk, sitk.Transform(), sitk.sitkLinear, outside_value, out.GetPixelID())
#         out = sitk.Mask(out, sitk.Cast(mask_sitk, sitk.sitkUInt8), outsideValue=outside_value)

#     return out

def revert_normalisation_single(pred_sitk, ct_mean, ct_std):
    arr = sitk.GetArrayFromImage(pred_sitk)
    arr = arr * ct_std + ct_mean
    reverted = sitk.GetImageFromArray(arr)
    reverted.CopyInformation(pred_sitk)
    return reverted
if __name__ == "__main__":
    ct_plan_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset203_synthrad2025_task1_CT/nnUNetPlans.json"
    ct_mean, ct_std = get_ct_normalisation_values(ct_plan_path)
    pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset202_synthrad2025_task1_MR_mask/nnUNetTrainerMRCT__nnUNetPlans__3d_fullres/fold_0/validation"
    revert_normalisation(pred_path, ct_mean, ct_std, save_path=pred_path + "_revert_norm")