File size: 7,154 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import sys
import SimpleITK as sitk
import json
import glob
import os
from tqdm import tqdm
import numpy as np
import torch
# revert normalisation
def get_ct_normalisation_values(ct_plan_path):
"""
Get the mean and standard deviation for CT normalisation.
"""
# Load the nnUNet plans file for CT
with open(ct_plan_path, "r") as f:
ct_plan = json.load(f)
ct_mean = ct_plan['foreground_intensity_properties_per_channel']["0"]['mean']
ct_std = ct_plan['foreground_intensity_properties_per_channel']["0"]['std']
print(f"CT mean: {ct_mean}, CT std: {ct_std}")
return ct_mean, ct_std
def revert_normalisation(pred_path, ct_mean, ct_std, save_path=None, mask_path=None, mask_outside_value=-1000):
"""
Revert the normalisation of a CT image.
"""
if save_path is None:
save_path = pred_path + '_revert_norm'
os.makedirs(save_path, exist_ok=True)
imgs = glob.glob(os.path.join(pred_path, "*.nii.gz")) + \
glob.glob(os.path.join(pred_path, "*.mha"))
if mask_path:
print(f"Applying mask from {mask_path} with outside value {mask_outside_value}")
else:
print("No mask provided, normalisation will be applied to all images.")
for img in tqdm(imgs):
img_sitk = sitk.ReadImage(img)
img_array = sitk.GetArrayFromImage(img_sitk)
img_array = img_array * ct_std + ct_mean
img_sitk_reverted = sitk.GetImageFromArray(img_array)
img_sitk_reverted.CopyInformation(img_sitk)
# if mask_path is provided, apply the mask
if mask_path:
filename = os.path.basename(img)
filename = filename.replace('_0000', '') if '_0000' in filename else filename
mask_itk = sitk.ReadImage(os.path.join(mask_path, filename))
img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value)
sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img)))
# print(f"Reverted saved to {os.path.join(save_path, os.path.basename(img))}")
import SimpleITK as sitk
import numpy as np
def print_sitk_space(img: sitk.Image, name: str = "img"):
if not isinstance(img, sitk.Image):
print(f"[{name}] 不是 SimpleITK.Image(得到 {type(img)}),没有空间信息可打印。")
return
size = img.GetSize() # (x, y, z)
spacing = img.GetSpacing() # (x, y, z)
origin = img.GetOrigin() # (x, y, z)
direction = np.array(img.GetDirection())
dim = img.GetDimension()
if direction.size == dim*dim:
direction = direction.reshape(dim, dim)
print(f"[{name}] size (x,y,z) = {size}")
print(f"[{name}] spacing (x,y,z) = {spacing}")
print(f"[{name}] origin (x,y,z) = {origin}")
print(f"[{name}] direction matrix =\n{direction}")
print(f"[{name}] pixel type = {img.GetPixelIDTypeAsString()}")
def revert_normalisation_modified(pred_path, ct_mean, ct_std, save_path=None,
mask_path=None, mask_sitk=None, mask_outside_value=-1000):
if save_path is None:
save_path = pred_path + '_revert_norm'
os.makedirs(save_path, exist_ok=True)
imgs = glob.glob(os.path.join(pred_path, "*.nii.gz")) + \
glob.glob(os.path.join(pred_path, "*.mha"))
if mask_path:
print(f"Applying mask from {mask_path} with outside value {mask_outside_value}")
elif mask_sitk is not None:
print(f"Applying provided mask_sitk with outside value {mask_outside_value}")
else:
print("No mask provided, normalisation will be applied to all images.")
for img in tqdm(imgs):
img_sitk = sitk.ReadImage(img)
img_array = sitk.GetArrayFromImage(img_sitk)
img_array = img_array * ct_std + ct_mean
img_sitk_reverted = sitk.GetImageFromArray(img_array)
img_sitk_reverted.CopyInformation(img_sitk)
if mask_path:
filename = os.path.basename(img)
filename = filename.replace('_0000', '') if '_0000' in filename else filename
mask_itk = sitk.ReadImage(os.path.join(mask_path, filename))
img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value)
elif mask_sitk is not None:
img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_sitk, outsideValue=mask_outside_value)
sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img)))
def revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=None, mr_sitk = None,outside_value=-1000):
print(type(pred_sitk))
# print()
# arr = sitk.GetArrayFromImage(pred_sitk).astype(np.float32)
# print(arr)
arr = pred_sitk * float(ct_std) + float(ct_mean)
# out = sitk.GetImageFromArray(arr)
# out.CopyInformation(mr_sitk)
if mask_sitk is not None:
out = sitk.Mask(arr, mask_sitk, outsideValue=outside_value)
return out
# def revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=None, mr_sitk=None, outside_value=-1000):
# import SimpleITK as sitk
# import numpy as np
# print_sitk_space(pred_sitk, "pred_sitk (in)") # 打印传入影像的空间信息
# arr = sitk.GetArrayFromImage(pred_sitk).astype(np.float32) # (z, y, x)
# arr = arr * float(ct_std) + float(ct_mean)
# out = sitk.GetImageFromArray(arr) # 这里生成的新图默认 spacing=(1,1,1), origin=(0,0,0), direction=I
# # 用参考影像复制空间信息:优先用 mr_sitk(如果你希望与原始 MR 对齐)
# ref = mr_sitk if mr_sitk is not None else pred_sitk
# out.CopyInformation(ref)
# print_sitk_space(out, "out (after CopyInformation)") # 打印复制后的空间信息
# if mask_sitk is not None:
# # 如果 out 和 mask 的网格不完全一致,可以先重采样到 mask 的网格
# if (out.GetSize()!=mask_sitk.GetSize() or
# out.GetSpacing()!=mask_sitk.GetSpacing() or
# out.GetOrigin()!=mask_sitk.GetOrigin() or
# out.GetDirection()!=mask_sitk.GetDirection()):
# out = sitk.Resample(out, mask_sitk, sitk.Transform(), sitk.sitkLinear, outside_value, out.GetPixelID())
# out = sitk.Mask(out, sitk.Cast(mask_sitk, sitk.sitkUInt8), outsideValue=outside_value)
# return out
def revert_normalisation_single(pred_sitk, ct_mean, ct_std):
arr = sitk.GetArrayFromImage(pred_sitk)
arr = arr * ct_std + ct_mean
reverted = sitk.GetImageFromArray(arr)
reverted.CopyInformation(pred_sitk)
return reverted
if __name__ == "__main__":
ct_plan_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset203_synthrad2025_task1_CT/nnUNetPlans.json"
ct_mean, ct_std = get_ct_normalisation_values(ct_plan_path)
pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset202_synthrad2025_task1_MR_mask/nnUNetTrainerMRCT__nnUNetPlans__3d_fullres/fold_0/validation"
revert_normalisation(pred_path, ct_mean, ct_std, save_path=pred_path + "_revert_norm")
|