Spaces:
Runtime error
Runtime error
zy7_oldserver commited on
Commit ·
fd601de
1
Parent(s): 299136f
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +16 -0
- README.md +1 -20
- dataprocesser/.gitignore +1 -0
- dataprocesser/Preprocess_CT_Mask_generation.py +267 -0
- dataprocesser/Preprocess_MRCT_mask_conversion.py +294 -0
- dataprocesser/Preprocess_MR_Mask_generation.py +306 -0
- dataprocesser/Preprocess_MR_Masks_overlay.py +78 -0
- dataprocesser/__init__.py +8 -0
- dataprocesser/archive/archiv.py +236 -0
- dataprocesser/archive/basics.py +167 -0
- dataprocesser/archive/checkdata.py +91 -0
- dataprocesser/archive/createsegtransform.py +276 -0
- dataprocesser/archive/csv_dataset.py +121 -0
- dataprocesser/archive/csv_dataset_slices.py +20 -0
- dataprocesser/archive/csv_dataset_slices_assigned.py +11 -0
- dataprocesser/archive/data_create_seg.py +28 -0
- dataprocesser/archive/data_slicing.py +13 -0
- dataprocesser/archive/dataset_med.py +188 -0
- dataprocesser/archive/gan_loader.py +310 -0
- dataprocesser/archive/init_dataset.py +0 -0
- dataprocesser/archive/json_dataset_slices.py +28 -0
- dataprocesser/archive/list_dataset_Anika.py +10 -0
- dataprocesser/archive/list_dataset_Anish.py +0 -0
- dataprocesser/archive/list_dataset_Anish_seg.py +42 -0
- dataprocesser/archive/list_dataset_base.py +983 -0
- dataprocesser/archive/list_dataset_combined_seg.py +15 -0
- dataprocesser/archive/list_dataset_combined_seg_assigned.py +1 -0
- dataprocesser/archive/list_dataset_synthrad.py +0 -0
- dataprocesser/archive/list_dataset_synthrad_seg.py +3 -0
- dataprocesser/archive/monai_loader_3D.py +367 -0
- dataprocesser/archive/slice_loader.py +124 -0
- dataprocesser/build_dataset.py +22 -0
- dataprocesser/config_example.yaml +43 -0
- dataprocesser/create_csv.py +87 -0
- dataprocesser/create_csv_xcat.py +25 -0
- dataprocesser/create_json_lodopab.py +59 -0
- dataprocesser/create_json_xcat.py +70 -0
- dataprocesser/customized_datasets.py +115 -0
- dataprocesser/customized_normalization.py +149 -0
- dataprocesser/customized_transform_list.py +149 -0
- dataprocesser/customized_transforms.py +507 -0
- dataprocesser/data_processing/.gitignore +4 -0
- dataprocesser/data_processing/README.md +20 -0
- dataprocesser/data_processing/__init__.py +2 -0
- dataprocesser/data_processing/data_process/.gitignore +1 -0
- dataprocesser/data_processing/data_process/CTbatchevaluate.py +49 -0
- dataprocesser/data_processing/data_process/CTevaluate.py +137 -0
- dataprocesser/data_processing/data_process/convert_dicoms.py +83 -0
- dataprocesser/data_processing/data_process/make_cond.py +37 -0
- dataprocesser/data_processing/data_process/matlab/BCELossIllustration.m +53 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
!.gitignore
|
| 2 |
+
/logs
|
| 3 |
+
__pycache__*
|
| 4 |
+
/*pycache*
|
| 5 |
+
/model/*pycache*
|
| 6 |
+
**/*pycache*
|
| 7 |
+
*.out
|
| 8 |
+
venv*
|
| 9 |
+
helix_log
|
| 10 |
+
checkpoints
|
| 11 |
+
MONAI
|
| 12 |
+
*.svg
|
| 13 |
+
data
|
| 14 |
+
generative-models
|
| 15 |
+
datasets
|
| 16 |
+
notuse
|
README.md
CHANGED
|
@@ -1,20 +1 @@
|
|
| 1 |
-
|
| 2 |
-
title: Frankenstein
|
| 3 |
-
emoji: 🚀
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 8501
|
| 8 |
-
tags:
|
| 9 |
-
- streamlit
|
| 10 |
-
pinned: false
|
| 11 |
-
short_description: Artificial Life
|
| 12 |
-
license: mit
|
| 13 |
-
---
|
| 14 |
-
|
| 15 |
-
# Welcome to Streamlit!
|
| 16 |
-
|
| 17 |
-
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
| 18 |
-
|
| 19 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 20 |
-
forums](https://discuss.streamlit.io).
|
|
|
|
| 1 |
+
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataprocesser/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__*
|
dataprocesser/Preprocess_CT_Mask_generation.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nrrd
|
| 5 |
+
import SimpleITK as sitk
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
def shift_to_min_zero(arr):
|
| 11 |
+
"""
|
| 12 |
+
Shifts the input NumPy array so that the minimum value becomes 0.
|
| 13 |
+
|
| 14 |
+
Parameters:
|
| 15 |
+
arr (numpy.ndarray): The input array to shift.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
numpy.ndarray: The shifted array with the minimum value as 0.
|
| 19 |
+
"""
|
| 20 |
+
min_value = np.min(arr) # Find the minimum value
|
| 21 |
+
shifted_array = arr - min_value # Subtract the minimum value from all elements
|
| 22 |
+
return shifted_array
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_body_mask(numpy_img, body_threshold=-500, min_contour_area=10000):
|
| 26 |
+
"""
|
| 27 |
+
Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
|
| 34 |
+
"""
|
| 35 |
+
# Convert tensor to numpy array
|
| 36 |
+
numpy_img = np.ascontiguousarray(numpy_img.astype(np.int16)) # Ensure we can handle negative values correctly
|
| 37 |
+
#numpy_img = numpy_img.astype(np.int16)
|
| 38 |
+
|
| 39 |
+
# Threshold the image at -500 to separate potential body from the background
|
| 40 |
+
binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
|
| 41 |
+
|
| 42 |
+
# Find contours from the binary image
|
| 43 |
+
contours, _ = cv2.findContours(binary_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
| 44 |
+
|
| 45 |
+
# Create an empty mask
|
| 46 |
+
mask = np.zeros_like(binary_img)
|
| 47 |
+
|
| 48 |
+
VERBOSE = False
|
| 49 |
+
# Fill all detected body contours
|
| 50 |
+
if contours:
|
| 51 |
+
for contour in contours:
|
| 52 |
+
if cv2.contourArea(contour) >= min_contour_area:
|
| 53 |
+
if VERBOSE:
|
| 54 |
+
print('current contour area: ', cv2.contourArea(contour), 'threshold: ', min_contour_area)
|
| 55 |
+
cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
|
| 56 |
+
|
| 57 |
+
return mask
|
| 58 |
+
|
| 59 |
+
def apply_mask(normalized_image_array, mask_array):
|
| 60 |
+
return normalized_image_array * mask_array
|
| 61 |
+
|
| 62 |
+
def print_all_info(data, title):
|
| 63 |
+
print(f'min, max of {title}:', np.min(data), np.max(data))
|
| 64 |
+
|
| 65 |
+
def process_CT_segmentation_numpy(mask, csv_simulation_values):
|
| 66 |
+
#df = pd.read_csv(csv_file)
|
| 67 |
+
df = csv_simulation_values
|
| 68 |
+
# Create a dictionary to map organ index to HU values
|
| 69 |
+
hu_values = dict(zip(df['Order Number'], df['HU Value']))
|
| 70 |
+
order_begin_from_0 = True if df['Order Number'].min()==0 else False
|
| 71 |
+
|
| 72 |
+
hu_mask = np.zeros_like(mask)
|
| 73 |
+
# Value Assigment
|
| 74 |
+
hu_mask[mask == 0] = -1000 # background
|
| 75 |
+
for organ_index, hu_value in hu_values.items():
|
| 76 |
+
assert isinstance(hu_value, int), f"Expected mask value an integer, but got {hu_value}. Ensure the mask is created by fine mode of totalsegmentator"
|
| 77 |
+
assert isinstance(organ_index, int), f"Expected organ_index an integer, but got {organ_index}. Ensure the mask is created by fine mode of totalsegmentator"
|
| 78 |
+
if order_begin_from_0:
|
| 79 |
+
hu_mask[mask == (organ_index+1)] = hu_value # mask value begin from 1 as body value, other than 0 in TA2 table, so organ_index+1
|
| 80 |
+
else:
|
| 81 |
+
hu_mask[mask == (organ_index)] = hu_value
|
| 82 |
+
return hu_mask
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# 处理单个图像和分割图
|
| 86 |
+
def process_image(input_path, contour_path, seg_path, seg_tissue_path, csv_simulation_values, output_path1, output_path2, output_path3, body_threshold):
|
| 87 |
+
# 读取原始 MR 图像和分割图
|
| 88 |
+
if input_path.endswith('.nrrd'):
|
| 89 |
+
img, header = nrrd.read(input_path)
|
| 90 |
+
segmentation_img, header_seg = nrrd.read(seg_path)
|
| 91 |
+
seg_tissue_img, header_seg_tissue = nrrd.read(seg_tissue_path)
|
| 92 |
+
elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
|
| 93 |
+
import nibabel as nib
|
| 94 |
+
img_metadata = nib.load(input_path)
|
| 95 |
+
img = img_metadata.get_fdata()
|
| 96 |
+
affine = img_metadata.affine
|
| 97 |
+
|
| 98 |
+
seg_metadata = nib.load(seg_path)
|
| 99 |
+
segmentation_img = seg_metadata.get_fdata()
|
| 100 |
+
affine_seg = seg_metadata.affine
|
| 101 |
+
|
| 102 |
+
seg_tissue_metadata = nib.load(seg_tissue_path)
|
| 103 |
+
seg_tissue_img = seg_tissue_metadata.get_fdata()
|
| 104 |
+
|
| 105 |
+
# extract contour
|
| 106 |
+
body_contour = np.zeros_like(img, dtype=np.int16)
|
| 107 |
+
for i in range(img.shape[-1]):
|
| 108 |
+
slice_data = img[:, :, i]
|
| 109 |
+
body_contour[:, :, i] = create_body_mask(slice_data, body_threshold=body_threshold)
|
| 110 |
+
|
| 111 |
+
# CT images don't need additional normalization
|
| 112 |
+
#
|
| 113 |
+
|
| 114 |
+
# normalize to 0-1
|
| 115 |
+
img_normalized = shift_to_min_zero(img)
|
| 116 |
+
# img_normalized = img_normalized/2000 # scale factor
|
| 117 |
+
|
| 118 |
+
# apply mask to ct img
|
| 119 |
+
masked_image = apply_mask(img_normalized, body_contour)
|
| 120 |
+
|
| 121 |
+
# process the mask image
|
| 122 |
+
seg = segmentation_img
|
| 123 |
+
tissue = seg_tissue_img
|
| 124 |
+
tissue[tissue!=0] += 200
|
| 125 |
+
# Create a mask for overlapping areas
|
| 126 |
+
overlap_mask = (seg > 0) & (tissue > 0)
|
| 127 |
+
|
| 128 |
+
# For overlapping areas, keep the lower value (organ values in seg)
|
| 129 |
+
merged_mask = tissue.copy()
|
| 130 |
+
merged_mask[overlap_mask] = seg[overlap_mask]
|
| 131 |
+
|
| 132 |
+
# Keep all non-overlapping areas
|
| 133 |
+
merged_mask[seg > 0] = seg[seg > 0]
|
| 134 |
+
|
| 135 |
+
combined_array = merged_mask + body_contour
|
| 136 |
+
|
| 137 |
+
processed_segmentation = combined_array
|
| 138 |
+
|
| 139 |
+
# assign simulation value to ct segmentation mask
|
| 140 |
+
assigned_segmentation = process_CT_segmentation_numpy(combined_array, csv_simulation_values)
|
| 141 |
+
|
| 142 |
+
if input_path.endswith('.nrrd'):
|
| 143 |
+
# 保存处理后的 MR 图像
|
| 144 |
+
nrrd.write(output_path1, masked_image, header)
|
| 145 |
+
|
| 146 |
+
# 保存处理后的分割图
|
| 147 |
+
nrrd.write(output_path2, processed_segmentation, header_seg)
|
| 148 |
+
|
| 149 |
+
# save the body contour mask
|
| 150 |
+
|
| 151 |
+
elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
|
| 152 |
+
img_processed = nib.Nifti1Image(masked_image, affine)
|
| 153 |
+
nib.save(img_processed, output_path1)
|
| 154 |
+
seg_processed = nib.Nifti1Image(processed_segmentation, affine_seg)
|
| 155 |
+
nib.save(seg_processed, output_path2)
|
| 156 |
+
contour_processed = nib.Nifti1Image(body_contour, affine_seg)
|
| 157 |
+
assigned_segmentation_processed = nib.Nifti1Image(assigned_segmentation, affine_seg)
|
| 158 |
+
# Split the path into directory and filename
|
| 159 |
+
directory, filename = os.path.split(output_path2)
|
| 160 |
+
contour_filename = filename.replace('_seg_merged', '_contour')
|
| 161 |
+
contour_path = os.path.join(directory, contour_filename)
|
| 162 |
+
nib.save(contour_processed, contour_path)
|
| 163 |
+
|
| 164 |
+
nib.save(assigned_segmentation_processed, output_path3)
|
| 165 |
+
|
| 166 |
+
return processed_segmentation
|
| 167 |
+
|
| 168 |
+
def analyse_hist(input_path):
|
| 169 |
+
if input_path.endswith('.nrrd'):
|
| 170 |
+
img, header = nrrd.read(input_path)
|
| 171 |
+
elif input_path.endswith('.nii.gz'):
|
| 172 |
+
import nibabel as nib
|
| 173 |
+
img_metadata = nib.load(input_path)
|
| 174 |
+
img = img_metadata.get_fdata()
|
| 175 |
+
affine = img_metadata.affine
|
| 176 |
+
import numpy as np
|
| 177 |
+
import matplotlib.pyplot as plt
|
| 178 |
+
|
| 179 |
+
# Plot the histogram
|
| 180 |
+
print('shape of img: ', img.shape)
|
| 181 |
+
plt.hist(img[:, :, 50], bins=30, edgecolor='black', alpha=0.7)
|
| 182 |
+
plt.xlabel('Value')
|
| 183 |
+
plt.ylabel('Frequency')
|
| 184 |
+
plt.title('Value Distribution')
|
| 185 |
+
plt.show()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def process_csv(csv_file, output_root, csv_simulation_file, body_threshold=-500):
|
| 189 |
+
# read csv to get simulation value
|
| 190 |
+
csv_simulation_values = pd.read_csv(csv_simulation_file) #.to_numpy()
|
| 191 |
+
#csv_simulation_values = pd.read_csv(csv_simulation_file)
|
| 192 |
+
|
| 193 |
+
# check 2-dimensional csv_simulation_values
|
| 194 |
+
if csv_simulation_values.ndim == 1:
|
| 195 |
+
raise ValueError("CSV should contain two columns: organ_index and simulation_value")
|
| 196 |
+
|
| 197 |
+
if not os.path.exists(csv_file):
|
| 198 |
+
print('csv:', csv_file)
|
| 199 |
+
raise ValueError('csv_file must input a available csv file in simplified form: id, Aorta_diss, seg, img!')
|
| 200 |
+
else:
|
| 201 |
+
print(f'use csv: {csv_file}')
|
| 202 |
+
|
| 203 |
+
data_frame = pd.read_csv(csv_file)
|
| 204 |
+
if len(data_frame) == 0:
|
| 205 |
+
raise RuntimeError(f"Found 0 images in: {csv_file}")
|
| 206 |
+
patient_IDs = data_frame.iloc[:, 0].tolist()
|
| 207 |
+
Aorta_diss = data_frame.iloc[:, 1].tolist()
|
| 208 |
+
segs = data_frame.iloc[:, 2].tolist()
|
| 209 |
+
images = data_frame.iloc[:, 3].tolist()
|
| 210 |
+
|
| 211 |
+
from tqdm import tqdm
|
| 212 |
+
dataset_list = []
|
| 213 |
+
for idx in tqdm(range(len(images))):
|
| 214 |
+
if (images[idx].endswith('.nii.gz') and segs[idx].endswith('.nii.gz')) or \
|
| 215 |
+
(images[idx].endswith('.nii') and segs[idx].endswith('.nii')):
|
| 216 |
+
input_file_path = images[idx]
|
| 217 |
+
seg_file_path = segs[idx]
|
| 218 |
+
patient_id = patient_IDs[idx]
|
| 219 |
+
ad = Aorta_diss[idx]
|
| 220 |
+
seg_tissue_file_path = seg_file_path.replace("_seg","_seg_tissue")
|
| 221 |
+
|
| 222 |
+
root_dir = os.path.dirname(input_file_path)
|
| 223 |
+
|
| 224 |
+
# Get root path (directory path)
|
| 225 |
+
root_path = os.path.dirname(seg_file_path)
|
| 226 |
+
ct_processed_file_name = f"{patient_id}_ct_processed.nii.gz"
|
| 227 |
+
seg_merged_file_name = f"{patient_id}_ct_seg_merged.nii.gz"
|
| 228 |
+
seg_merged_assigned_mask_file_name = f"{patient_id}_ct_seg_merged_assigned_mask.nii.gz"
|
| 229 |
+
|
| 230 |
+
os.makedirs(output_root, exist_ok=True)
|
| 231 |
+
output_file_path1 = os.path.join(output_root, ct_processed_file_name)
|
| 232 |
+
output_file_path2 = os.path.join(output_root, seg_merged_file_name)
|
| 233 |
+
output_file_path3 = os.path.join(output_root, seg_merged_assigned_mask_file_name)
|
| 234 |
+
print(f"Processing {input_file_path} with segmentation {seg_file_path}")
|
| 235 |
+
print(f"Save results to {output_file_path1} and {output_file_path2} and {output_file_path3} \n")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
processed_seg = process_image(input_file_path, None, seg_file_path, seg_tissue_file_path, csv_simulation_values, output_file_path1, output_file_path2, output_file_path3, body_threshold)
|
| 239 |
+
|
| 240 |
+
# processed_mr_csv_file = ...
|
| 241 |
+
csv_mr_line = [patient_id,ad, output_file_path2, output_file_path1, output_file_path3]
|
| 242 |
+
dataset_list.append(csv_mr_line)
|
| 243 |
+
|
| 244 |
+
import csv
|
| 245 |
+
output_csv_file=os.path.join(output_root, 'processed_csv_file.csv')
|
| 246 |
+
with open(output_csv_file, 'w', newline='') as f:
|
| 247 |
+
csvwriter = csv.writer(f)
|
| 248 |
+
csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img', 'seg_mask'])
|
| 249 |
+
csvwriter.writerows(dataset_list)
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
import argparse
|
| 253 |
+
csv_file = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\datacsv\ct_synthrad_test_newserver.csv'
|
| 254 |
+
output_root = r'E:\Projects\yang_proj\data\synthrad\processed'
|
| 255 |
+
csv_simulation_file = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\TA2_CT_from1.csv'
|
| 256 |
+
process_csv(csv_file, output_root, csv_simulation_file, body_threshold=-500)
|
| 257 |
+
|
| 258 |
+
'''parser = argparse.ArgumentParser(description="Process MR images and segmentation maps, apply masks and replace grayscale values.")
|
| 259 |
+
parser.add_argument('--input_folder1', required=True, help="Path to the folder containing input MR .nrrd files.")
|
| 260 |
+
parser.add_argument('--input_folder2', required=True, help="Path to the folder containing segmentation .nrrd files.")
|
| 261 |
+
parser.add_argument('--output_folder1', required=True, help="Path to the folder to save the output MR files.")
|
| 262 |
+
parser.add_argument('--output_folder2', required=True, help="Path to the folder to save the output segmentation files.")
|
| 263 |
+
parser.add_argument('--csv_simulation_file', required=True, help="CSV file containing simulated CT grayscale values.")
|
| 264 |
+
parser.add_argument('--body_threshold', type=int, default=50, help="Threshold to separate body from background.")
|
| 265 |
+
args = parser.parse_args()
|
| 266 |
+
|
| 267 |
+
process_folder(args.input_folder1, args.input_folder2, args.output_folder1, args.output_folder2, args.csv_simulation_file, args.body_threshold)'''
|
dataprocesser/Preprocess_MRCT_mask_conversion.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import nibabel as nib
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
#from dataprocesser.customized_transforms import create_body_contour
|
| 8 |
+
from dataprocesser.Preprocess_MR_Mask_generation import process_segmentation
|
| 9 |
+
from dataprocesser.Preprocess_CT_Mask_generation import process_CT_segmentation_numpy
|
| 10 |
+
import difflib
|
| 11 |
+
|
| 12 |
+
def find_best_match_smart(organ_name, target_names):
|
| 13 |
+
# 完全匹配
|
| 14 |
+
if organ_name in target_names:
|
| 15 |
+
return organ_name
|
| 16 |
+
# 精确 startswith 匹配(如 vertebrae → vertebrae_Lx)
|
| 17 |
+
matches = [t for t in target_names if t.startswith(organ_name)]
|
| 18 |
+
if matches:
|
| 19 |
+
return matches[0]
|
| 20 |
+
# 再 fallback 到 difflib,但严格一些
|
| 21 |
+
import difflib
|
| 22 |
+
closes = difflib.get_close_matches(organ_name, target_names, n=1, cutoff=0.8)
|
| 23 |
+
return closes #[0] if close else None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def convert_segmentation_mask_torch(source_mask, source_csv, target_csv, body_contour_value=1):
|
| 27 |
+
"""
|
| 28 |
+
Converts segmentation mask values from source modality to target modality based on organ name mapping.
|
| 29 |
+
|
| 30 |
+
Parameters:
|
| 31 |
+
- source_mask (torch.Tensor): The source segmentation mask tensor.
|
| 32 |
+
- source_csv (str): Path to the CSV file of the source modality (CT or MR).
|
| 33 |
+
- target_csv (str): Path to the CSV file of the target modality (MR or CT).
|
| 34 |
+
- body_contour_value (int): The class value for "body contour" in the target modality.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
- target_mask (torch.Tensor): The converted segmentation mask tensor.
|
| 38 |
+
"""
|
| 39 |
+
# Load the source and target anatomy lists
|
| 40 |
+
source_df = pd.read_csv(source_csv)
|
| 41 |
+
target_df = pd.read_csv(target_csv)
|
| 42 |
+
|
| 43 |
+
# Create dictionaries mapping class values to organ names
|
| 44 |
+
source_mapping = {}
|
| 45 |
+
for _, row in source_df.iterrows():
|
| 46 |
+
organ_name = row['Organ Name']
|
| 47 |
+
class_value = row.iloc[0]
|
| 48 |
+
source_mapping.setdefault(organ_name, []).append(class_value)
|
| 49 |
+
|
| 50 |
+
target_mapping = {}
|
| 51 |
+
for _, row in target_df.iterrows():
|
| 52 |
+
organ_name = row['Organ Name']
|
| 53 |
+
class_value = row.iloc[0]
|
| 54 |
+
target_mapping.setdefault(organ_name, []).append(class_value)
|
| 55 |
+
|
| 56 |
+
# Create a reverse mapping from class values to organ names for the source modality
|
| 57 |
+
class_to_organ = {class_value: organ_name for organ_name, class_values in source_mapping.items() for class_value in class_values}
|
| 58 |
+
|
| 59 |
+
# Initialize the target mask with zeros
|
| 60 |
+
target_mask = torch.zeros_like(source_mask, dtype=source_mask.dtype)
|
| 61 |
+
|
| 62 |
+
# Convert each unique class in the source mask
|
| 63 |
+
unique_classes = torch.unique(source_mask)
|
| 64 |
+
for class_value in unique_classes:
|
| 65 |
+
# Find the corresponding organ name in the source modality
|
| 66 |
+
organ_name = class_to_organ.get(class_value.item(), None)
|
| 67 |
+
|
| 68 |
+
if class_value.item() == 0: # Preserve background as is
|
| 69 |
+
target_value = 0
|
| 70 |
+
else:
|
| 71 |
+
# If organ name exists, find the corresponding target class values
|
| 72 |
+
if organ_name and organ_name in target_mapping:
|
| 73 |
+
# Pick the first target class value (or handle overlaps if needed)
|
| 74 |
+
target_value = target_mapping[organ_name][0]
|
| 75 |
+
else:
|
| 76 |
+
# Use body contour class value for unmapped organs
|
| 77 |
+
target_value = body_contour_value
|
| 78 |
+
#print(f'Processing for class {class_value.item()}')
|
| 79 |
+
#print(f'Not found {organ_name} in target mapping, replaced with body contour.')
|
| 80 |
+
|
| 81 |
+
# Replace class values in the target mask
|
| 82 |
+
target_mask[source_mask == class_value] = target_value
|
| 83 |
+
|
| 84 |
+
return target_mask
|
| 85 |
+
|
| 86 |
+
def convert_segmentation_mask(source_mask, source_csv, target_csv, body_contour_value=1000):
|
| 87 |
+
"""
|
| 88 |
+
Converts segmentation mask values from source modality to target modality based on organ name mapping.
|
| 89 |
+
|
| 90 |
+
Parameters:
|
| 91 |
+
- source_mask (ndarray): The source segmentation mask array.
|
| 92 |
+
- source_csv (str): Path to the CSV file of the source modality (CT or MR).
|
| 93 |
+
- target_csv (str): Path to the CSV file of the target modality (MR or CT).
|
| 94 |
+
- body_contour_value (int): The class value for "body contour" in the target modality.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
- target_mask (ndarray): The converted segmentation mask.
|
| 98 |
+
"""
|
| 99 |
+
# Load the source and target anatomy lists
|
| 100 |
+
source_df = pd.read_csv(source_csv)
|
| 101 |
+
target_df = pd.read_csv(target_csv)
|
| 102 |
+
|
| 103 |
+
# Create dictionaries mapping class values to organ names and vice versa
|
| 104 |
+
source_mapping = {}
|
| 105 |
+
for _, row in source_df.iterrows():
|
| 106 |
+
organ_name = row['Organ Name']
|
| 107 |
+
class_value = row.iloc[0]
|
| 108 |
+
source_mapping.setdefault(organ_name, []).append(class_value)
|
| 109 |
+
|
| 110 |
+
target_mapping = {}
|
| 111 |
+
for _, row in target_df.iterrows():
|
| 112 |
+
organ_name = row['Organ Name']
|
| 113 |
+
class_value = row.iloc[0]
|
| 114 |
+
target_mapping.setdefault(organ_name, []).append(class_value)
|
| 115 |
+
|
| 116 |
+
# Create a reverse mapping from class values to organ names for the source modality
|
| 117 |
+
class_to_organ = {class_value: organ_name for organ_name, class_values in source_mapping.items() for class_value in class_values}
|
| 118 |
+
# Initialize the target mask
|
| 119 |
+
target_organ_names = list(target_mapping.keys())
|
| 120 |
+
target_mask = np.full_like(source_mask, 0, dtype=source_mask.dtype)
|
| 121 |
+
|
| 122 |
+
# Convert each unique class in the source mask
|
| 123 |
+
for class_value in np.unique(source_mask):
|
| 124 |
+
# Find the corresponding organ name in the source modality
|
| 125 |
+
organ_name = class_to_organ.get(class_value, None)
|
| 126 |
+
if class_value == 0:
|
| 127 |
+
target_value = 0
|
| 128 |
+
else:
|
| 129 |
+
# If organ name exists, find the corresponding target class values
|
| 130 |
+
if organ_name and organ_name in target_mapping:
|
| 131 |
+
# Pick the first target class value (or handle overlaps if needed)
|
| 132 |
+
target_value = target_mapping[organ_name][0]
|
| 133 |
+
else:
|
| 134 |
+
# Manual mapping: source organ name → target organ name
|
| 135 |
+
manual_mapping = {
|
| 136 |
+
'intervertebral_discs': 'spinal_cord',
|
| 137 |
+
'quadriceps_femoris_left':'gluteus_maximus_left',
|
| 138 |
+
'quadriceps_femoris_right':'gluteus_maximus_right',
|
| 139 |
+
'thigh_medial_compartment_left': 'gluteus_maximus_left',
|
| 140 |
+
'thigh_medial_compartment_right': 'gluteus_maximus_right',
|
| 141 |
+
'thigh_posterior_compartment_left': 'gluteus_maximus_left',
|
| 142 |
+
'thigh_posterior_compartment_right': 'gluteus_maximus_right',
|
| 143 |
+
'sartorius_left': 'gluteus_maximus_left',
|
| 144 |
+
'sartorius_right': 'gluteus_maximus_right',
|
| 145 |
+
# Add more mappings here as needed
|
| 146 |
+
}
|
| 147 |
+
# Check manual mapping first
|
| 148 |
+
if organ_name in manual_mapping and manual_mapping[organ_name] in target_mapping:
|
| 149 |
+
matched_name = manual_mapping[organ_name]
|
| 150 |
+
target_value = target_mapping[matched_name][0]
|
| 151 |
+
print(f"[Manual match] '{organ_name}' → '{matched_name}' → label {target_value}")
|
| 152 |
+
else:
|
| 153 |
+
# Fuzzy match fallback
|
| 154 |
+
close_matches = difflib.get_close_matches(organ_name, target_organ_names, n=1, cutoff=0.4)
|
| 155 |
+
if close_matches:
|
| 156 |
+
matched_name = close_matches[0]
|
| 157 |
+
target_value = target_mapping[matched_name][0]
|
| 158 |
+
print(f"[Fuzzy match] '{organ_name}' → '{matched_name}' → label {target_value}")
|
| 159 |
+
else:
|
| 160 |
+
print(f"[Warning] No match for '{organ_name}', using body contour value.")
|
| 161 |
+
target_value = body_contour_value
|
| 162 |
+
'''close_matches = difflib.get_close_matches(organ_name, target_organ_names, n=1, cutoff=0.4)
|
| 163 |
+
if close_matches:
|
| 164 |
+
matched_name = close_matches[0]
|
| 165 |
+
target_value = target_mapping[matched_name][0]
|
| 166 |
+
print(f"[Fuzzy match] '{organ_name}' → '{matched_name}' → label {target_value}")
|
| 167 |
+
else:
|
| 168 |
+
print(f"[Warning] No match for '{organ_name}', using body contour value.")
|
| 169 |
+
target_value = body_contour_value'''
|
| 170 |
+
# Replace class values in the target mask
|
| 171 |
+
target_mask[source_mask == class_value] = target_value
|
| 172 |
+
|
| 173 |
+
return target_mask
|
| 174 |
+
|
| 175 |
+
def run_mask_conversion(
|
| 176 |
+
mask = r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis\1PA001\ct_seg.nii.gz',
|
| 177 |
+
img = r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis\1PA001\ct.nii.gz',
|
| 178 |
+
MR_csv = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\TA2_MR_for_convert.csv',
|
| 179 |
+
CT_csv = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\TA2_CT_for_convert.csv',
|
| 180 |
+
output_path = r'mr_mask_from_ct.nii.gz', # output_path = r'ct_mask_from_mr.nii.gz'
|
| 181 |
+
mode = 'ct2mr'
|
| 182 |
+
):
|
| 183 |
+
if mode == 'ct2mr':
|
| 184 |
+
body_threshold=-500
|
| 185 |
+
source_csv = CT_csv
|
| 186 |
+
target_csv = MR_csv
|
| 187 |
+
elif mode == 'mr2ct':
|
| 188 |
+
body_threshold=5
|
| 189 |
+
source_csv = MR_csv
|
| 190 |
+
target_csv = CT_csv
|
| 191 |
+
|
| 192 |
+
source_mask = mask
|
| 193 |
+
img = img
|
| 194 |
+
|
| 195 |
+
seg_metadata = nib.load(source_mask)
|
| 196 |
+
seg = seg_metadata.get_fdata()
|
| 197 |
+
affine = seg_metadata.affine
|
| 198 |
+
|
| 199 |
+
img_metadata = nib.load(img)
|
| 200 |
+
img = img_metadata.get_fdata()
|
| 201 |
+
affine = img_metadata.affine
|
| 202 |
+
|
| 203 |
+
'''body_contour = np.zeros_like(img, dtype=np.int16)
|
| 204 |
+
for i in range(img.shape[2]):
|
| 205 |
+
slice_data = img[:, :, i]
|
| 206 |
+
body_contour[:, :, i] = create_body_contour(slice_data, body_threshold)
|
| 207 |
+
contour = body_contour
|
| 208 |
+
seg_with_contour = seg+contour'''
|
| 209 |
+
seg_with_contour = seg
|
| 210 |
+
target_mask = convert_segmentation_mask(seg_with_contour, source_csv, target_csv, body_contour_value=1)
|
| 211 |
+
if mode == 'ct2mr':
|
| 212 |
+
csv_simulation_file = MR_csv
|
| 213 |
+
csv_values = pd.read_csv(csv_simulation_file, header=None).to_numpy()
|
| 214 |
+
target_mask = process_segmentation(target_mask, csv_values)
|
| 215 |
+
elif mode == 'mr2ct':
|
| 216 |
+
csv_simulation_file = CT_csv
|
| 217 |
+
target_mask = process_CT_segmentation_numpy(target_mask, csv_simulation_file)
|
| 218 |
+
img_processed = nib.Nifti1Image(target_mask, affine)
|
| 219 |
+
nib.save(img_processed, output_path)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def run_mask_conversion_synthrad_test(synthrad_root = r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis', patient_list=['1PA001'], mode = 'ct2mr', output_csv_file = 'ct2mr_conversion.csv'):
|
| 223 |
+
dataset_list = []
|
| 224 |
+
for patient in tqdm(patient_list):
|
| 225 |
+
mr_mask = os.path.join(synthrad_root, patient, 'mr_merged_seg.nii.gz')
|
| 226 |
+
mr_img = os.path.join(synthrad_root, patient, 'mr.nii.gz')
|
| 227 |
+
ct_mask = os.path.join(synthrad_root, patient, 'ct_seg.nii.gz')
|
| 228 |
+
ct_img = os.path.join(synthrad_root, patient, 'ct.nii.gz')
|
| 229 |
+
MR_csv = r'synthrad_conversion/TA2_MR_for_convert.csv'
|
| 230 |
+
CT_csv = r'synthrad_conversion/TA2_CT_for_convert.csv'
|
| 231 |
+
if mode == 'ct2mr':
|
| 232 |
+
preprocessed_mr_path = r'E:\Projects\yang_proj\data\anika\MR_processed'
|
| 233 |
+
preprocessed_mr_img = os.path.join(preprocessed_mr_path, f'mr_{patient}.nii.gz')
|
| 234 |
+
output_path = os.path.join(synthrad_root, patient, 'mr_mask_from_ct.nii.gz')
|
| 235 |
+
csv_mr_line = [patient,0,output_path,preprocessed_mr_img]
|
| 236 |
+
|
| 237 |
+
elif mode == 'mr2ct':
|
| 238 |
+
output_path = os.path.join(synthrad_root, patient, 'ct_mask_from_mr.nii.gz')
|
| 239 |
+
csv_mr_line = [patient,0,output_path,ct_img]
|
| 240 |
+
run_mask_conversion(mr_mask, mr_img, ct_mask, ct_img, MR_csv, CT_csv, output_path, mode)
|
| 241 |
+
dataset_list.append(csv_mr_line)
|
| 242 |
+
|
| 243 |
+
import csv
|
| 244 |
+
with open(output_csv_file, 'w', newline='') as f:
|
| 245 |
+
csvwriter = csv.writer(f)
|
| 246 |
+
csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
|
| 247 |
+
csvwriter.writerows(dataset_list)
|
| 248 |
+
|
| 249 |
+
def run_mask_conversion_csv(csv_file = r'E:\Projects\yang_proj\data\synthrad\processed\processed_ct_csv_file.csv', mode = 'ct2mr', output_csv_file = 'ct2mr_conversion.csv'):
|
| 250 |
+
data_frame = pd.read_csv(csv_file)
|
| 251 |
+
if len(data_frame) == 0:
|
| 252 |
+
raise RuntimeError(f"Found 0 images in: {csv_file}")
|
| 253 |
+
patient_IDs = data_frame.iloc[:, 0].tolist()
|
| 254 |
+
Aorta_diss = data_frame.iloc[:, 1].tolist()
|
| 255 |
+
segs = data_frame.iloc[:, 2].tolist()
|
| 256 |
+
images = data_frame.iloc[:, 3].tolist()
|
| 257 |
+
aligned_segs = data_frame.iloc[:, 4].tolist()
|
| 258 |
+
dataset_list = []
|
| 259 |
+
synthrad_root = r"E:\Projects\yang_proj\data\synthrad\Task1\pelvis"
|
| 260 |
+
from tqdm import tqdm
|
| 261 |
+
for idx in tqdm(range(len(images))):
|
| 262 |
+
MR_csv = r'synthrad_conversion/TA2_MR_for_convert.csv'
|
| 263 |
+
CT_csv = r'synthrad_conversion/TA2_CT_for_convert.csv'
|
| 264 |
+
patient = patient_IDs[idx]
|
| 265 |
+
if mode == 'ct2mr':
|
| 266 |
+
ct_mask = segs[idx]
|
| 267 |
+
ct_img = images[idx]
|
| 268 |
+
preprocessed_mr_path = r'E:\Projects\yang_proj\data\anika\MR_processed'
|
| 269 |
+
preprocessed_mr_img = os.path.join(preprocessed_mr_path, f'mr_{patient}.nii.gz')
|
| 270 |
+
|
| 271 |
+
mr_mask_from_ct_folder = r'E:\Projects\yang_proj\data\synthrad\mr_mask_from_ct'
|
| 272 |
+
output_path = os.path.join(mr_mask_from_ct_folder, f'{patient}_mr_mask_from_ct.nii.gz')
|
| 273 |
+
csv_mr_line = [patient,0,output_path,preprocessed_mr_img]
|
| 274 |
+
run_mask_conversion(ct_mask, ct_img, MR_csv, CT_csv, output_path, mode)
|
| 275 |
+
|
| 276 |
+
elif mode == 'mr2ct':
|
| 277 |
+
mr_mask = os.path.join(synthrad_root, patient, 'mr_merged_seg.nii.gz')
|
| 278 |
+
mr_img = os.path.join(synthrad_root, patient, 'mr.nii.gz')
|
| 279 |
+
output_path = os.path.join(synthrad_root, patient, 'ct_mask_from_mr.nii.gz')
|
| 280 |
+
csv_mr_line = [patient,0,output_path,ct_img]
|
| 281 |
+
run_mask_conversion(mr_mask, mr_img, MR_csv, CT_csv, output_path, mode)
|
| 282 |
+
dataset_list.append(csv_mr_line)
|
| 283 |
+
|
| 284 |
+
import csv
|
| 285 |
+
with open(output_csv_file, 'w', newline='') as f:
|
| 286 |
+
csvwriter = csv.writer(f)
|
| 287 |
+
csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
|
| 288 |
+
csvwriter.writerows(dataset_list)
|
| 289 |
+
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
csv_file = r'E:\Projects\yang_proj\data\synthrad\processed\processed_csv_file.csv'
|
| 292 |
+
mode = 'ct2mr'
|
| 293 |
+
output_csv_file = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\datacsv\ct2mr_conversion.csv'
|
| 294 |
+
run_mask_conversion_csv(csv_file = csv_file, mode = mode, output_csv_file = output_csv_file)
|
dataprocesser/Preprocess_MR_Mask_generation.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nrrd
|
| 5 |
+
import SimpleITK as sitk
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
from dataprocesser.preprocess_MR import step3_vibe_resetsignal
|
| 9 |
+
"""
|
| 10 |
+
该代码用于处理一组 MR 图像和对应的分割图,应用掩膜、进行归一化,并根据 CSV 文件中的仿真 MR 灰度值对分割图进行替换。最后将处理后的 MR 图像和分割图保存。
|
| 11 |
+
|
| 12 |
+
主要步骤:
|
| 13 |
+
1. **读取数据**:从指定的文件夹中读取 MR 图像和对应的分割图。
|
| 14 |
+
2. **归一化处理**:对 MR 图像进行归一化,将其值范围映射到 0 到 255 之间。
|
| 15 |
+
3. **轮廓提取**:从归一化后的 MR 图像中提取出主体区域的轮廓(根据给定的阈值分割),创建掩膜。
|
| 16 |
+
4. **掩膜应用**:将提取出的掩膜应用到归一化后的 MR 图像上,保留主体区域,抑制背景。
|
| 17 |
+
5. **分割图处理**:读取对应的分割图,并与提取出的轮廓进行叠加,之后根据 CSV 文件中的仿真 CT 值替换分割图中的灰度值。
|
| 18 |
+
6. **图像保存**:将处理后的 MR 图像和修改后的分割图保存到指定的输出文件夹中,保证其空间属性和几何信息与输入图像一致。
|
| 19 |
+
7. **输出**:在 ITK-SNAP 等医学图像工具中打开时, MR 图像和分割图能够保持同步和正确的比例显示。
|
| 20 |
+
|
| 21 |
+
函数简介:
|
| 22 |
+
- `normalize`: 对 MR 图像进行归一化处理,将像素值范围映射到 [0, 255]。
|
| 23 |
+
- `create_body_mask`: 从图像中提取出身体的轮廓,生成二值掩膜。
|
| 24 |
+
- `apply_mask`: 将提取的掩膜应用到 MR 图像上,保留轮廓内部的区域。
|
| 25 |
+
- `process_segmentation`: 读取分割图,并根据 CSV 文件中的仿真 CT 值对其灰度值进行替换。
|
| 26 |
+
- `process_image`: 处理单个 MR 图像及其对应的分割图,包括归一化、轮廓提取、掩膜应用、分割图处理等。
|
| 27 |
+
- `process_folder`: 处理整个文件夹中的 MR 图像和分割图,逐一处理所有图像并保存结果。
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# 归一化函数
|
| 31 |
+
def normalize(img, vmin_out=0, vmax_out=1, norm_min_v=None, norm_max_v=None, epsilon=1e-5):
|
| 32 |
+
if norm_min_v is None and norm_max_v is None:
|
| 33 |
+
norm_min_v = np.min(img)
|
| 34 |
+
norm_max_v = np.max(img)
|
| 35 |
+
img = np.clip(img, norm_min_v, norm_max_v)
|
| 36 |
+
img = (img - norm_min_v) / (norm_max_v - norm_min_v + epsilon)
|
| 37 |
+
img = img * (vmax_out - vmin_out) + vmin_out
|
| 38 |
+
return img
|
| 39 |
+
|
| 40 |
+
# 创建轮廓掩膜
|
| 41 |
+
def create_body_mask_simple(numpy_img, body_threshold=50):
|
| 42 |
+
numpy_img = numpy_img.astype(np.int16)
|
| 43 |
+
body_mask = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
|
| 44 |
+
contours, _ = cv2.findContours(body_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 45 |
+
mask = np.zeros_like(body_mask, dtype=np.uint8)
|
| 46 |
+
|
| 47 |
+
if contours:
|
| 48 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 49 |
+
mask = np.ascontiguousarray(mask)
|
| 50 |
+
largest_contour = np.ascontiguousarray(largest_contour)
|
| 51 |
+
cv2.drawContours(mask, [largest_contour], -1, 1, thickness=cv2.FILLED)
|
| 52 |
+
|
| 53 |
+
return mask
|
| 54 |
+
|
| 55 |
+
def create_body_mask(numpy_img, body_threshold=-500, min_contour_area=10000):
|
| 56 |
+
"""
|
| 57 |
+
Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
|
| 64 |
+
"""
|
| 65 |
+
# Convert tensor to numpy array
|
| 66 |
+
numpy_img = np.ascontiguousarray(numpy_img.astype(np.int16)) # Ensure we can handle negative values correctly
|
| 67 |
+
#numpy_img = numpy_img.astype(np.int16)
|
| 68 |
+
|
| 69 |
+
# Threshold the image at -500 to separate potential body from the background
|
| 70 |
+
binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
|
| 71 |
+
|
| 72 |
+
# Find contours from the binary image
|
| 73 |
+
contours, _ = cv2.findContours(binary_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
| 74 |
+
|
| 75 |
+
# Create an empty mask
|
| 76 |
+
mask = np.zeros_like(binary_img)
|
| 77 |
+
|
| 78 |
+
VERBOSE = False
|
| 79 |
+
# Fill all detected body contours
|
| 80 |
+
if contours:
|
| 81 |
+
for contour in contours:
|
| 82 |
+
if cv2.contourArea(contour) >= min_contour_area:
|
| 83 |
+
if VERBOSE:
|
| 84 |
+
print('current contour area: ', cv2.contourArea(contour), 'threshold: ', min_contour_area)
|
| 85 |
+
cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
|
| 86 |
+
|
| 87 |
+
return mask
|
| 88 |
+
|
| 89 |
+
def apply_mask(normalized_image_array, mask_array):
|
| 90 |
+
return normalized_image_array * mask_array
|
| 91 |
+
|
| 92 |
+
def print_all_info(data, title):
|
| 93 |
+
print(f'min, max of {title}:', np.min(data), np.max(data))
|
| 94 |
+
|
| 95 |
+
# process the segmentation, replace the classes with simulated MR values
|
| 96 |
+
def process_segmentation(combined_array, csv_simulation_values, mr_signal_formula=step3_vibe_resetsignal.calculate_signal_vibe):
|
| 97 |
+
combined_array = combined_array.astype(np.int16)
|
| 98 |
+
print_all_info(combined_array, 'combine')
|
| 99 |
+
# two columns of unique value 和 simulation value
|
| 100 |
+
# the first element will not be included
|
| 101 |
+
organ_indexs = csv_simulation_values[1:, 0] # first column: organ index
|
| 102 |
+
T1_values = csv_simulation_values[1:, 1] # second column: simulate MRI value
|
| 103 |
+
T2_values = csv_simulation_values[1:, 2]
|
| 104 |
+
Rho_values = csv_simulation_values[1:, 3]
|
| 105 |
+
order_begin_from_0 = True if organ_indexs.astype(int).min()==0 else False
|
| 106 |
+
#print('organ order number begin from 0:', order_begin_from_0)
|
| 107 |
+
#print(organ_indexs)
|
| 108 |
+
assign_value_mask = np.zeros_like(combined_array)
|
| 109 |
+
|
| 110 |
+
step=0
|
| 111 |
+
for step in range(len(organ_indexs)):
|
| 112 |
+
organ_index = organ_indexs[step] # in csv file, organs begin with 1
|
| 113 |
+
t1 = float(T1_values[step])
|
| 114 |
+
t2 = float(T2_values[step])
|
| 115 |
+
rho = float(Rho_values[step])
|
| 116 |
+
|
| 117 |
+
simulation_value = mr_signal_formula(t1, t2, rho)
|
| 118 |
+
organ_index = int(organ_index)
|
| 119 |
+
if order_begin_from_0:
|
| 120 |
+
#print("order in csv begin from 0")
|
| 121 |
+
assign_value_mask[combined_array == organ_index+1] = simulation_value # organ_index+ 1
|
| 122 |
+
else:
|
| 123 |
+
#print("order in csv begin from 1")
|
| 124 |
+
assign_value_mask[combined_array == organ_index] = simulation_value
|
| 125 |
+
step+=1
|
| 126 |
+
print_all_info(assign_value_mask, 'assignment')
|
| 127 |
+
return assign_value_mask
|
| 128 |
+
|
| 129 |
+
# 处理单个图像和分割图
|
| 130 |
+
def process_image(input_path, contour_path, seg_path, csv_simulation_values, output_path1, output_path2, body_threshold):
|
| 131 |
+
# 读取原始 MR 图像和分割图
|
| 132 |
+
if input_path.endswith('.nrrd'):
|
| 133 |
+
img, header = nrrd.read(input_path)
|
| 134 |
+
segmentation_img, header_seg = nrrd.read(seg_path)
|
| 135 |
+
elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
|
| 136 |
+
import nibabel as nib
|
| 137 |
+
img_metadata = nib.load(input_path)
|
| 138 |
+
img = img_metadata.get_fdata()
|
| 139 |
+
affine = img_metadata.affine
|
| 140 |
+
|
| 141 |
+
seg_metadata = nib.load(seg_path)
|
| 142 |
+
segmentation_img = seg_metadata.get_fdata()
|
| 143 |
+
|
| 144 |
+
# 归一化处理
|
| 145 |
+
norm_max=255 #255
|
| 146 |
+
low_percentile = 5
|
| 147 |
+
high_percentile = 90
|
| 148 |
+
img_normalized = normalize(img, 0, norm_max, np.percentile(img, low_percentile), np.percentile(img, high_percentile), epsilon=0)
|
| 149 |
+
|
| 150 |
+
# 提取轮廓图
|
| 151 |
+
body_contour = np.zeros_like(img, dtype=np.int16)
|
| 152 |
+
for i in range(img.shape[2]):
|
| 153 |
+
slice_data = img[:, :, i]
|
| 154 |
+
body_contour[:, :, i] = create_body_mask(slice_data, body_threshold=body_threshold)
|
| 155 |
+
|
| 156 |
+
# 应用掩膜到归一化 MR 图像
|
| 157 |
+
masked_image = apply_mask(img_normalized, body_contour)
|
| 158 |
+
|
| 159 |
+
# 处理分割图
|
| 160 |
+
# add contour background to the segmentation (all region inside body + 1)
|
| 161 |
+
combined_array = segmentation_img + body_contour
|
| 162 |
+
combined_array = np.clip(combined_array, 0, np.max(segmentation_img) + 1)
|
| 163 |
+
print_all_info(segmentation_img, 'seg')
|
| 164 |
+
processed_segmentation = process_segmentation(combined_array, csv_simulation_values)
|
| 165 |
+
|
| 166 |
+
# normalize to 0-1
|
| 167 |
+
# masked_image = masked_image/norm_max
|
| 168 |
+
# processed_segmentation = processed_segmentation/norm_max
|
| 169 |
+
|
| 170 |
+
if input_path.endswith('.nrrd'):
|
| 171 |
+
# 保存处理后的 MR 图像
|
| 172 |
+
nrrd.write(output_path1, masked_image, header)
|
| 173 |
+
|
| 174 |
+
# 保存处理后的分割图
|
| 175 |
+
nrrd.write(output_path2, processed_segmentation, header_seg)
|
| 176 |
+
|
| 177 |
+
# save the body contour mask
|
| 178 |
+
|
| 179 |
+
elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
|
| 180 |
+
img_processed = nib.Nifti1Image(masked_image, affine)
|
| 181 |
+
nib.save(img_processed, output_path1)
|
| 182 |
+
seg_processed = nib.Nifti1Image(processed_segmentation, affine)
|
| 183 |
+
nib.save(seg_processed, output_path2)
|
| 184 |
+
contour_processed = nib.Nifti1Image(body_contour, affine)
|
| 185 |
+
|
| 186 |
+
# Split the path into directory and filename
|
| 187 |
+
directory, filename = os.path.split(output_path2)
|
| 188 |
+
new_filename = filename.replace('seg', 'contour')
|
| 189 |
+
contour_path = os.path.join(directory, new_filename)
|
| 190 |
+
|
| 191 |
+
nib.save(contour_processed, contour_path)
|
| 192 |
+
return processed_segmentation
|
| 193 |
+
|
| 194 |
+
# 处理文件夹
|
| 195 |
+
def process_folder(input_folder1, input_folder2, output_folder1, output_folder2, csv_simulation_file, body_threshold=50):
|
| 196 |
+
# 读取CSV文件获取仿真CT灰度值 (两列)
|
| 197 |
+
csv_simulation_values = pd.read_csv(csv_simulation_file, header=None).to_numpy()
|
| 198 |
+
|
| 199 |
+
# 检查 csv_simulation_values 是否是二维数组
|
| 200 |
+
if csv_simulation_values.ndim == 1:
|
| 201 |
+
raise ValueError("CSV 文件格式不正确,应该包含两列:organ_index 和 simulation_value")
|
| 202 |
+
|
| 203 |
+
# 确保输出文件夹存在
|
| 204 |
+
os.makedirs(output_folder1, exist_ok=True)
|
| 205 |
+
os.makedirs(output_folder2, exist_ok=True)
|
| 206 |
+
|
| 207 |
+
for filename in os.listdir(input_folder1):
|
| 208 |
+
if filename.endswith('.nrrd'):
|
| 209 |
+
input_file_path = os.path.join(input_folder1, filename)
|
| 210 |
+
seg_file_path = os.path.join(input_folder2, filename)
|
| 211 |
+
output_file_path1 = os.path.join(output_folder1, filename)
|
| 212 |
+
output_file_path2 = os.path.join(output_folder2, filename)
|
| 213 |
+
|
| 214 |
+
print(f"Processing {input_file_path} with segmentation {seg_file_path}")
|
| 215 |
+
processed_seg = process_image(input_file_path, None, seg_file_path, csv_simulation_values, output_file_path1, output_file_path2, body_threshold)
|
| 216 |
+
|
| 217 |
+
def analyse_hist(input_path):
|
| 218 |
+
if input_path.endswith('.nrrd'):
|
| 219 |
+
img, header = nrrd.read(input_path)
|
| 220 |
+
elif input_path.endswith('.nii.gz'):
|
| 221 |
+
import nibabel as nib
|
| 222 |
+
img_metadata = nib.load(input_path)
|
| 223 |
+
img = img_metadata.get_fdata()
|
| 224 |
+
affine = img_metadata.affine
|
| 225 |
+
import numpy as np
|
| 226 |
+
import matplotlib.pyplot as plt
|
| 227 |
+
|
| 228 |
+
# Plot the histogram
|
| 229 |
+
print('shape of img: ', img.shape)
|
| 230 |
+
plt.hist(img[:, :, 50], bins=30, edgecolor='black', alpha=0.7)
|
| 231 |
+
plt.xlabel('Value')
|
| 232 |
+
plt.ylabel('Frequency')
|
| 233 |
+
plt.title('Value Distribution')
|
| 234 |
+
plt.show()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def process_csv(csv_file, output_folder1, output_folder2, csv_simulation_file, body_threshold=50, output_mr_csv_file='processed_mr_csv_file.csv'):
|
| 238 |
+
# 读取CSV文件获取仿真CT灰度值 (两列)
|
| 239 |
+
csv_simulation_values = pd.read_csv(csv_simulation_file, header=None).to_numpy()
|
| 240 |
+
#csv_simulation_values = pd.read_csv(csv_simulation_file)
|
| 241 |
+
|
| 242 |
+
# 检查 csv_simulation_values 是否是二维数组
|
| 243 |
+
if csv_simulation_values.ndim == 1:
|
| 244 |
+
raise ValueError("CSV 文件格式不正确,应该包含两列:organ_index 和 simulation_value")
|
| 245 |
+
|
| 246 |
+
# 确保输出文件夹存在
|
| 247 |
+
os.makedirs(output_folder1, exist_ok=True)
|
| 248 |
+
os.makedirs(output_folder2, exist_ok=True)
|
| 249 |
+
|
| 250 |
+
from step1_init_data_list import list_img_seg_ad_pIDs_from_new_simplified_csv
|
| 251 |
+
patient_IDs, Aorta_diss, segs, images = list_img_seg_ad_pIDs_from_new_simplified_csv(csv_file)
|
| 252 |
+
from tqdm import tqdm
|
| 253 |
+
dataset_list = []
|
| 254 |
+
for idx in tqdm(range(len(images))):
|
| 255 |
+
if (images[idx].endswith('.nii.gz') and segs[idx].endswith('.nii.gz')) or \
|
| 256 |
+
(images[idx].endswith('.nii') and segs[idx].endswith('.nii')):
|
| 257 |
+
input_file_path = images[idx]
|
| 258 |
+
seg_file_path = segs[idx]
|
| 259 |
+
patient_id = patient_IDs[idx]
|
| 260 |
+
ad = Aorta_diss[idx]
|
| 261 |
+
root_dir = os.path.dirname(input_file_path)
|
| 262 |
+
|
| 263 |
+
output_file_path1 = os.path.join(output_folder1, os.path.relpath(input_file_path, start=root_dir))
|
| 264 |
+
|
| 265 |
+
synthrad_basic_mr_name = 'mr'
|
| 266 |
+
synthrad_basic_seg_name = 'mr_merged_seg'
|
| 267 |
+
if os.path.basename(output_file_path1) == f'{synthrad_basic_mr_name}.nii.gz' or \
|
| 268 |
+
os.path.basename(output_file_path1) == f'{synthrad_basic_mr_name}.nii':
|
| 269 |
+
# Insert the patient ID in the filename
|
| 270 |
+
output_file_path1 = output_file_path1.replace(f'{synthrad_basic_mr_name}', f'mr_{patient_id}')
|
| 271 |
+
|
| 272 |
+
output_file_path2 = os.path.join(output_folder2, os.path.relpath(seg_file_path, start=root_dir))
|
| 273 |
+
|
| 274 |
+
if os.path.basename(output_file_path2) == f'{synthrad_basic_seg_name}.nii.gz' or \
|
| 275 |
+
os.path.basename(output_file_path2) == f'{synthrad_basic_seg_name}.nii':
|
| 276 |
+
# Insert the patient ID in the filename
|
| 277 |
+
output_file_path2 = output_file_path2.replace(f'{synthrad_basic_seg_name}', f'mr_seg_{patient_id}')
|
| 278 |
+
|
| 279 |
+
print(f"Processing {input_file_path} with segmentation {seg_file_path}")
|
| 280 |
+
print(f"Save results to {output_file_path1} and {output_file_path2}")
|
| 281 |
+
|
| 282 |
+
processed_seg = process_image(input_file_path, None, seg_file_path, csv_simulation_values, output_file_path1, output_file_path2, body_threshold)
|
| 283 |
+
|
| 284 |
+
# processed_mr_csv_file = ...
|
| 285 |
+
csv_mr_line = [patient_id,ad,output_file_path2,output_file_path1]
|
| 286 |
+
dataset_list.append(csv_mr_line)
|
| 287 |
+
|
| 288 |
+
import csv
|
| 289 |
+
with open(output_mr_csv_file, 'w', newline='') as f:
|
| 290 |
+
csvwriter = csv.writer(f)
|
| 291 |
+
csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
|
| 292 |
+
csvwriter.writerows(dataset_list)
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
import argparse
|
| 296 |
+
|
| 297 |
+
parser = argparse.ArgumentParser(description="Process MR images and segmentation maps, apply masks and replace grayscale values.")
|
| 298 |
+
parser.add_argument('--input_folder1', required=True, help="Path to the folder containing input MR .nrrd files.")
|
| 299 |
+
parser.add_argument('--input_folder2', required=True, help="Path to the folder containing segmentation .nrrd files.")
|
| 300 |
+
parser.add_argument('--output_folder1', required=True, help="Path to the folder to save the output MR files.")
|
| 301 |
+
parser.add_argument('--output_folder2', required=True, help="Path to the folder to save the output segmentation files.")
|
| 302 |
+
parser.add_argument('--csv_simulation_file', required=True, help="CSV file containing simulated CT grayscale values.")
|
| 303 |
+
parser.add_argument('--body_threshold', type=int, default=50, help="Threshold to separate body from background.")
|
| 304 |
+
args = parser.parse_args()
|
| 305 |
+
|
| 306 |
+
process_folder(args.input_folder1, args.input_folder2, args.output_folder1, args.output_folder2, args.csv_simulation_file, args.body_threshold)
|
dataprocesser/Preprocess_MR_Masks_overlay.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
| 3 |
+
import nrrd
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def load_nifti_file(input_path):
|
| 7 |
+
if input_path.endswith('.nrrd'):
|
| 8 |
+
data, header = nrrd.read(input_path)
|
| 9 |
+
return data, header
|
| 10 |
+
elif input_path.endswith('.nii.gz') or input_path.endswith(".nii"):
|
| 11 |
+
import nibabel as nib
|
| 12 |
+
img_metadata = nib.load(input_path)
|
| 13 |
+
img = img_metadata.get_fdata()
|
| 14 |
+
affine = img_metadata.affine
|
| 15 |
+
return img, affine
|
| 16 |
+
|
| 17 |
+
def save_nrrd_file(data, HeaderOrAffine, input_path, save_path):
|
| 18 |
+
#nrrd.write(save_path, data, header)
|
| 19 |
+
if input_path.endswith('.nrrd'):
|
| 20 |
+
nrrd.write(save_path, data, HeaderOrAffine)
|
| 21 |
+
|
| 22 |
+
elif input_path.endswith('.nii.gz') or input_path.endswith(".nii"):
|
| 23 |
+
import nibabel as nib
|
| 24 |
+
img_processed = nib.Nifti1Image(data, HeaderOrAffine)
|
| 25 |
+
nib.save(img_processed, save_path)
|
| 26 |
+
|
| 27 |
+
def overlay_images(mask_data, organ_data):
|
| 28 |
+
# Combine the images by adding the pixel values
|
| 29 |
+
organ_data = np.where(organ_data == 1, organ_data + 98, organ_data)
|
| 30 |
+
organ_data = np.where(organ_data == 2, organ_data + 197, organ_data)
|
| 31 |
+
organ_data = np.where(organ_data == 3, organ_data + 296, organ_data)
|
| 32 |
+
combined_data = mask_data + organ_data
|
| 33 |
+
return combined_data
|
| 34 |
+
|
| 35 |
+
def main(files1, files2, output_folder=None):
|
| 36 |
+
# files is the list including all basic MR segmentations
|
| 37 |
+
# files is the list including all basic MR tissue segmentations
|
| 38 |
+
|
| 39 |
+
print("preprocess length of seg files: ", len(files1))
|
| 40 |
+
print("preprocess length of tissue seg files: ", len(files2))
|
| 41 |
+
|
| 42 |
+
files2 = [file.replace('seg_tissue', 'seg') for file in files2]
|
| 43 |
+
|
| 44 |
+
files1 = set(files1)
|
| 45 |
+
files2 = set(files2)
|
| 46 |
+
|
| 47 |
+
common_files = files1.intersection(files2)
|
| 48 |
+
|
| 49 |
+
from tqdm import tqdm
|
| 50 |
+
for filename in tqdm(common_files):
|
| 51 |
+
if filename.endswith(".nrrd") or filename.endswith(".nii.gz") or filename.endswith(".nii"):
|
| 52 |
+
nrrd_path1 = filename
|
| 53 |
+
nrrd_path2 = filename.replace('seg', 'seg_tissue')
|
| 54 |
+
|
| 55 |
+
'''
|
| 56 |
+
if os.path.basename(filename) == 'mr_seg.nii.gz':
|
| 57 |
+
patient_ID = os.path.basename(os.path.dirname(filename))
|
| 58 |
+
output_file_name = os.path.basename(filename).replace("seg", f"seg_{patient_ID}")
|
| 59 |
+
else:
|
| 60 |
+
output_file_name = os.path.basename(filename)
|
| 61 |
+
'''
|
| 62 |
+
|
| 63 |
+
output_file_name = os.path.basename(filename)
|
| 64 |
+
output_file_name = output_file_name.replace("seg", "merged_seg")
|
| 65 |
+
if output_folder == None:
|
| 66 |
+
output_folder_current_patient = os.path.dirname(filename)
|
| 67 |
+
else:
|
| 68 |
+
output_folder_current_patient = output_folder
|
| 69 |
+
save_path = os.path.join(output_folder_current_patient, output_file_name)
|
| 70 |
+
|
| 71 |
+
print(f"Processing {nrrd_path1} and {nrrd_path2}, saving to {save_path}")
|
| 72 |
+
|
| 73 |
+
data1, header1 = load_nifti_file(nrrd_path1)
|
| 74 |
+
data2, header2 = load_nifti_file(nrrd_path2)
|
| 75 |
+
|
| 76 |
+
combined_data = overlay_images(data1, data2)
|
| 77 |
+
save_nrrd_file(combined_data, header1, nrrd_path1, save_path)
|
| 78 |
+
|
dataprocesser/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataprocesser.dataset_registry import DATASET_REGISTRY
|
| 2 |
+
# register all the dataset in DATASET_REGISTRY
|
| 3 |
+
import dataprocesser.dataset_anish
|
| 4 |
+
import dataprocesser.dataset_combined_csv
|
| 5 |
+
import dataprocesser.dataset_combined_synthrad_anish
|
| 6 |
+
import dataprocesser.dataset_csv_slice
|
| 7 |
+
import dataprocesser.dataset_json
|
| 8 |
+
import dataprocesser.dataset_synthrad
|
dataprocesser/archive/archiv.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..basics import get_file_list,crop_volumes, load_volumes
|
| 2 |
+
def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,window_width=1,ifcheck=True):
|
| 3 |
+
patch_func = monai.data.PatchIterd(
|
| 4 |
+
keys=["source", "target"],
|
| 5 |
+
patch_size=(None, None, window_width), # dynamic first two dimensions
|
| 6 |
+
start_pos=(0, 0, 0)
|
| 7 |
+
)
|
| 8 |
+
if window_width==1:
|
| 9 |
+
patch_transform = Compose(
|
| 10 |
+
[
|
| 11 |
+
SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
|
| 12 |
+
]
|
| 13 |
+
)
|
| 14 |
+
else:
|
| 15 |
+
patch_transform = None
|
| 16 |
+
# for training
|
| 17 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 18 |
+
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 19 |
+
train_loader = DataLoader(
|
| 20 |
+
train_patch_ds,
|
| 21 |
+
batch_size=train_batch_size,
|
| 22 |
+
num_workers=2,
|
| 23 |
+
pin_memory=torch.cuda.is_available(),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# for validation
|
| 27 |
+
val_loader = DataLoader(
|
| 28 |
+
val_volume_ds,
|
| 29 |
+
num_workers=1,
|
| 30 |
+
batch_size=val_batch_size,
|
| 31 |
+
pin_memory=torch.cuda.is_available())
|
| 32 |
+
|
| 33 |
+
if ifcheck:
|
| 34 |
+
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
|
| 35 |
+
return train_loader,val_loader
|
| 36 |
+
|
| 37 |
+
def load_batch_slices3D(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,ifcheck=True):
|
| 38 |
+
patch_func = monai.data.PatchIterd(
|
| 39 |
+
keys=["source", "target"],
|
| 40 |
+
patch_size=(None, None,32), # dynamic first two dimensions
|
| 41 |
+
start_pos=(0, 0, 0)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# for training
|
| 45 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 46 |
+
data=train_volume_ds, patch_iter=patch_func, with_coordinates=False)
|
| 47 |
+
train_loader = DataLoader(
|
| 48 |
+
train_patch_ds,
|
| 49 |
+
batch_size=train_batch_size,
|
| 50 |
+
num_workers=2,
|
| 51 |
+
pin_memory=torch.cuda.is_available(),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# for validation
|
| 55 |
+
val_loader = DataLoader(
|
| 56 |
+
val_volume_ds,
|
| 57 |
+
num_workers=1,
|
| 58 |
+
batch_size=val_batch_size,
|
| 59 |
+
pin_memory=torch.cuda.is_available())
|
| 60 |
+
|
| 61 |
+
if ifcheck:
|
| 62 |
+
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
|
| 63 |
+
return train_loader,val_loader
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def mydataloader_3d(data_pelvis_path,
|
| 68 |
+
train_number,
|
| 69 |
+
val_number,
|
| 70 |
+
train_batch_size,
|
| 71 |
+
val_batch_size,
|
| 72 |
+
saved_name_train='./train_ds_2d.csv',
|
| 73 |
+
saved_name_val='./val_ds_2d.csv',
|
| 74 |
+
resized_size=(600,400,150),
|
| 75 |
+
div_size=(16,16,16),
|
| 76 |
+
ifcheck_volume=True,):
|
| 77 |
+
# volume-level transforms for both image and segmentation
|
| 78 |
+
normalize='zscore'
|
| 79 |
+
train_transforms = get_transforms(normalize,resized_size,div_size)
|
| 80 |
+
|
| 81 |
+
train_ds, val_ds = get_file_list(data_pelvis_path,
|
| 82 |
+
train_number,
|
| 83 |
+
val_number)
|
| 84 |
+
#train_volume_ds, val_volume_ds
|
| 85 |
+
|
| 86 |
+
train_volume_ds,val_volume_ds = load_volumes(train_transforms=train_transforms,
|
| 87 |
+
train_ds=train_ds,
|
| 88 |
+
val_ds=val_ds,
|
| 89 |
+
saved_name_train=saved_name_train,
|
| 90 |
+
saved_name_val=saved_name_train,
|
| 91 |
+
ifsave=True,
|
| 92 |
+
ifcheck=ifcheck_volume)
|
| 93 |
+
'''
|
| 94 |
+
train_loader = DataLoader(train_volume_ds, batch_size=train_batch_size)
|
| 95 |
+
val_loader = DataLoader(val_volume_ds, batch_size=val_batch_size)
|
| 96 |
+
'''
|
| 97 |
+
ifcheck_sclices=False
|
| 98 |
+
train_loader,val_loader = load_batch_slices3D(train_volume_ds,
|
| 99 |
+
val_volume_ds,
|
| 100 |
+
train_batch_size,
|
| 101 |
+
val_batch_size=val_batch_size,
|
| 102 |
+
ifcheck=ifcheck_sclices)
|
| 103 |
+
|
| 104 |
+
return train_loader,val_loader,train_transforms
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
from torchvision.utils import save_image
|
| 108 |
+
def save_dataset_as_png(train_ds, train_volume_ds,saved_img_folder,saved_label_folder):
|
| 109 |
+
train_loader = DataLoader(train_volume_ds, batch_size=1)
|
| 110 |
+
for idx, train_check_data in enumerate(train_loader):
|
| 111 |
+
image_volume = train_check_data['image']
|
| 112 |
+
label_volume = train_check_data['label']
|
| 113 |
+
current_item = train_ds[idx]
|
| 114 |
+
file_name_prex = os.path.basename(os.path.dirname(current_item['image']))
|
| 115 |
+
slices_num=image_volume.shape[-1]
|
| 116 |
+
for i in range(slices_num):
|
| 117 |
+
image_i=image_volume[0,0,:,:,i]
|
| 118 |
+
label_i=label_volume[0,0,:,:,i]
|
| 119 |
+
#print(label_volume.shape)
|
| 120 |
+
#SaveImage(output_dir=saved_img_folder, output_postfix=f'{file_name_prex}_image', output_ext='.png', resample=True)(image_volume[0,:,:,:,0])
|
| 121 |
+
save_image(image_i, f'{saved_img_folder}\{file_name_prex}_image_{i}.png')
|
| 122 |
+
save_image(label_i, f'{saved_label_folder}\{file_name_prex}_label_{i}.png')
|
| 123 |
+
|
| 124 |
+
def pre_dataset_for_stylegan(data_pelvis_path,
|
| 125 |
+
normalize,
|
| 126 |
+
train_number,
|
| 127 |
+
val_number,
|
| 128 |
+
saved_img_folder,
|
| 129 |
+
saved_label_folder,
|
| 130 |
+
saved_name_train='./train_ds_2d.csv',
|
| 131 |
+
saved_name_val='./val_ds_2d.csv',
|
| 132 |
+
resized_size=(600,400,None),
|
| 133 |
+
div_size=(16,16,None),):
|
| 134 |
+
train_transforms = get_transforms(normalize,resized_size,div_size)
|
| 135 |
+
train_ds, val_ds = get_file_list(data_pelvis_path,
|
| 136 |
+
train_number,
|
| 137 |
+
val_number)
|
| 138 |
+
train_volume_ds, _ = load_volumes(train_transforms,
|
| 139 |
+
train_ds,
|
| 140 |
+
val_ds,
|
| 141 |
+
saved_name_train,
|
| 142 |
+
saved_name_val,
|
| 143 |
+
ifsave=False,
|
| 144 |
+
ifcheck=False)
|
| 145 |
+
save_dataset_as_png(train_ds, train_volume_ds,saved_img_folder,saved_label_folder)
|
| 146 |
+
return train_ds,train_volume_ds
|
| 147 |
+
|
| 148 |
+
def sum_slices(data_pelvis_path, num=180):
|
| 149 |
+
train_ds, val_ds=get_file_list(data_pelvis_path, 0, num)
|
| 150 |
+
train_ds_2d, val_ds_2d,\
|
| 151 |
+
all_slices_train,all_slices_val,\
|
| 152 |
+
shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds,
|
| 153 |
+
saved_name_train='./train_ds_2d.csv',
|
| 154 |
+
saved_name_val='./val_ds_2d.csv',
|
| 155 |
+
ifsave=False)
|
| 156 |
+
print(all_slices_val)
|
| 157 |
+
return all_slices_val
|
| 158 |
+
|
| 159 |
+
def transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val,ifsave=True):
|
| 160 |
+
# Load 2D slices of CT images
|
| 161 |
+
train_ds_2d = []
|
| 162 |
+
val_ds_2d = []
|
| 163 |
+
shape_list_train = []
|
| 164 |
+
shape_list_val = []
|
| 165 |
+
all_slices_train=0
|
| 166 |
+
all_slices_val=0
|
| 167 |
+
|
| 168 |
+
# Load 2D slices for training
|
| 169 |
+
for sample in train_ds:
|
| 170 |
+
train_ds_2d_image = LoadImaged(keys=["source","target"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
|
| 171 |
+
name = os.path.basename(os.path.dirname(sample['image']))
|
| 172 |
+
num_slices = train_ds_2d_image["source"].shape[-1]
|
| 173 |
+
shape_list_train.append({'patient': name, 'shape': train_ds_2d_image["image"].shape})
|
| 174 |
+
for i in range(num_slices):
|
| 175 |
+
train_ds_2d.append({'image': train_ds_2d_image['image'][:,:,i], 'label': train_ds_2d_image['label'][:,:,i]})
|
| 176 |
+
all_slices_train += num_slices
|
| 177 |
+
|
| 178 |
+
# Load 2D slices for validation
|
| 179 |
+
for sample in val_ds:
|
| 180 |
+
val_ds_2d_image = LoadImaged(keys=["source","target"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
|
| 181 |
+
name = os.path.basename(os.path.dirname(sample['image']))
|
| 182 |
+
shape_list_val.append({'patient': name, 'shape': val_ds_2d_image["image"].shape})
|
| 183 |
+
num_slices = val_ds_2d_image["image"].shape[-1]
|
| 184 |
+
for i in range(num_slices):
|
| 185 |
+
val_ds_2d.append({'image': val_ds_2d_image['image'][:,:,i], 'label': val_ds_2d_image['label'][:,:,i]})
|
| 186 |
+
all_slices_val += num_slices
|
| 187 |
+
# Save shape list to csv
|
| 188 |
+
if ifsave:
|
| 189 |
+
np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
|
| 190 |
+
np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
|
| 191 |
+
return train_ds_2d, val_ds_2d, all_slices_train, all_slices_val, shape_list_train, shape_list_val
|
| 192 |
+
|
| 193 |
+
def get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size, val_batch_size,normalize, resized_size=(600,400), div_size=(16,16,None),):
|
| 194 |
+
# Define transforms
|
| 195 |
+
train_transforms = get_transforms(normalize,resized_size,div_size)
|
| 196 |
+
train_transforms_list=train_transforms.__dict__['transforms']
|
| 197 |
+
batch_size = batch_size
|
| 198 |
+
# Create training dataset and data loader
|
| 199 |
+
train_dataset = Dataset(data=train_ds_2d, transform=train_transforms)
|
| 200 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
| 201 |
+
|
| 202 |
+
val_batch_size = val_batch_size
|
| 203 |
+
# Create validation dataset and data loader
|
| 204 |
+
val_dataset = Dataset(data=val_ds_2d, transform=train_transforms)
|
| 205 |
+
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=1, pin_memory=True)
|
| 206 |
+
return train_loader, val_loader, train_transforms_list,train_transforms
|
| 207 |
+
|
| 208 |
+
def mydataloader(data_pelvis_path,
|
| 209 |
+
train_number,
|
| 210 |
+
val_number,
|
| 211 |
+
batch_size,
|
| 212 |
+
val_batch_size,
|
| 213 |
+
saved_name_train='./train_ds_2d.csv',
|
| 214 |
+
saved_name_val='./val_ds_2d.csv',
|
| 215 |
+
resized_size=(600,400)):
|
| 216 |
+
train_ds, val_ds = get_file_list(data_pelvis_path,
|
| 217 |
+
train_number,
|
| 218 |
+
val_number)
|
| 219 |
+
train_ds_2d, val_ds_2d,\
|
| 220 |
+
all_slices_train,all_slices_val,\
|
| 221 |
+
shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds,
|
| 222 |
+
saved_name_train,
|
| 223 |
+
saved_name_val,ifsave=True)
|
| 224 |
+
|
| 225 |
+
train_loader, val_loader, \
|
| 226 |
+
train_transforms_list,train_transforms = get_train_val_loaders(train_ds_2d,
|
| 227 |
+
val_ds_2d,
|
| 228 |
+
batch_size=batch_size,
|
| 229 |
+
val_batch_size=val_batch_size,
|
| 230 |
+
normalize='zscore',
|
| 231 |
+
resized_size=resized_size,
|
| 232 |
+
div_size=(16,16,None),)
|
| 233 |
+
return train_loader,val_loader,\
|
| 234 |
+
train_transforms_list,train_transforms,\
|
| 235 |
+
all_slices_train,all_slices_val,\
|
| 236 |
+
shape_list_train,shape_list_val
|
dataprocesser/archive/basics.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import monai
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from monai.transforms import (
|
| 5 |
+
Compose,
|
| 6 |
+
LoadImaged,
|
| 7 |
+
Rotate90d,
|
| 8 |
+
ScaleIntensityd,
|
| 9 |
+
EnsureChannelFirstd,
|
| 10 |
+
ResizeWithPadOrCropd,
|
| 11 |
+
DivisiblePadd,
|
| 12 |
+
ThresholdIntensityd,
|
| 13 |
+
NormalizeIntensityd,
|
| 14 |
+
SqueezeDimd,
|
| 15 |
+
ShiftIntensityd,
|
| 16 |
+
Identityd,
|
| 17 |
+
CenterSpatialCropd,
|
| 18 |
+
ScaleIntensityRanged,
|
| 19 |
+
Spacingd,
|
| 20 |
+
)
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
from .checkdata import check_volumes, save_volumes
|
| 23 |
+
def get_file_list(data_pelvis_path, train_number, val_number, source='mr', target='ct'):
|
| 24 |
+
#list all files in the folder
|
| 25 |
+
file_list=[i for i in os.listdir(data_pelvis_path) if 'overview' not in i]
|
| 26 |
+
file_list_path=[os.path.join(data_pelvis_path,i) for i in file_list]
|
| 27 |
+
#list all ct and mr files in folder
|
| 28 |
+
source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
|
| 29 |
+
target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path] #mr
|
| 30 |
+
# Dict Version
|
| 31 |
+
# source -> image
|
| 32 |
+
# target -> label
|
| 33 |
+
train_ds = [{'source': i, 'target': j, 'A_paths': i, 'B_paths': j} for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])]
|
| 34 |
+
val_ds = [{'source': i, 'target': j, 'A_paths': i, 'B_paths': j} for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])]
|
| 35 |
+
print('all files in dataset:',len(file_list))
|
| 36 |
+
return train_ds, val_ds
|
| 37 |
+
|
| 38 |
+
def load_volumes(train_transforms,val_transforms,
|
| 39 |
+
train_crop_ds, val_crop_ds,
|
| 40 |
+
train_ds, val_ds,
|
| 41 |
+
saved_name_train=None, saved_name_val=None,
|
| 42 |
+
ifsave=False,ifcheck=False):
|
| 43 |
+
train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
|
| 44 |
+
val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
|
| 45 |
+
if ifsave:
|
| 46 |
+
save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
|
| 47 |
+
if ifcheck:
|
| 48 |
+
check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
|
| 49 |
+
return train_volume_ds,val_volume_ds
|
| 50 |
+
|
| 51 |
+
def crop_volumes(train_ds, val_ds,center_crop,resized_size=(512,512,None),pad='minimum'):
|
| 52 |
+
if center_crop>0:
|
| 53 |
+
crop=Compose([LoadImaged(keys=["source", "target"]),
|
| 54 |
+
EnsureChannelFirstd(keys=["source", "target"]),
|
| 55 |
+
CenterSpatialCropd(keys=["source", "target"], roi_size=(-1,-1,center_crop)),
|
| 56 |
+
|
| 57 |
+
])
|
| 58 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
|
| 59 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
|
| 60 |
+
print('center crop:',center_crop)
|
| 61 |
+
else:
|
| 62 |
+
crop=Compose([LoadImaged(keys=["source", "target"]),
|
| 63 |
+
EnsureChannelFirstd(keys=["source", "target"]),
|
| 64 |
+
])
|
| 65 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
|
| 66 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
|
| 67 |
+
return train_crop_ds, val_crop_ds
|
| 68 |
+
|
| 69 |
+
def get_transforms(configs, mode='train'):
|
| 70 |
+
normalize=configs.dataset.normalize
|
| 71 |
+
pad=configs.dataset.pad
|
| 72 |
+
resized_size=configs.dataset.resized_size
|
| 73 |
+
WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
|
| 74 |
+
WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
|
| 75 |
+
prob=configs.dataset.augmentationProb
|
| 76 |
+
background=configs.dataset.background
|
| 77 |
+
|
| 78 |
+
transform_list=[]
|
| 79 |
+
min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
|
| 80 |
+
transform_list.append(ThresholdIntensityd(keys=["target"], threshold=min, above=True, cval=background))
|
| 81 |
+
#transform_list.append(ThresholdIntensityd(keys=["target"], threshold=max, above=False, cval=-1000))
|
| 82 |
+
# filter the source images
|
| 83 |
+
# transform_list.append(ThresholdIntensityd(keys=["source"], threshold=configs.dataset.MRImax, above=False, cval=0))
|
| 84 |
+
if normalize=='zscore':
|
| 85 |
+
transform_list.append(NormalizeIntensityd(keys=["source", "target"], nonzero=False, channel_wise=True))
|
| 86 |
+
print('zscore normalization')
|
| 87 |
+
elif normalize=='minmax':
|
| 88 |
+
transform_list.append(ScaleIntensityd(keys=["source", "target"], minv=-1, maxv=1.0))
|
| 89 |
+
print('minmax normalization')
|
| 90 |
+
|
| 91 |
+
elif normalize=='scale4000':
|
| 92 |
+
transform_list.append(ScaleIntensityd(keys=["source"], minv=-1, maxv=1))
|
| 93 |
+
transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
|
| 94 |
+
transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975)) # x=x(1+factor)
|
| 95 |
+
print('scale1000 normalization')
|
| 96 |
+
|
| 97 |
+
elif normalize=='scale1000':
|
| 98 |
+
transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1))
|
| 99 |
+
transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
|
| 100 |
+
transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975))
|
| 101 |
+
print('scale1000 normalization')
|
| 102 |
+
|
| 103 |
+
elif normalize=='inputonlyzscore':
|
| 104 |
+
transform_list.append(NormalizeIntensityd(keys=["source"], nonzero=False, channel_wise=True))
|
| 105 |
+
print('only normalize input MRI images')
|
| 106 |
+
|
| 107 |
+
elif normalize=='inputonlyminmax':
|
| 108 |
+
transform_list.append(ScaleIntensityd(keys=["source"], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
|
| 109 |
+
print('only normalize input MRI images')
|
| 110 |
+
elif normalize=='none':
|
| 111 |
+
print('no normalization')
|
| 112 |
+
transform_list.append(Spacingd(keys=["source"], pixdim=(1.0, 1.0, 1.0), mode="bilinear")) #
|
| 113 |
+
transform_list.append(Spacingd(keys=["target", "mask"], pixdim=(1.0, 1.0 , 2.5), mode="bilinear")) #
|
| 114 |
+
transform_list.append(ResizeWithPadOrCropd(keys=["source", "target", "mask"], spatial_size=resized_size,mode=pad))
|
| 115 |
+
# transform_list.append(ScaleIntensityRanged(keys=["target"],a_min=WINDOW_LEVEL-(WINDOW_WIDTH/2), a_max=WINDOW_LEVEL+(WINDOW_WIDTH/2),b_min=0, b_max=1, clip=True))
|
| 116 |
+
|
| 117 |
+
if mode == 'train':
|
| 118 |
+
from monai.transforms import (
|
| 119 |
+
# data augmentation
|
| 120 |
+
RandRotated,
|
| 121 |
+
RandZoomd,
|
| 122 |
+
RandBiasFieldd,
|
| 123 |
+
RandAffined,
|
| 124 |
+
RandGridDistortiond,
|
| 125 |
+
RandGridPatchd,
|
| 126 |
+
RandShiftIntensityd,
|
| 127 |
+
RandGibbsNoised,
|
| 128 |
+
RandAdjustContrastd,
|
| 129 |
+
RandGaussianSmoothd,
|
| 130 |
+
RandGaussianSharpend,
|
| 131 |
+
RandGaussianNoised,
|
| 132 |
+
)
|
| 133 |
+
Aug=True
|
| 134 |
+
if Aug:
|
| 135 |
+
transform_list.append(RandRotated(keys=["source", "target", "mask"], range_x = 0.1, range_y = 0.1, range_z = 0.1, prob=prob, padding_mode="border", keep_size=True))
|
| 136 |
+
transform_list.append(RandZoomd(keys=["source", "target", "mask"], prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True))
|
| 137 |
+
transform_list.append(RandAffined(keys=["source", "target", "mask"],padding_mode="border" , prob=prob))
|
| 138 |
+
#transform_list.append(Rand3DElasticd(keys=["source", "target"], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
|
| 139 |
+
intensityAug=False
|
| 140 |
+
if intensityAug:
|
| 141 |
+
print('intensity data augmentation is used')
|
| 142 |
+
transform_list.append(RandBiasFieldd(keys=["source"], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
|
| 143 |
+
transform_list.append(RandGaussianNoised(keys=["source"], prob=prob, mean=0.0, std=0.01))
|
| 144 |
+
transform_list.append(RandAdjustContrastd(keys=["source"], prob=prob, gamma=(0.5, 1.5)))
|
| 145 |
+
transform_list.append(RandShiftIntensityd(keys=["source"], prob=prob, offsets=20))
|
| 146 |
+
transform_list.append(RandGaussianSharpend(keys=["source"], alpha=(0.2, 0.8), prob=prob))
|
| 147 |
+
|
| 148 |
+
#transform_list.append(Rotate90d(keys=["source", "target"], k=3))
|
| 149 |
+
#transform_list.append(DivisiblePadd(keys=["source", "target"], k=div_size, mode="minimum"))
|
| 150 |
+
#transform_list.append(Identityd(keys=["source", "target"])) # do nothing for the no norm case
|
| 151 |
+
train_transforms = Compose(transform_list)
|
| 152 |
+
return train_transforms
|
| 153 |
+
|
| 154 |
+
def get_length(dataset, patch_batch_size):
|
| 155 |
+
loader=DataLoader(dataset, batch_size=1)
|
| 156 |
+
iterator = iter(loader)
|
| 157 |
+
sum_nslices=0
|
| 158 |
+
for idx in range(len(loader)):
|
| 159 |
+
check_data = next(iterator)
|
| 160 |
+
nslices=check_data['source'].shape[-1]
|
| 161 |
+
sum_nslices+=nslices
|
| 162 |
+
if sum_nslices%patch_batch_size==0:
|
| 163 |
+
return sum_nslices//patch_batch_size
|
| 164 |
+
else:
|
| 165 |
+
return sum_nslices//patch_batch_size+1
|
| 166 |
+
|
| 167 |
+
|
dataprocesser/archive/checkdata.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
def test_volumes_pixdim(train_volume_ds):
|
| 5 |
+
train_loader = DataLoader(train_volume_ds, batch_size=1)
|
| 6 |
+
for step, data in enumerate(train_loader):
|
| 7 |
+
mr_data=data['source']
|
| 8 |
+
ct_data=data['target']
|
| 9 |
+
|
| 10 |
+
print(f"source image shape: {mr_data.shape}")
|
| 11 |
+
print(f"source image affine:\n{mr_data.meta['affine']}")
|
| 12 |
+
print(f"source image pixdim:\n{mr_data.pixdim}")
|
| 13 |
+
|
| 14 |
+
# target image information
|
| 15 |
+
print(f"target image shape: {ct_data.shape}")
|
| 16 |
+
print(f"target image affine:\n{ct_data.meta['affine']}")
|
| 17 |
+
print(f"target image pixdim:\n{ct_data.pixdim}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds):
|
| 21 |
+
# use batch_size=1 to check the volumes because the input volumes have different shapes
|
| 22 |
+
train_loader = DataLoader(train_volume_ds, batch_size=1)
|
| 23 |
+
val_loader = DataLoader(val_volume_ds, batch_size=1)
|
| 24 |
+
train_iterator = iter(train_loader)
|
| 25 |
+
val_iterator = iter(val_loader)
|
| 26 |
+
print('check training data:')
|
| 27 |
+
idx=0
|
| 28 |
+
for idx in range(len(train_loader)):
|
| 29 |
+
try:
|
| 30 |
+
train_check_data = next(train_iterator)
|
| 31 |
+
ds_idx = idx * 1
|
| 32 |
+
current_item = train_ds[ds_idx]
|
| 33 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 34 |
+
print(idx, current_name, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
|
| 35 |
+
except:
|
| 36 |
+
ds_idx = idx * 1
|
| 37 |
+
current_item = train_ds[ds_idx]
|
| 38 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 39 |
+
print('check data error! Check the input data:',current_name)
|
| 40 |
+
print("checked all training data.")
|
| 41 |
+
|
| 42 |
+
print('check validation data:')
|
| 43 |
+
idx=0
|
| 44 |
+
for idx in range(len(val_loader)):
|
| 45 |
+
try:
|
| 46 |
+
val_check_data = next(val_iterator)
|
| 47 |
+
ds_idx = idx * 1
|
| 48 |
+
current_item = val_ds[ds_idx]
|
| 49 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 50 |
+
print(idx, current_name, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
|
| 51 |
+
except:
|
| 52 |
+
ds_idx = idx * 1
|
| 53 |
+
current_item = val_ds[ds_idx]
|
| 54 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 55 |
+
print('check data error! Check the input data:',current_name)
|
| 56 |
+
print("checked all validation data.")
|
| 57 |
+
|
| 58 |
+
def save_volumes(train_ds, val_ds, saved_name_train, saved_name_val):
|
| 59 |
+
shape_list_train=[]
|
| 60 |
+
shape_list_val=[]
|
| 61 |
+
# use the function of saving information before
|
| 62 |
+
for sample in train_ds:
|
| 63 |
+
name = os.path.basename(os.path.dirname(sample['image']))
|
| 64 |
+
shape_list_train.append({'patient': name})
|
| 65 |
+
for sample in val_ds:
|
| 66 |
+
name = os.path.basename(os.path.dirname(sample['image']))
|
| 67 |
+
shape_list_val.append({'patient': name})
|
| 68 |
+
np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
|
| 69 |
+
np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
|
| 70 |
+
|
| 71 |
+
def check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size):
|
| 72 |
+
for idx, train_check_data in enumerate(train_loader):
|
| 73 |
+
ds_idx = idx * train_batch_size
|
| 74 |
+
current_item = train_patch_ds[ds_idx]
|
| 75 |
+
print('check train data:')
|
| 76 |
+
print(current_item, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
|
| 77 |
+
|
| 78 |
+
for idx, val_check_data in enumerate(val_loader):
|
| 79 |
+
ds_idx = idx * val_batch_size
|
| 80 |
+
current_item = val_volume_ds[ds_idx]
|
| 81 |
+
print('check val data:')
|
| 82 |
+
print(current_item, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
|
| 83 |
+
|
| 84 |
+
def len_patchloader(train_volume_ds,train_batch_size):
|
| 85 |
+
slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds)))
|
| 86 |
+
print('total slices in training set:',slice_number)
|
| 87 |
+
|
| 88 |
+
import math
|
| 89 |
+
batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
|
| 90 |
+
print('total batches in training set:',batch_number)
|
| 91 |
+
return slice_number,batch_number
|
dataprocesser/archive/createsegtransform.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from totalsegmentator.python_api import totalsegmentator
|
| 3 |
+
class CreateMaskTransformd:
|
| 4 |
+
def __init__(self, keys, tissue_min, tissue_max, bone_min, bone_max, mask_value_bones=2,
|
| 5 |
+
if_use_total_seg=False, organ_label_id=52, mask_value_organ=2, fast=True):
|
| 6 |
+
self.keys = keys
|
| 7 |
+
self.tissue_min = tissue_min
|
| 8 |
+
self.tissue_max = tissue_max
|
| 9 |
+
self.bone_min = bone_min
|
| 10 |
+
self.bone_max = bone_max
|
| 11 |
+
self.mask_value_bones = mask_value_bones
|
| 12 |
+
|
| 13 |
+
self.if_use_total_seg = if_use_total_seg
|
| 14 |
+
self.organ_label_id = organ_label_id
|
| 15 |
+
self.mask_value_organ = mask_value_organ
|
| 16 |
+
self.fast = fast
|
| 17 |
+
|
| 18 |
+
def extract_organ_mask(self, input_img, organ_label_id, mask_value):
|
| 19 |
+
# aorta = 52
|
| 20 |
+
"""
|
| 21 |
+
Extracts a binary mask for a specific organ from a labeled NIFTI image.
|
| 22 |
+
|
| 23 |
+
img_in: NIFTI image with segmentation labels.
|
| 24 |
+
organ_name: Name of the organ to extract.
|
| 25 |
+
label_map: Dictionary mapping label IDs to organ names.
|
| 26 |
+
|
| 27 |
+
returns: Binary mask as a NIFTI image.
|
| 28 |
+
"""
|
| 29 |
+
img_in = totalsegmentator(input=input_img, task='total',fast=self.fast)
|
| 30 |
+
data = img_in.get_fdata()
|
| 31 |
+
|
| 32 |
+
# Create a binary mask for the specified organ
|
| 33 |
+
organ_mask_data = np.zeros_like(data)
|
| 34 |
+
organ_mask_data[data == organ_label_id] = mask_value
|
| 35 |
+
|
| 36 |
+
# Create a new NIFTI image for the binary mask
|
| 37 |
+
organ_mask_img = nib.Nifti1Image(organ_mask_data, img_in.affine, img_in.header)
|
| 38 |
+
return organ_mask_img
|
| 39 |
+
|
| 40 |
+
def __call__(self, data):
|
| 41 |
+
for key in self.keys:
|
| 42 |
+
x = data[key]
|
| 43 |
+
|
| 44 |
+
mask = torch.zeros_like(x)
|
| 45 |
+
# [B, H, W, D]
|
| 46 |
+
# create a mask for each slice in the batch
|
| 47 |
+
for i in range(x.shape[0]):
|
| 48 |
+
if self.if_use_total_seg:
|
| 49 |
+
mask_batch_i = self.extract_organ_mask(x[i,:,:,:], organ_label_id=self.organ_label_id, mask_value=self.mask_value_organ)
|
| 50 |
+
mask[i,:,:,:] = mask_batch_i
|
| 51 |
+
for j in range(x.shape[-1]):
|
| 52 |
+
mask_slice = create_body_mask(x[i,:,:,j], body_threshold=self.tissue_min)
|
| 53 |
+
mask[i,:,:, j] = mask_slice
|
| 54 |
+
#mask = torch.zeros_like(x)
|
| 55 |
+
#mask[(x > self.tissue_min) & (x <= self.tissue_max)] = 1
|
| 56 |
+
mask[(x >= self.bone_min) & (x <= self.bone_max)] = self.mask_value_bones
|
| 57 |
+
data[key] = mask
|
| 58 |
+
#print("input and mask shape: ",x.shape,data[key].shape)
|
| 59 |
+
return data
|
| 60 |
+
|
| 61 |
+
class CreateSegTransformd:
|
| 62 |
+
# create a mask by segmenting the input image using totalsegmentator
|
| 63 |
+
def __init__(self, keys, organ_label_id=52, mask_value=2, fast=True):
|
| 64 |
+
self.keys = keys
|
| 65 |
+
self.organ_label_id = organ_label_id
|
| 66 |
+
self.mask_value = mask_value
|
| 67 |
+
self.fast = fast
|
| 68 |
+
|
| 69 |
+
def extract_organ_mask(self, input_img, organ_label_id, mask_value):
|
| 70 |
+
# aorta = 52
|
| 71 |
+
"""
|
| 72 |
+
Extracts a binary mask for a specific organ from a labeled NIFTI image.
|
| 73 |
+
|
| 74 |
+
img_in: NIFTI image with segmentation labels.
|
| 75 |
+
organ_name: Name of the organ to extract.
|
| 76 |
+
label_map: Dictionary mapping label IDs to organ names.
|
| 77 |
+
|
| 78 |
+
returns: Binary mask as a NIFTI image.
|
| 79 |
+
"""
|
| 80 |
+
img_in = totalsegmentator(input=input_img, task='total',fast=self.fast)
|
| 81 |
+
data = img_in.get_fdata()
|
| 82 |
+
|
| 83 |
+
if organ_label_id>0:
|
| 84 |
+
# Create a binary mask for the specified organ
|
| 85 |
+
organ_mask_data = np.zeros_like(data)
|
| 86 |
+
organ_mask_data[data == organ_label_id] = mask_value
|
| 87 |
+
else:
|
| 88 |
+
organ_mask_data=data
|
| 89 |
+
|
| 90 |
+
# Create a new NIFTI image for the binary mask
|
| 91 |
+
organ_mask_img = nib.Nifti1Image(organ_mask_data, img_in.affine, img_in.header)
|
| 92 |
+
return organ_mask_img
|
| 93 |
+
|
| 94 |
+
def __call__(self, data):
|
| 95 |
+
for key in self.keys:
|
| 96 |
+
x = data[key]
|
| 97 |
+
mask = torch.zeros_like(x)
|
| 98 |
+
# [B, H, W, D]
|
| 99 |
+
for i in range(x.shape[0]):
|
| 100 |
+
mask_batch_i = self.extract_organ_mask(x[i,:,:,:], organ_label_id=self.organ_label_id, mask_value=self.mask_value)
|
| 101 |
+
mask[i,:,:,:] = mask_batch_i
|
| 102 |
+
data[key] = mask
|
| 103 |
+
return data
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class CreateTotalSegTransformd:
|
| 107 |
+
# create a mask by segmenting the input image using totalsegmentator
|
| 108 |
+
def __init__(self, keys, fast=True):
|
| 109 |
+
self.keys = keys
|
| 110 |
+
self.fast = fast
|
| 111 |
+
|
| 112 |
+
def extract_organ_mask(self, input_img):
|
| 113 |
+
# aorta = 52
|
| 114 |
+
"""
|
| 115 |
+
Extracts a binary mask for a specific organ from a labeled NIFTI image.
|
| 116 |
+
|
| 117 |
+
img_in: NIFTI image with segmentation labels.
|
| 118 |
+
organ_name: Name of the organ to extract.
|
| 119 |
+
label_map: Dictionary mapping label IDs to organ names.
|
| 120 |
+
|
| 121 |
+
returns: Binary mask as a NIFTI image.
|
| 122 |
+
"""
|
| 123 |
+
#print(input_img.meta)
|
| 124 |
+
input_affine = input_img.meta['affine']
|
| 125 |
+
|
| 126 |
+
input_img = torch_tensor_to_nifti(input_img, affine=input_affine)
|
| 127 |
+
img_in = totalsegmentator(input=input_img, task='total', fast=self.fast)
|
| 128 |
+
data = img_in.get_fdata()
|
| 129 |
+
organ_mask_data=data
|
| 130 |
+
# Create a new NIFTI image for the binary mask
|
| 131 |
+
organ_mask_img = nib.Nifti1Image(organ_mask_data, img_in.affine, img_in.header)
|
| 132 |
+
return organ_mask_img
|
| 133 |
+
|
| 134 |
+
def __call__(self, data):
|
| 135 |
+
for key in self.keys:
|
| 136 |
+
x = data[key]
|
| 137 |
+
mask = torch.zeros_like(x)
|
| 138 |
+
# [B, H, W, D]
|
| 139 |
+
for i in range(x.shape[0]):
|
| 140 |
+
mask_batch_i = self.extract_organ_mask(x[i,:,:,:])
|
| 141 |
+
numpy_data = mask_batch_i.get_fdata()
|
| 142 |
+
|
| 143 |
+
# Convert the NumPy array to a PyTorch tensor
|
| 144 |
+
tensor_data = torch.from_numpy(numpy_data).float()
|
| 145 |
+
mask[i,:,:,:] = tensor_data
|
| 146 |
+
data[key] = mask
|
| 147 |
+
return data
|
| 148 |
+
|
| 149 |
+
def get_transforms(self, transform_list):
|
| 150 |
+
normalize=configs.dataset.normalize
|
| 151 |
+
pad=configs.dataset.pad
|
| 152 |
+
resized_size=configs.dataset.resized_size
|
| 153 |
+
WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
|
| 154 |
+
WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
|
| 155 |
+
prob=configs.dataset.augmentationProb
|
| 156 |
+
background=configs.dataset.background
|
| 157 |
+
indicator_A=configs.dataset.indicator_A
|
| 158 |
+
indicator_B=configs.dataset.indicator_B
|
| 159 |
+
load_masks=configs.dataset.load_masks
|
| 160 |
+
transform_list=[]
|
| 161 |
+
input_is_mask=configs.dataset.input_is_mask
|
| 162 |
+
# normally we input CT images and here we create masks for CT images
|
| 163 |
+
if not input_is_mask:
|
| 164 |
+
if not configs.dataset.use_all_masks:
|
| 165 |
+
transform_list.append(CreateMaskTransformd(keys=[indicator_A],
|
| 166 |
+
tissue_min=configs.dataset.tissue_min,
|
| 167 |
+
tissue_max=configs.dataset.tissue_max,
|
| 168 |
+
bone_min=configs.dataset.bone_min,
|
| 169 |
+
bone_max=configs.dataset.bone_max,
|
| 170 |
+
mask_value_bones=2,
|
| 171 |
+
))
|
| 172 |
+
else: # use all masks from the totalsegmentator
|
| 173 |
+
transform_list.append(CreateTotalSegTransformd(keys=[indicator_A],
|
| 174 |
+
fast=True))
|
| 175 |
+
min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
|
| 176 |
+
#transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=min, above=True, cval=background))
|
| 177 |
+
#transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=max, above=False, cval=-1000))
|
| 178 |
+
# filter the source images
|
| 179 |
+
# transform_list.append(ThresholdIntensityd(keys=[indicator_A], threshold=configs.dataset.MRImax, above=False, cval=0))
|
| 180 |
+
if normalize=='zscore':
|
| 181 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_B], nonzero=False, channel_wise=True))
|
| 182 |
+
print('zscore normalization')
|
| 183 |
+
elif normalize=='minmax':
|
| 184 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=-1.0, maxv=1.0))
|
| 185 |
+
print('minmax normalization')
|
| 186 |
+
|
| 187 |
+
elif normalize=='scale1000_wrongbutworks':
|
| 188 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 189 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
|
| 190 |
+
print('scale1000 normalization')
|
| 191 |
+
|
| 192 |
+
elif normalize=='scale1000':
|
| 193 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
|
| 194 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
|
| 195 |
+
print('scale1000 normalization')
|
| 196 |
+
|
| 197 |
+
elif normalize=='scale4000':
|
| 198 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
|
| 199 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975))
|
| 200 |
+
print('scale4000 normalization')
|
| 201 |
+
|
| 202 |
+
elif normalize=='scale10':
|
| 203 |
+
#transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 204 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
|
| 205 |
+
print('scale10 normalization')
|
| 206 |
+
|
| 207 |
+
elif normalize=='inputonlyzscore':
|
| 208 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
|
| 209 |
+
print('only normalize input MRI images')
|
| 210 |
+
|
| 211 |
+
elif normalize=='inputonlyminmax':
|
| 212 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
|
| 213 |
+
print('only normalize input MRI images')
|
| 214 |
+
|
| 215 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 216 |
+
print('no normalization')
|
| 217 |
+
|
| 218 |
+
spaceXY=self.configs.dataset.spaceXY
|
| 219 |
+
if spaceXY>0:
|
| 220 |
+
transform_list.append(Spacingd(keys=[indicator_A], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear", ensure_same_shape=True)) #
|
| 221 |
+
transform_list.append(Spacingd(keys=[indicator_B, "mask"] if load_masks else [indicator_B],
|
| 222 |
+
pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear", ensure_same_shape=True))
|
| 223 |
+
|
| 224 |
+
transform_list.append(Zoomd(keys=[indicator_A, indicator_B,"mask"] if load_masks
|
| 225 |
+
else [indicator_A, indicator_B],
|
| 226 |
+
zoom=configs.dataset.zoom, keep_size=False, mode='area',padding_mode='minimum'))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
transform_list.append(DivisiblePadd(keys=[indicator_A, indicator_B,"mask"] if load_masks else [indicator_A, indicator_B],
|
| 230 |
+
k=self.configs.dataset.div_size, mode="minimum"))
|
| 231 |
+
transform_list.append(ResizeWithPadOrCropd(keys=[indicator_A, indicator_B,"mask"] if load_masks else [indicator_A, indicator_B],
|
| 232 |
+
spatial_size=resized_size,mode=pad))
|
| 233 |
+
|
| 234 |
+
if configs.dataset.rotate:
|
| 235 |
+
transform_list.append(Rotate90d(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B], k=3))
|
| 236 |
+
|
| 237 |
+
if mode == 'train':
|
| 238 |
+
from monai.transforms import (
|
| 239 |
+
# data augmentation
|
| 240 |
+
RandRotated,
|
| 241 |
+
RandZoomd,
|
| 242 |
+
RandBiasFieldd,
|
| 243 |
+
RandAffined,
|
| 244 |
+
RandGridDistortiond,
|
| 245 |
+
RandGridPatchd,
|
| 246 |
+
RandShiftIntensityd,
|
| 247 |
+
RandGibbsNoised,
|
| 248 |
+
RandAdjustContrastd,
|
| 249 |
+
RandGaussianSmoothd,
|
| 250 |
+
RandGaussianSharpend,
|
| 251 |
+
RandGaussianNoised,
|
| 252 |
+
)
|
| 253 |
+
shapeAug=configs.dataset.shapeAug
|
| 254 |
+
if shapeAug:
|
| 255 |
+
#transform_list.append(RandRotated(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
|
| 256 |
+
# range_x = 0.0, range_y = 1.0, range_z = 1.0,
|
| 257 |
+
# prob=prob, padding_mode="border", keep_size=False))
|
| 258 |
+
transform_list.append(RandZoomd(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
|
| 259 |
+
prob=prob, min_zoom=self.configs.dataset.rand_min_zoom, max_zoom=self.configs.dataset.rand_max_zoom,
|
| 260 |
+
padding_mode= "minimum" ,keep_size=False))
|
| 261 |
+
#transform_list.append(RandAffined(keys=[indicator_A, indicator_B], padding_mode="border" , prob=prob))
|
| 262 |
+
#transform_list.append(Rand3DElasticd(keys=[indicator_A, indicator_B], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
|
| 263 |
+
intensityAug=configs.dataset.intensityAug
|
| 264 |
+
if intensityAug:
|
| 265 |
+
print('intensity data augmentation is used')
|
| 266 |
+
transform_list.append(RandBiasFieldd(keys=[indicator_A], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
|
| 267 |
+
transform_list.append(RandGaussianNoised(keys=[indicator_A], prob=prob, mean=0.0, std=0.01))
|
| 268 |
+
transform_list.append(RandAdjustContrastd(keys=[indicator_A], prob=prob, gamma=(0.5, 1.5)))
|
| 269 |
+
transform_list.append(RandShiftIntensityd(keys=[indicator_A], prob=prob, offsets=20))
|
| 270 |
+
transform_list.append(RandGaussianSharpend(keys=[indicator_A], alpha=(0.2, 0.8), prob=prob))
|
| 271 |
+
|
| 272 |
+
#transform_list.append(Rotate90d(keys=[indicator_A, indicator_B], k=3))
|
| 273 |
+
#transform_list.append(DivisiblePadd(keys=[indicator_A, indicator_B], k=div_size, mode="minimum"))
|
| 274 |
+
#transform_list.append(Identityd(keys=[indicator_A, indicator_B])) # do nothing for the no norm case
|
| 275 |
+
train_transforms = Compose(transform_list)
|
| 276 |
+
return train_transforms
|
dataprocesser/archive/csv_dataset.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def get_data_scaler(config):
|
| 2 |
+
"""Data normalizer. Assume data are always in [0, 1]."""
|
| 3 |
+
if config.data.centered:
|
| 4 |
+
# Rescale to [-1, 1]
|
| 5 |
+
return lambda x: x * 2. - 1.
|
| 6 |
+
else:
|
| 7 |
+
return lambda x: x
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_data_inverse_scaler(config):
|
| 11 |
+
"""Inverse data normalizer."""
|
| 12 |
+
if config.data.centered:
|
| 13 |
+
# Rescale [-1, 1] to [0, 1]
|
| 14 |
+
return lambda x: (x + 1.) / 2.
|
| 15 |
+
else:
|
| 16 |
+
return lambda x: x
|
| 17 |
+
|
| 18 |
+
IMG_EXTENSIONS = [
|
| 19 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 20 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 21 |
+
'.nrrd', '.nii.gz'
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_image_file(filename):
|
| 26 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 27 |
+
|
| 28 |
+
def volume_slicer(volume_tensor, transform, all_slices=None):
|
| 29 |
+
# Convert numpy array to PyTorch tensor
|
| 30 |
+
# Note: You might need to add channel dimension or perform other adjustments
|
| 31 |
+
volume_tensor = volume_tensor.permute(2, 1, 0) # [H, W, D] -> [D, H, W]
|
| 32 |
+
volume_tensor = volume_tensor.unsqueeze(1) # Add channel dimension [D, H, W] -> [D, 1, H, W]
|
| 33 |
+
if transform is not None:
|
| 34 |
+
volume_tensor = transform(volume_tensor)
|
| 35 |
+
|
| 36 |
+
#print('stacking volume tensor:',volume_tensor.shape)
|
| 37 |
+
if all_slices is None:
|
| 38 |
+
all_slices = volume_tensor
|
| 39 |
+
else:
|
| 40 |
+
all_slices = torch.cat((all_slices, volume_tensor), 0)
|
| 41 |
+
return all_slices
|
| 42 |
+
|
| 43 |
+
class csvDataset_3D(Dataset):
|
| 44 |
+
def __init__(self, csv_file, transform=None, load_patient_number=1):
|
| 45 |
+
"""
|
| 46 |
+
Args:
|
| 47 |
+
csv_file (string): Path to the csv file with annotations.
|
| 48 |
+
transform (callable, optional): Optional transform to be applied on a sample.
|
| 49 |
+
"""
|
| 50 |
+
self.data_frame = pd.read_csv(csv_file)
|
| 51 |
+
# control the length of the dataset
|
| 52 |
+
self.data_frame = self.data_frame[:load_patient_number]
|
| 53 |
+
self.transform = transform
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.data_frame)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
if torch.is_tensor(idx):
|
| 60 |
+
idx = idx.tolist()
|
| 61 |
+
|
| 62 |
+
img_path = self.data_frame.iloc[idx, -1]
|
| 63 |
+
image = nib.load(img_path).get_fdata()
|
| 64 |
+
image = torch.tensor(image, dtype=torch.float32)
|
| 65 |
+
|
| 66 |
+
# Example: Using the 'Aorta_diss' column as a label
|
| 67 |
+
label = self.data_frame.iloc[idx, -3]
|
| 68 |
+
#label = torch.tensor(label, dtype=torch.float32)
|
| 69 |
+
|
| 70 |
+
# If more processing is needed (e.g., normalization, adding channel dimension), do it here
|
| 71 |
+
image = image.unsqueeze(0) # Add channel dimension if it's a single channel image
|
| 72 |
+
|
| 73 |
+
sample = {'image': image, 'label': label}
|
| 74 |
+
|
| 75 |
+
return sample
|
| 76 |
+
|
| 77 |
+
class csvDataset_2D(Dataset):
|
| 78 |
+
def __init__(self, csv_file, transform=None, load_patient_number=1):
|
| 79 |
+
self.csv_file = csv_file
|
| 80 |
+
self.transform = transform
|
| 81 |
+
self.load_patient_number = load_patient_number
|
| 82 |
+
self.data_frame = pd.read_csv(csv_file)
|
| 83 |
+
if len(self.data_frame) == 0:
|
| 84 |
+
raise RuntimeError(f"Found 0 images in: {csv_file}")
|
| 85 |
+
|
| 86 |
+
# Initialize dataset
|
| 87 |
+
self.initialize_dataset()
|
| 88 |
+
|
| 89 |
+
def initialize_dataset(self):
|
| 90 |
+
print('Loading dataset...')
|
| 91 |
+
self.data_frame = self.data_frame[:self.load_patient_number]
|
| 92 |
+
all_slices = None
|
| 93 |
+
all_labels = []
|
| 94 |
+
|
| 95 |
+
for idx in tqdm(range(len(self.data_frame))):
|
| 96 |
+
img_path = self.data_frame.iloc[idx, -1]
|
| 97 |
+
volume = nib.load(img_path)
|
| 98 |
+
volume_data = volume.get_fdata() # Load as [H, W, D]
|
| 99 |
+
volume_tensor = torch.tensor(volume_data, dtype=torch.float32)
|
| 100 |
+
all_slices = volume_slicer(volume_tensor, self.transform, all_slices) # -> [D, 1, H, W] and pile up all the slices
|
| 101 |
+
label = self.data_frame.iloc[idx, -3]
|
| 102 |
+
all_labels = all_labels + [label] * volume_tensor.shape[0]
|
| 103 |
+
|
| 104 |
+
print('All stacked slices:', all_slices.shape)
|
| 105 |
+
self.all_slices = all_slices
|
| 106 |
+
self.all_labels = all_labels
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return self.all_slices.shape[0]
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, idx):
|
| 112 |
+
if torch.is_tensor(idx):
|
| 113 |
+
idx = idx.tolist()
|
| 114 |
+
image = self.all_slices[idx]
|
| 115 |
+
label = self.all_labels[idx]
|
| 116 |
+
sample = {'source': image, 'target': label}
|
| 117 |
+
return sample
|
| 118 |
+
|
| 119 |
+
def reset(self):
|
| 120 |
+
print('Resetting dataset...')
|
| 121 |
+
self.initialize_dataset()
|
dataprocesser/archive/csv_dataset_slices.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
import os.path
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import ImageFile
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import monai
|
| 9 |
+
import json
|
| 10 |
+
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
|
| 11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from dataprocesser.list_dataset_base import BaseDataLoader
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
dataprocesser/archive/csv_dataset_slices_assigned.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataprocesser.csv_dataset_slices import csv_slices_DataLoader
|
| 2 |
+
from dataprocesser.customized_transforms import MaskHUAssigmentd
|
| 3 |
+
|
| 4 |
+
from monai.transforms import (
|
| 5 |
+
ScaleIntensityd,
|
| 6 |
+
ThresholdIntensityd,
|
| 7 |
+
NormalizeIntensityd,
|
| 8 |
+
ShiftIntensityd,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
dataprocesser/archive/data_create_seg.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from step1_init_data_list import (
|
| 2 |
+
list_img_ad_from_anish_csv,
|
| 3 |
+
list_img_pID_from_synthrad_folder,
|
| 4 |
+
)
|
| 5 |
+
def run():
|
| 6 |
+
number = 1
|
| 7 |
+
dataset='anish'
|
| 8 |
+
if dataset=='anish':
|
| 9 |
+
data_dir = 'D:\Projects\SynthRad\synthrad_conversion\healthy_dissec_home.csv'
|
| 10 |
+
target_file_list, _ =list_img_ad_from_anish_csv(data_dir) # a csv_file
|
| 11 |
+
elif dataset=='synthrad':
|
| 12 |
+
data_dir = 'D:\Projects\data\synthrad\train\Task1\pelvis'
|
| 13 |
+
target_file_list, _=list_img_pID_from_synthrad_folder(data_dir, accepted_modalities='ct', saved_name="target_filenames.txt")
|
| 14 |
+
create_segmentation(target_file_list[0: number])
|
| 15 |
+
|
| 16 |
+
def create_segmentation(dataset_list):
|
| 17 |
+
import nibabel as nib
|
| 18 |
+
try:
|
| 19 |
+
from totalsegmentator.python_api import totalsegmentator
|
| 20 |
+
for sample in dataset_list:
|
| 21 |
+
input_path=sample
|
| 22 |
+
print(f'create segmentation mask for {input_path}')
|
| 23 |
+
output_path=input_path.replace('.nii','_seg.nii')
|
| 24 |
+
input_img = nib.load(input_path)
|
| 25 |
+
totalsegmentator(input=input_img, output=output_path, task='total', fast=False, ml=True)
|
| 26 |
+
print(f'segmentation mask is saved as {output_path}')
|
| 27 |
+
except:
|
| 28 |
+
print("An exception occurred")
|
dataprocesser/archive/data_slicing.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataprocesser.step1_init_data_list import init_dataset
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
loader, opt, my_paths = init_dataset()
|
| 5 |
+
path=r'E:\Projects\yang_proj\data\seg2med\seg2med_nifti_2d_343'
|
| 6 |
+
train_path=os.path.join(path, 'train')
|
| 7 |
+
val_path=os.path.join(path,'val')
|
| 8 |
+
os.makedirs(path,exist_ok=True)
|
| 9 |
+
os.makedirs(train_path,exist_ok=True)
|
| 10 |
+
os.makedirs(val_path,exist_ok=True)
|
| 11 |
+
|
| 12 |
+
loader.save_slices_nifti_and_csv(train_path, loader.train_volume_ds)
|
| 13 |
+
loader.save_slices_nifti_and_csv(val_path, loader.val_volume_ds)
|
dataprocesser/archive/dataset_med.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader, Dataset
|
| 2 |
+
import torch.utils.data as data
|
| 3 |
+
import os.path
|
| 4 |
+
import random
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import ImageFile
|
| 9 |
+
from utils.MattingLaplacian import compute_laplacian
|
| 10 |
+
|
| 11 |
+
import nibabel as nib
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
|
| 15 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
IMG_EXTENSIONS = [
|
| 20 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 21 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 22 |
+
'.nrrd', '.nii.gz'
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_image_file(filename):
|
| 27 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 28 |
+
|
| 29 |
+
def make_dataset_modality(dir, modality='ct'):
|
| 30 |
+
images = []
|
| 31 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
| 32 |
+
|
| 33 |
+
# for image data in the following structure:
|
| 34 |
+
# root/
|
| 35 |
+
# patient_folder/
|
| 36 |
+
# ct_image.nii.gz
|
| 37 |
+
# mr_image.nii.gz
|
| 38 |
+
# ...
|
| 39 |
+
# patient_folder2/
|
| 40 |
+
# ct_image.nii.gz
|
| 41 |
+
# mr_image.nii.gz
|
| 42 |
+
# ...
|
| 43 |
+
for patient_folder, _, fnames in sorted(os.walk(dir)): # means that it will go through all the files in the directory
|
| 44 |
+
#print(patient_folder)
|
| 45 |
+
if patient_folder != dir:
|
| 46 |
+
#print('patient folder:',patient_folder)
|
| 47 |
+
for root2, _, fnames2 in sorted(os.walk(patient_folder)):
|
| 48 |
+
#print('files:',fnames2)
|
| 49 |
+
for fname2 in fnames2:
|
| 50 |
+
if is_image_file(fname2) and modality in fname2:
|
| 51 |
+
#print('passed file:',fname2)
|
| 52 |
+
path = os.path.join(root2, fname2)
|
| 53 |
+
images.append(path)
|
| 54 |
+
return images
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CTImageDataset(Dataset):
|
| 58 |
+
def __init__(self, root, modality='ct', transform=None,
|
| 59 |
+
load_patient_number=1,
|
| 60 |
+
use_lap=True, win_rad=1):
|
| 61 |
+
self.imgs_paths = sorted(make_dataset_modality(root, modality))
|
| 62 |
+
self.transform = transform
|
| 63 |
+
self.to_tensor = transforms.ToTensor() # Might need adjustment for 3D
|
| 64 |
+
|
| 65 |
+
if len(self.imgs_paths) == 0:
|
| 66 |
+
raise RuntimeError(f"Found 0 images in: {root}")
|
| 67 |
+
# form the images to be in the form of [D, H, W]
|
| 68 |
+
all_slices = None
|
| 69 |
+
for img_path in self.imgs_paths[:load_patient_number]:
|
| 70 |
+
volume = nib.load(img_path)
|
| 71 |
+
volume_data = volume.get_fdata() # load as [H, W, D]
|
| 72 |
+
#
|
| 73 |
+
# Convert numpy array to PyTorch tensor
|
| 74 |
+
# Note: You might need to add channel dimension or perform other adjustments
|
| 75 |
+
volume_tensor = torch.tensor(volume_data, dtype=torch.float32)
|
| 76 |
+
volume_tensor = volume_tensor.permute(2, 1, 0) # [N, H, W]
|
| 77 |
+
volume_tensor = volume_tensor.unsqueeze(3) # Add channel dimension [N, H, W] -> [N, H, W, 1]
|
| 78 |
+
# pasting grayscale information to all three channels.
|
| 79 |
+
volume_tensor = volume_tensor.repeat(1, 1, 1, 3)
|
| 80 |
+
#print('Debug, volume tensor:',volume_tensor.shape)
|
| 81 |
+
if self.transform is not None:
|
| 82 |
+
volume_tensor = self.transform(volume_tensor)
|
| 83 |
+
if all_slices is None:
|
| 84 |
+
all_slices = volume_tensor
|
| 85 |
+
else:
|
| 86 |
+
all_slices = torch.cat((all_slices, volume_tensor), 0)
|
| 87 |
+
print(f'slices of {modality} dataset:',all_slices.shape)
|
| 88 |
+
self.all_slices = all_slices
|
| 89 |
+
self.use_lap = use_lap
|
| 90 |
+
self.win_rad = win_rad
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, index):
|
| 93 |
+
img = self.all_slices[index]
|
| 94 |
+
#print('Debug 1, img shape:',img.shape)
|
| 95 |
+
if self.use_lap:
|
| 96 |
+
laplacian_m = compute_laplacian(img, win_rad=self.win_rad)
|
| 97 |
+
else:
|
| 98 |
+
laplacian_m = None
|
| 99 |
+
#print('Debug 2, laplacian_m:',laplacian_m.shape)
|
| 100 |
+
# permute img from [H, W, C] to [C, H, W]
|
| 101 |
+
img = img.permute(2, 0, 1)
|
| 102 |
+
return {'img': img, 'laplacian_m': laplacian_m}
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return self.all_slices.shape[0]
|
| 106 |
+
|
| 107 |
+
from monai.transforms import (
|
| 108 |
+
ResizeWithPadOrCrop,
|
| 109 |
+
ScaleIntensity,
|
| 110 |
+
Compose,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def get_data_loader_folder(input_folder, modality,
|
| 114 |
+
batch_size, new_size=288,
|
| 115 |
+
height=256, width=256,
|
| 116 |
+
num_workers=None, load_patient_number=1):
|
| 117 |
+
transform_list = []
|
| 118 |
+
transform_list = [ResizeWithPadOrCrop(spatial_size=[height,width, -1],mode="minimum")] + transform_list
|
| 119 |
+
transform_list = [ScaleIntensity(minv=0, maxv=1.0)]+ transform_list
|
| 120 |
+
#transform_list = [ScaleIntensity(factor=-0.9)]+ transform_list
|
| 121 |
+
#transform_list = [transforms.Resize(new_size)] + transform_list
|
| 122 |
+
transform = Compose(transform_list)
|
| 123 |
+
|
| 124 |
+
dataset = CTImageDataset(input_folder, modality=modality, transform=transform, load_patient_number=load_patient_number)
|
| 125 |
+
|
| 126 |
+
if num_workers is None:
|
| 127 |
+
num_workers = 0
|
| 128 |
+
loader = DataLoader(dataset=dataset,
|
| 129 |
+
batch_size=batch_size,
|
| 130 |
+
drop_last=True,
|
| 131 |
+
num_workers=num_workers,
|
| 132 |
+
sampler=InfiniteSamplerWrapper(dataset),
|
| 133 |
+
collate_fn=collate_fn
|
| 134 |
+
)
|
| 135 |
+
return loader
|
| 136 |
+
|
| 137 |
+
def main(root = r'C:\Users\56991\Projects\Datasets\Task1\pelvis',modality='ct'):
|
| 138 |
+
# Example usage
|
| 139 |
+
|
| 140 |
+
batch_size = 8
|
| 141 |
+
new_size = 512
|
| 142 |
+
height = 512
|
| 143 |
+
width = 512
|
| 144 |
+
num_workers = None
|
| 145 |
+
load_patient_number = 1
|
| 146 |
+
loader = get_data_loader_folder(root,modality, batch_size, new_size, height, width, num_workers, load_patient_number)
|
| 147 |
+
#print length of loader
|
| 148 |
+
print('Length of loader:',len(loader))
|
| 149 |
+
for i, batch in enumerate(loader):
|
| 150 |
+
print(f'Batch {i}:',batch['img'].shape)
|
| 151 |
+
print('Done')
|
| 152 |
+
|
| 153 |
+
if __name__=='__main__':
|
| 154 |
+
main()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def InfiniteSampler(n):
|
| 158 |
+
# i = 0
|
| 159 |
+
i = n - 1
|
| 160 |
+
order = np.random.permutation(n)
|
| 161 |
+
while True:
|
| 162 |
+
yield order[i]
|
| 163 |
+
i += 1
|
| 164 |
+
if i >= n:
|
| 165 |
+
np.random.seed()
|
| 166 |
+
order = np.random.permutation(n)
|
| 167 |
+
i = 0
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class InfiniteSamplerWrapper(data.sampler.Sampler):
|
| 171 |
+
def __init__(self, data_source):
|
| 172 |
+
self.num_samples = len(data_source)
|
| 173 |
+
|
| 174 |
+
def __iter__(self):
|
| 175 |
+
return iter(InfiniteSampler(self.num_samples))
|
| 176 |
+
|
| 177 |
+
def __len__(self):
|
| 178 |
+
return 2 ** 31
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def collate_fn(batch):
|
| 182 |
+
img = [b['img'] for b in batch]
|
| 183 |
+
img = torch.stack(img, dim=0)
|
| 184 |
+
|
| 185 |
+
laplacian_m = [b['laplacian_m'] for b in batch]
|
| 186 |
+
|
| 187 |
+
return {'img': img, 'laplacian_m': laplacian_m}
|
| 188 |
+
|
dataprocesser/archive/gan_loader.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import monai
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from monai.transforms import (
|
| 5 |
+
Compose,
|
| 6 |
+
LoadImaged,
|
| 7 |
+
EnsureChannelFirstd,
|
| 8 |
+
SqueezeDimd,
|
| 9 |
+
CenterSpatialCropd,
|
| 10 |
+
Rotate90d,
|
| 11 |
+
ScaleIntensityd,
|
| 12 |
+
ResizeWithPadOrCropd,
|
| 13 |
+
DivisiblePadd,
|
| 14 |
+
ThresholdIntensityd,
|
| 15 |
+
NormalizeIntensityd,
|
| 16 |
+
ShiftIntensityd,
|
| 17 |
+
Identityd,
|
| 18 |
+
ScaleIntensityRanged,
|
| 19 |
+
Spacingd,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from monai.data import Dataset
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
import torch
|
| 25 |
+
from .checkdata import check_volumes, save_volumes, check_batch_data, test_volumes_pixdim
|
| 26 |
+
|
| 27 |
+
def get_transforms(configs, mode='train'):
|
| 28 |
+
normalize=configs.dataset.normalize
|
| 29 |
+
pad=configs.dataset.pad
|
| 30 |
+
resized_size=configs.dataset.resized_size
|
| 31 |
+
WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
|
| 32 |
+
WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
|
| 33 |
+
prob=configs.dataset.augmentationProb
|
| 34 |
+
background=configs.dataset.background
|
| 35 |
+
|
| 36 |
+
transform_list=[]
|
| 37 |
+
min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
|
| 38 |
+
#transform_list.append(ThresholdIntensityd(keys=["target"], threshold=min, above=True, cval=background))
|
| 39 |
+
#transform_list.append(ThresholdIntensityd(keys=["target"], threshold=max, above=False, cval=-1000))
|
| 40 |
+
# filter the source images
|
| 41 |
+
# transform_list.append(ThresholdIntensityd(keys=["source"], threshold=configs.dataset.MRImax, above=False, cval=0))
|
| 42 |
+
if normalize=='zscore':
|
| 43 |
+
transform_list.append(NormalizeIntensityd(keys=["source", "target"], nonzero=False, channel_wise=True))
|
| 44 |
+
print('zscore normalization')
|
| 45 |
+
elif normalize=='minmax':
|
| 46 |
+
transform_list.append(ScaleIntensityd(keys=["source", "target"], minv=-1.0, maxv=1.0))
|
| 47 |
+
print('minmax normalization')
|
| 48 |
+
|
| 49 |
+
elif normalize=='scale4000':
|
| 50 |
+
transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1))
|
| 51 |
+
transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
|
| 52 |
+
transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975)) # x=x(1+factor)
|
| 53 |
+
print('scale1000 normalization')
|
| 54 |
+
|
| 55 |
+
elif normalize=='scale1000':
|
| 56 |
+
transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1))
|
| 57 |
+
transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
|
| 58 |
+
transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.999))
|
| 59 |
+
print('scale1000 normalization')
|
| 60 |
+
|
| 61 |
+
elif normalize=='inputonlyzscore':
|
| 62 |
+
transform_list.append(NormalizeIntensityd(keys=["source"], nonzero=False, channel_wise=True))
|
| 63 |
+
print('only normalize input MRI images')
|
| 64 |
+
|
| 65 |
+
elif normalize=='inputonlyminmax':
|
| 66 |
+
transform_list.append(ScaleIntensityd(keys=["source"], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
|
| 67 |
+
print('only normalize input MRI images')
|
| 68 |
+
|
| 69 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 70 |
+
print('no normalization')
|
| 71 |
+
|
| 72 |
+
spaceXY=0
|
| 73 |
+
if spaceXY>0:
|
| 74 |
+
transform_list.append(Spacingd(keys=["source"], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear")) #
|
| 75 |
+
transform_list.append(Spacingd(keys=["target", "mask"], pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear")) #
|
| 76 |
+
transform_list.append(ResizeWithPadOrCropd(keys=["source", "target"], spatial_size=resized_size,mode=pad))
|
| 77 |
+
# transform_list.append(ScaleIntensityRanged(keys=["target"],a_min=WINDOW_LEVEL-(WINDOW_WIDTH/2), a_max=WINDOW_LEVEL+(WINDOW_WIDTH/2),b_min=0, b_max=1, clip=True))
|
| 78 |
+
|
| 79 |
+
if configs.dataset.rotate:
|
| 80 |
+
transform_list.append(Rotate90d(keys=["source", "target"], k=3))
|
| 81 |
+
|
| 82 |
+
if mode == 'train':
|
| 83 |
+
from monai.transforms import (
|
| 84 |
+
# data augmentation
|
| 85 |
+
RandRotated,
|
| 86 |
+
RandZoomd,
|
| 87 |
+
RandBiasFieldd,
|
| 88 |
+
RandAffined,
|
| 89 |
+
RandGridDistortiond,
|
| 90 |
+
RandGridPatchd,
|
| 91 |
+
RandShiftIntensityd,
|
| 92 |
+
RandGibbsNoised,
|
| 93 |
+
RandAdjustContrastd,
|
| 94 |
+
RandGaussianSmoothd,
|
| 95 |
+
RandGaussianSharpend,
|
| 96 |
+
RandGaussianNoised,
|
| 97 |
+
)
|
| 98 |
+
Aug=True
|
| 99 |
+
if Aug:
|
| 100 |
+
transform_list.append(RandRotated(keys=["source", "target", "mask"], range_x = 0.1, range_y = 0.1, range_z = 0.1, prob=prob, padding_mode="border", keep_size=True))
|
| 101 |
+
transform_list.append(RandZoomd(keys=["source", "target", "mask"], prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True))
|
| 102 |
+
transform_list.append(RandAffined(keys=["source", "target"],padding_mode="border" , prob=prob))
|
| 103 |
+
#transform_list.append(Rand3DElasticd(keys=["source", "target"], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
|
| 104 |
+
intensityAug=False
|
| 105 |
+
if intensityAug:
|
| 106 |
+
print('intensity data augmentation is used')
|
| 107 |
+
transform_list.append(RandBiasFieldd(keys=["source"], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
|
| 108 |
+
transform_list.append(RandGaussianNoised(keys=["source"], prob=prob, mean=0.0, std=0.01))
|
| 109 |
+
transform_list.append(RandAdjustContrastd(keys=["source"], prob=prob, gamma=(0.5, 1.5)))
|
| 110 |
+
transform_list.append(RandShiftIntensityd(keys=["source"], prob=prob, offsets=20))
|
| 111 |
+
transform_list.append(RandGaussianSharpend(keys=["source"], alpha=(0.2, 0.8), prob=prob))
|
| 112 |
+
|
| 113 |
+
#transform_list.append(Rotate90d(keys=["source", "target"], k=3))
|
| 114 |
+
#transform_list.append(DivisiblePadd(keys=["source", "target"], k=div_size, mode="minimum"))
|
| 115 |
+
#transform_list.append(Identityd(keys=["source", "target"])) # do nothing for the no norm case
|
| 116 |
+
train_transforms = Compose(transform_list)
|
| 117 |
+
return train_transforms
|
| 118 |
+
def myslicesloader(configs,paths):
|
| 119 |
+
data_path=configs.dataset.data_dir
|
| 120 |
+
train_number=configs.dataset.train_number
|
| 121 |
+
val_number=configs.dataset.val_number
|
| 122 |
+
train_batch_size=configs.dataset.batch_size
|
| 123 |
+
val_batch_size=configs.dataset.val_batch_size
|
| 124 |
+
saved_name_train=paths["saved_name_train"]
|
| 125 |
+
saved_name_val=paths["saved_name_val"]
|
| 126 |
+
center_crop=configs.dataset.center_crop
|
| 127 |
+
source=configs.dataset.source
|
| 128 |
+
target=configs.dataset.target
|
| 129 |
+
|
| 130 |
+
# volume-level transforms for both image and label
|
| 131 |
+
train_transforms = get_transforms(configs,mode='train')
|
| 132 |
+
val_transforms = get_transforms(configs,mode='val')
|
| 133 |
+
|
| 134 |
+
#list all files in the folder
|
| 135 |
+
file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
|
| 136 |
+
file_list_path=[os.path.join(data_path,i) for i in file_list]
|
| 137 |
+
#list all ct and mr files in folder
|
| 138 |
+
mask='mask'
|
| 139 |
+
source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
|
| 140 |
+
target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path]
|
| 141 |
+
mask_file_list=[os.path.join(j,f'{mask}.nii.gz') for j in file_list_path]
|
| 142 |
+
train_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 143 |
+
for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
|
| 144 |
+
val_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 145 |
+
for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
|
| 146 |
+
print('all files in dataset:',len(file_list))
|
| 147 |
+
|
| 148 |
+
# load volumes and center crop
|
| 149 |
+
if center_crop>0:
|
| 150 |
+
crop=Compose([LoadImaged(keys=["source", "target", "mask"]),
|
| 151 |
+
EnsureChannelFirstd(keys=["source", "target", "mask"]),
|
| 152 |
+
CenterSpatialCropd(keys=["source", "target", "mask"], roi_size=(-1,-1,center_crop)),
|
| 153 |
+
|
| 154 |
+
])
|
| 155 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
|
| 156 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
|
| 157 |
+
print('center crop:',center_crop)
|
| 158 |
+
else:
|
| 159 |
+
crop=Compose([LoadImaged(keys=["source", "target", "mask"]),
|
| 160 |
+
EnsureChannelFirstd(keys=["source", "target", "mask"]),
|
| 161 |
+
])
|
| 162 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
|
| 163 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
|
| 164 |
+
|
| 165 |
+
# load volumes
|
| 166 |
+
train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
|
| 167 |
+
val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
|
| 168 |
+
ifsave,ifcheck,iftest=False,False,False
|
| 169 |
+
if ifsave:
|
| 170 |
+
save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
|
| 171 |
+
if ifcheck:
|
| 172 |
+
check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
|
| 173 |
+
if iftest:
|
| 174 |
+
test_volumes_pixdim(train_volume_ds)
|
| 175 |
+
|
| 176 |
+
# batch-level slicer for both image and label
|
| 177 |
+
window_width=1
|
| 178 |
+
patch_func = monai.data.PatchIterd(
|
| 179 |
+
keys=["source", "target", "mask"],
|
| 180 |
+
patch_size=(None, None, window_width), # dynamic first two dimensions
|
| 181 |
+
start_pos=(0, 0, 0)
|
| 182 |
+
)
|
| 183 |
+
if window_width==1:
|
| 184 |
+
patch_transform = Compose(
|
| 185 |
+
[
|
| 186 |
+
SqueezeDimd(keys=["source", "target", "mask"], dim=-1), # squeeze the last dim
|
| 187 |
+
]
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
patch_transform = None
|
| 191 |
+
|
| 192 |
+
# for training
|
| 193 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 194 |
+
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 195 |
+
train_loader = DataLoader(
|
| 196 |
+
train_patch_ds,
|
| 197 |
+
batch_size=train_batch_size,
|
| 198 |
+
num_workers=2,
|
| 199 |
+
pin_memory=torch.cuda.is_available(),
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# for validation
|
| 203 |
+
val_loader = DataLoader(
|
| 204 |
+
val_volume_ds,
|
| 205 |
+
num_workers=1,
|
| 206 |
+
batch_size=val_batch_size,
|
| 207 |
+
pin_memory=torch.cuda.is_available())
|
| 208 |
+
|
| 209 |
+
if ifcheck:
|
| 210 |
+
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
|
| 211 |
+
|
| 212 |
+
return train_crop_ds,val_crop_ds,train_loader,val_loader,train_transforms,val_transforms
|
| 213 |
+
|
| 214 |
+
def ddpmloader(configs,paths):
|
| 215 |
+
data_path=configs.dataset.data_dir
|
| 216 |
+
train_number=configs.dataset.train_number
|
| 217 |
+
val_number=configs.dataset.val_number
|
| 218 |
+
train_batch_size=configs.dataset.batch_size
|
| 219 |
+
val_batch_size=configs.dataset.val_batch_size
|
| 220 |
+
saved_name_train=paths["saved_name_train"]
|
| 221 |
+
saved_name_val=paths["saved_name_val"]
|
| 222 |
+
center_crop=configs.dataset.center_crop
|
| 223 |
+
source=configs.dataset.source
|
| 224 |
+
target=configs.dataset.target
|
| 225 |
+
|
| 226 |
+
# volume-level transforms for both image and label
|
| 227 |
+
train_transforms = get_transforms(configs,mode='train')
|
| 228 |
+
val_transforms = get_transforms(configs,mode='val')
|
| 229 |
+
|
| 230 |
+
#list all files in the folder
|
| 231 |
+
file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
|
| 232 |
+
file_list_path=[os.path.join(data_path,i) for i in file_list]
|
| 233 |
+
#list all ct and mr files in folder
|
| 234 |
+
mask='mask'
|
| 235 |
+
source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
|
| 236 |
+
target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path]
|
| 237 |
+
mask_file_list=[os.path.join(j,f'{mask}.nii.gz') for j in file_list_path]
|
| 238 |
+
train_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 239 |
+
for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
|
| 240 |
+
val_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 241 |
+
for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
|
| 242 |
+
print('all files in dataset:',len(file_list))
|
| 243 |
+
|
| 244 |
+
# load volumes and center crop
|
| 245 |
+
if center_crop>0:
|
| 246 |
+
crop=Compose([LoadImaged(keys=["source", "target"]),
|
| 247 |
+
EnsureChannelFirstd(keys=["source", "target"]),
|
| 248 |
+
CenterSpatialCropd(keys=["source", "target"], roi_size=(-1,-1,center_crop)),
|
| 249 |
+
|
| 250 |
+
])
|
| 251 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
|
| 252 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
|
| 253 |
+
print('center crop:',center_crop)
|
| 254 |
+
else:
|
| 255 |
+
crop=Compose([LoadImaged(keys=["source", "target"]),
|
| 256 |
+
EnsureChannelFirstd(keys=["source", "target"]),
|
| 257 |
+
])
|
| 258 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
|
| 259 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
|
| 260 |
+
|
| 261 |
+
# load volumes
|
| 262 |
+
train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
|
| 263 |
+
val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
|
| 264 |
+
ifsave,ifcheck,iftest=False,False,False
|
| 265 |
+
if ifsave:
|
| 266 |
+
save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
|
| 267 |
+
if ifcheck:
|
| 268 |
+
check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
|
| 269 |
+
if iftest:
|
| 270 |
+
test_volumes_pixdim(train_volume_ds)
|
| 271 |
+
|
| 272 |
+
# batch-level slicer for both image and label
|
| 273 |
+
window_width=1
|
| 274 |
+
patch_func = monai.data.PatchIterd(
|
| 275 |
+
keys=["source", "target"],
|
| 276 |
+
patch_size=(None, None, window_width), # dynamic first two dimensions
|
| 277 |
+
start_pos=(0, 0, 0)
|
| 278 |
+
)
|
| 279 |
+
if window_width==1:
|
| 280 |
+
patch_transform = Compose(
|
| 281 |
+
[
|
| 282 |
+
SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
|
| 283 |
+
]
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
patch_transform = None
|
| 287 |
+
|
| 288 |
+
# for training
|
| 289 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 290 |
+
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 291 |
+
train_loader = DataLoader(
|
| 292 |
+
train_patch_ds,
|
| 293 |
+
batch_size=train_batch_size,
|
| 294 |
+
num_workers=0,
|
| 295 |
+
pin_memory=torch.cuda.is_available(),
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# for validation
|
| 299 |
+
val_patch_ds = monai.data.GridPatchDataset(
|
| 300 |
+
data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 301 |
+
val_loader = DataLoader(
|
| 302 |
+
val_patch_ds, #val_volume_ds,
|
| 303 |
+
num_workers=0,
|
| 304 |
+
batch_size=val_batch_size,
|
| 305 |
+
pin_memory=torch.cuda.is_available())
|
| 306 |
+
|
| 307 |
+
if ifcheck:
|
| 308 |
+
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
|
| 309 |
+
|
| 310 |
+
return train_crop_ds,val_crop_ds,train_loader,val_loader,train_transforms,val_transforms
|
dataprocesser/archive/init_dataset.py
ADDED
|
File without changes
|
dataprocesser/archive/json_dataset_slices.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
import os.path
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import ImageFile
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import monai
|
| 9 |
+
|
| 10 |
+
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
|
| 11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
|
| 12 |
+
|
| 13 |
+
IMG_EXTENSIONS = [
|
| 14 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 15 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 16 |
+
'.nrrd', '.nii.gz'
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def is_image_file(filename):
|
| 20 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from dataprocesser.list_dataset_base import BaseDataLoader
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
dataprocesser/archive/list_dataset_Anika.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
dataprocesser/archive/list_dataset_Anish.py
ADDED
|
File without changes
|
dataprocesser/archive/list_dataset_Anish_seg.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 2 |
+
import torch.utils.data as data
|
| 3 |
+
import os.path
|
| 4 |
+
import random
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import ImageFile
|
| 9 |
+
#from utils.MattingLaplacian import compute_laplacian
|
| 10 |
+
|
| 11 |
+
import nibabel as nib
|
| 12 |
+
import numpy as np
|
| 13 |
+
import os
|
| 14 |
+
import csv
|
| 15 |
+
import pandas as pd
|
| 16 |
+
#from transformers import CLIPTokenizer
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
|
| 19 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
IMG_EXTENSIONS = [
|
| 27 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 28 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 29 |
+
'.nrrd', '.nii.gz'
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
def is_image_file(filename):
|
| 33 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd
|
| 39 |
+
from dataprocesser.list_dataset_base import BaseDataLoader
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
dataprocesser/archive/list_dataset_base.py
ADDED
|
@@ -0,0 +1,983 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import monai
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from monai.transforms import (
|
| 5 |
+
Compose,
|
| 6 |
+
LoadImaged,
|
| 7 |
+
EnsureChannelFirstd,
|
| 8 |
+
SqueezeDimd,
|
| 9 |
+
CenterSpatialCropd,
|
| 10 |
+
Rotate90d,
|
| 11 |
+
ScaleIntensityd,
|
| 12 |
+
ResizeWithPadOrCropd,
|
| 13 |
+
DivisiblePadd,
|
| 14 |
+
Zoomd,
|
| 15 |
+
ThresholdIntensityd,
|
| 16 |
+
NormalizeIntensityd,
|
| 17 |
+
ShiftIntensityd,
|
| 18 |
+
Identityd,
|
| 19 |
+
ScaleIntensityRanged,
|
| 20 |
+
Spacingd,
|
| 21 |
+
)
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
from torch.utils.data import ConcatDataset
|
| 24 |
+
import torch
|
| 25 |
+
from abc import ABC, abstractmethod
|
| 26 |
+
|
| 27 |
+
from datetime import datetime
|
| 28 |
+
import json
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
|
| 31 |
+
from step1_init_data_list import (
|
| 32 |
+
list_img_ad_from_anish_csv,
|
| 33 |
+
list_img_ad_pIDs_from_anish_csv,
|
| 34 |
+
list_img_pID_from_synthrad_folder,
|
| 35 |
+
list_from_anika_dataset,
|
| 36 |
+
list_from_json,
|
| 37 |
+
list_from_slice_csv,
|
| 38 |
+
)
|
| 39 |
+
from step5_data_check_and_log import finalcheck
|
| 40 |
+
VERBOSE = False
|
| 41 |
+
|
| 42 |
+
def make_dataset_modality():
|
| 43 |
+
images = []
|
| 44 |
+
return images
|
| 45 |
+
|
| 46 |
+
class ABCLoader(ABC):
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def __init__(self):
|
| 49 |
+
"""Subclass must implement this method."""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
def get_loader(self):
|
| 53 |
+
"""Subclass must implement this method."""
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
def create_dataset(self):
|
| 57 |
+
"""Subclass must implement this method."""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def get_transforms(self):
|
| 61 |
+
"""Subclass must implement this method."""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
def get_normlization(self):
|
| 65 |
+
"""Subclass must implement this method."""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def get_shape_transform(self):
|
| 69 |
+
"""Subclass must implement this method."""
|
| 70 |
+
print("no shape transform here!!!!!!!!!!!!!!!!!!!!!!")
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def get_augmentation(self):
|
| 74 |
+
"""Subclass must implement this method."""
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
class BaseDataLoader(ABCLoader):
|
| 78 |
+
def __init__(self,configs,paths=None,dimension=2, **kwargs):
|
| 79 |
+
self.configs=configs
|
| 80 |
+
self.paths=paths
|
| 81 |
+
self.init_parameters_and_transforms()
|
| 82 |
+
self.get_loader()
|
| 83 |
+
#print('all files in dataset:',len(self.source_file_list))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
self.rotation_level = kwargs.get('rotation_level', 0) # Default to no rotation (0)
|
| 87 |
+
self.zoom_level = kwargs.get('zoom_level', 1.0) # Default to no zoom (1.0)
|
| 88 |
+
self.flip = kwargs.get('flip', 0) # Default to no flip
|
| 89 |
+
|
| 90 |
+
self.create_dataset(dimension=dimension)
|
| 91 |
+
|
| 92 |
+
ifsave = None if paths is None else True
|
| 93 |
+
finalcheck(self.train_ds, self.val_ds,
|
| 94 |
+
self.train_volume_ds, self.val_volume_ds,
|
| 95 |
+
self.train_loader, self.val_loader,
|
| 96 |
+
self.train_patch_ds,
|
| 97 |
+
self.train_batch_size, self.val_batch_size,
|
| 98 |
+
self.saved_name_train, self.saved_name_val,
|
| 99 |
+
self.indicator_A, self.indicator_B,
|
| 100 |
+
ifsave=ifsave, ifcheck=False,iftest_volumes_pixdim=False)
|
| 101 |
+
|
| 102 |
+
def get_loader(self):
|
| 103 |
+
self.source_file_list = []
|
| 104 |
+
self.train_ds=[]
|
| 105 |
+
self.val_ds=[]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def init_parameters_and_transforms(self):
|
| 109 |
+
self.indicator_A=self.configs.dataset.indicator_A
|
| 110 |
+
self.indicator_B=self.configs.dataset.indicator_B
|
| 111 |
+
self.train_number=self.configs.dataset.train_number
|
| 112 |
+
self.val_number=self.configs.dataset.val_number
|
| 113 |
+
self.train_batch_size=self.configs.dataset.batch_size
|
| 114 |
+
self.val_batch_size=self.configs.dataset.val_batch_size
|
| 115 |
+
self.load_masks=self.configs.dataset.load_masks
|
| 116 |
+
|
| 117 |
+
self.keys = [self.indicator_A, self.indicator_B, "mask"] if self.load_masks else [self.indicator_A, self.indicator_B]
|
| 118 |
+
|
| 119 |
+
if self.configs.model_name=='augmentation':
|
| 120 |
+
# Fixed parameters for rotation and zooming
|
| 121 |
+
self.train_transforms = self.get_augmentation(transform_list=[], flip=self.flip, rotation_level=self.rotation_level, zoom_level=self.zoom_level)
|
| 122 |
+
else:
|
| 123 |
+
self.train_transforms = self.get_transforms(mode='train')
|
| 124 |
+
self.val_transforms = self.get_transforms(mode='val')
|
| 125 |
+
|
| 126 |
+
if self.paths is not None:
|
| 127 |
+
self.saved_name_train=self.paths["saved_name_train"]
|
| 128 |
+
self.saved_name_val=self.paths["saved_name_val"]
|
| 129 |
+
|
| 130 |
+
def create_volume_dataset(self):
|
| 131 |
+
# load volumes and center crop
|
| 132 |
+
center_crop = self.configs.dataset.center_crop
|
| 133 |
+
transformations_crop = [
|
| 134 |
+
LoadImaged(keys=self.keys),
|
| 135 |
+
EnsureChannelFirstd(keys=self.keys),
|
| 136 |
+
]
|
| 137 |
+
if center_crop>0:
|
| 138 |
+
transformations_crop.append(CenterSpatialCropd(keys=self.keys, roi_size=(-1,-1,center_crop)))
|
| 139 |
+
transformations_crop=Compose(transformations_crop)
|
| 140 |
+
train_crop_ds = monai.data.Dataset(data=self.train_ds, transform=transformations_crop)
|
| 141 |
+
val_crop_ds = monai.data.Dataset(data=self.val_ds, transform=transformations_crop)
|
| 142 |
+
|
| 143 |
+
# load volumes
|
| 144 |
+
self.train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=self.train_transforms)
|
| 145 |
+
self.val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=self.val_transforms)
|
| 146 |
+
|
| 147 |
+
def create_patch_dataset_and_dataloader(self, dimension=2):
|
| 148 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 149 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 150 |
+
if dimension==2:
|
| 151 |
+
# batch-level slicer for both image and label
|
| 152 |
+
window_width=1
|
| 153 |
+
patch_func = monai.data.PatchIterd(
|
| 154 |
+
keys=self.keys,
|
| 155 |
+
patch_size=(None, None, window_width), # dynamic first two dimensions
|
| 156 |
+
start_pos=(0, 0, 0)
|
| 157 |
+
)
|
| 158 |
+
if window_width==1:
|
| 159 |
+
patch_transform = Compose(
|
| 160 |
+
[
|
| 161 |
+
SqueezeDimd(keys=self.keys, dim=-1), # squeeze the last dim
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
patch_transform = None
|
| 166 |
+
|
| 167 |
+
# for training
|
| 168 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 169 |
+
data=self.train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 170 |
+
train_loader = DataLoader(
|
| 171 |
+
train_patch_ds,
|
| 172 |
+
batch_size=train_batch_size,
|
| 173 |
+
num_workers=self.configs.dataset.num_workers,
|
| 174 |
+
pin_memory=torch.cuda.is_available(),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# for validation
|
| 178 |
+
if self.configs.model_name=='ddpm' or 'ddpm2d_seg2med' or 'ddpm2d':
|
| 179 |
+
val_patch_ds = monai.data.GridPatchDataset(
|
| 180 |
+
data=self.val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 181 |
+
val_loader = DataLoader(
|
| 182 |
+
val_patch_ds, #val_volume_ds,
|
| 183 |
+
num_workers=self.configs.dataset.num_workers,
|
| 184 |
+
batch_size=val_batch_size,
|
| 185 |
+
pin_memory=torch.cuda.is_available())
|
| 186 |
+
else:
|
| 187 |
+
val_loader = DataLoader(
|
| 188 |
+
self.val_volume_ds,
|
| 189 |
+
num_workers=self.configs.dataset.num_workers,
|
| 190 |
+
batch_size=val_batch_size,
|
| 191 |
+
pin_memory=torch.cuda.is_available())
|
| 192 |
+
self.train_patch_ds=train_patch_ds
|
| 193 |
+
|
| 194 |
+
elif dimension==2.5:
|
| 195 |
+
# batch-level slicer for both image and label
|
| 196 |
+
# 2.5 means stack slices together as a small volume patch
|
| 197 |
+
# if window_width>1, means we train a 2.5D network
|
| 198 |
+
patch_size=self.configs.dataset.patch_size # (None, None, window_width)
|
| 199 |
+
window_width=patch_size[-1]
|
| 200 |
+
patch_func = monai.data.PatchIterd(
|
| 201 |
+
keys=self.keys,
|
| 202 |
+
patch_size=patch_size, # dynamic first two dimensions: (None, None, window_width)
|
| 203 |
+
start_pos=(0, 0, 0)
|
| 204 |
+
)
|
| 205 |
+
if window_width==1:
|
| 206 |
+
print(f"slice patch is 1, we use 2D-training")
|
| 207 |
+
patch_transform = Compose(
|
| 208 |
+
[
|
| 209 |
+
SqueezeDimd(keys=self.keys, dim=-1), # squeeze the last dim
|
| 210 |
+
]
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
print(f"use consecutive {window_width} slices for 2.5D-training")
|
| 214 |
+
# there would be an error if original size < patch_size during training, so we should pad it in this case
|
| 215 |
+
patch_transform = ResizeWithPadOrCropd(keys=self.keys,
|
| 216 |
+
spatial_size=patch_size, mode='minimum')
|
| 217 |
+
|
| 218 |
+
# for training
|
| 219 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 220 |
+
data=self.train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 221 |
+
train_loader = DataLoader(
|
| 222 |
+
train_patch_ds,
|
| 223 |
+
batch_size=train_batch_size,
|
| 224 |
+
num_workers=2,
|
| 225 |
+
pin_memory=torch.cuda.is_available(),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# for validation
|
| 229 |
+
if self.configs.model_name=='ddpm':
|
| 230 |
+
val_patch_ds = monai.data.GridPatchDataset(
|
| 231 |
+
data=self.val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 232 |
+
val_loader = DataLoader(
|
| 233 |
+
val_patch_ds, #val_volume_ds,
|
| 234 |
+
num_workers=0,
|
| 235 |
+
batch_size=val_batch_size,
|
| 236 |
+
pin_memory=torch.cuda.is_available())
|
| 237 |
+
else:
|
| 238 |
+
val_loader = DataLoader(
|
| 239 |
+
self.val_volume_ds,
|
| 240 |
+
num_workers=1,
|
| 241 |
+
batch_size=val_batch_size,
|
| 242 |
+
pin_memory=torch.cuda.is_available())
|
| 243 |
+
self.train_patch_ds=train_patch_ds
|
| 244 |
+
|
| 245 |
+
elif dimension==3:
|
| 246 |
+
# 3 means use the whole input volume for training
|
| 247 |
+
train_loader = DataLoader(
|
| 248 |
+
self.train_volume_ds,
|
| 249 |
+
num_workers=self.configs.dataset.num_workers,
|
| 250 |
+
batch_size=train_batch_size,
|
| 251 |
+
pin_memory=torch.cuda.is_available())
|
| 252 |
+
val_loader = DataLoader(
|
| 253 |
+
self.val_volume_ds,
|
| 254 |
+
num_workers=self.configs.dataset.num_workers,
|
| 255 |
+
batch_size=val_batch_size,
|
| 256 |
+
pin_memory=torch.cuda.is_available())
|
| 257 |
+
|
| 258 |
+
elif dimension==3.5:
|
| 259 |
+
# 3.5 means create patch from the original volume
|
| 260 |
+
patch_func = monai.data.PatchIterd(
|
| 261 |
+
keys=[self.indicator_A, self.indicator_B],
|
| 262 |
+
patch_size=self.configs.dataset.patch_size, # dynamic first two dimensions
|
| 263 |
+
start_pos=(0, 0, 0),
|
| 264 |
+
mode="replicate",
|
| 265 |
+
)
|
| 266 |
+
patch_transform = None
|
| 267 |
+
|
| 268 |
+
# for training
|
| 269 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 270 |
+
data=self.train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 271 |
+
train_loader = DataLoader(
|
| 272 |
+
train_patch_ds,
|
| 273 |
+
batch_size=train_batch_size,
|
| 274 |
+
num_workers=self.configs.dataset.num_workers,
|
| 275 |
+
pin_memory=torch.cuda.is_available(),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
val_patch_ds = monai.data.GridPatchDataset(
|
| 279 |
+
data=self.val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 280 |
+
val_loader = DataLoader(
|
| 281 |
+
val_patch_ds, #val_volume_ds,
|
| 282 |
+
num_workers=self.configs.dataset.num_workers,
|
| 283 |
+
batch_size=val_batch_size,
|
| 284 |
+
pin_memory=torch.cuda.is_available())
|
| 285 |
+
else:
|
| 286 |
+
print('dimension of input data must be 2 or 2.5 or 3 or 3.5!')
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
self.train_batch_size=train_batch_size
|
| 291 |
+
self.val_batch_size=val_batch_size
|
| 292 |
+
self.train_loader=train_loader
|
| 293 |
+
self.val_loader=val_loader
|
| 294 |
+
|
| 295 |
+
def create_dataset(self,dimension=2):
|
| 296 |
+
self.create_volume_dataset()
|
| 297 |
+
self.create_patch_dataset_and_dataloader(dimension=dimension)
|
| 298 |
+
|
| 299 |
+
def get_transforms(self, mode='train'):
|
| 300 |
+
transform_list=[]
|
| 301 |
+
transform_list = self.get_pretransforms(transform_list)
|
| 302 |
+
transform_list = self.get_intensity_transforms(transform_list)
|
| 303 |
+
transform_list = self.get_normlization(transform_list)
|
| 304 |
+
transform_list = self.get_shape_transform(transform_list)
|
| 305 |
+
train_transforms = Compose(transform_list)
|
| 306 |
+
return train_transforms
|
| 307 |
+
|
| 308 |
+
def get_pretransforms(self, transform_list):
|
| 309 |
+
#print("customized transforms")
|
| 310 |
+
return transform_list
|
| 311 |
+
|
| 312 |
+
def get_intensity_transforms(self, transform_list):
|
| 313 |
+
threshold_low=self.configs.dataset.WINDOW_LEVEL - self.configs.dataset.WINDOW_WIDTH / 2
|
| 314 |
+
threshold_high=self.configs.dataset.WINDOW_LEVEL + self.configs.dataset.WINDOW_WIDTH / 2
|
| 315 |
+
offset=(-1)*threshold_low
|
| 316 |
+
# if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
|
| 317 |
+
# mask = img > self.threshold if self.above else img < self.threshold
|
| 318 |
+
# res = where(mask, img, self.cval)
|
| 319 |
+
transform_list.append(ThresholdIntensityd(keys=[self.indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
|
| 320 |
+
transform_list.append(ThresholdIntensityd(keys=[self.indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
|
| 321 |
+
transform_list.append(ShiftIntensityd(keys=[self.indicator_B], offset=offset))
|
| 322 |
+
return transform_list
|
| 323 |
+
|
| 324 |
+
def get_normlization(self, transform_list):
|
| 325 |
+
normalize=self.configs.dataset.normalize
|
| 326 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 327 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 328 |
+
# offset = self.configs.dataset.offset
|
| 329 |
+
# we don't need normalization for segmentation mask
|
| 330 |
+
if normalize=='zscore':
|
| 331 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_B], nonzero=False, channel_wise=True))
|
| 332 |
+
print('zscore normalization')
|
| 333 |
+
elif normalize=='minmax':
|
| 334 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=-1.0, maxv=1.0))
|
| 335 |
+
print('minmax normalization')
|
| 336 |
+
|
| 337 |
+
elif normalize=='scale1000_wrongbutworks':
|
| 338 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 339 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
|
| 340 |
+
print('scale1000 normalization')
|
| 341 |
+
|
| 342 |
+
elif normalize=='scale4000':
|
| 343 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975))
|
| 344 |
+
print('scale4000 normalization')
|
| 345 |
+
|
| 346 |
+
elif normalize=='scale2000':
|
| 347 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.9995))
|
| 348 |
+
print('scale2000 normalization')
|
| 349 |
+
|
| 350 |
+
elif normalize=='scale1000':
|
| 351 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
|
| 352 |
+
print('scale1000 normalization')
|
| 353 |
+
|
| 354 |
+
elif normalize=='scale100':
|
| 355 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.99))
|
| 356 |
+
print('scale10 normalization')
|
| 357 |
+
|
| 358 |
+
elif normalize=='scale10':
|
| 359 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
|
| 360 |
+
print('scale10 normalization')
|
| 361 |
+
|
| 362 |
+
elif normalize=='inputonlyzscore':
|
| 363 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
|
| 364 |
+
print('only normalize input MRI images')
|
| 365 |
+
|
| 366 |
+
elif normalize=='inputonlyminmax':
|
| 367 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=self.configs.dataset.normmin, maxv=self.configs.dataset.normmax))
|
| 368 |
+
print('only normalize input MRI images')
|
| 369 |
+
|
| 370 |
+
elif normalize == 'nonegative':
|
| 371 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=self.configs.dataset.offset))
|
| 372 |
+
print('none negative normalization')
|
| 373 |
+
|
| 374 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 375 |
+
print('no normalization')
|
| 376 |
+
|
| 377 |
+
return transform_list
|
| 378 |
+
|
| 379 |
+
def get_shape_transform(self, transform_list):
|
| 380 |
+
spaceXY=self.configs.dataset.spaceXY
|
| 381 |
+
load_masks=self.configs.dataset.load_masks
|
| 382 |
+
|
| 383 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 384 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 385 |
+
pad_value=0 #offset*(-1)
|
| 386 |
+
keys = self.keys #[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B]
|
| 387 |
+
if spaceXY>0:
|
| 388 |
+
transform_list.append(Spacingd(keys=[indicator_A], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear", ensure_same_shape=True)) #
|
| 389 |
+
transform_list.append(Spacingd(keys=[indicator_B, "mask"] if load_masks else [indicator_B],
|
| 390 |
+
pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear", ensure_same_shape=True))
|
| 391 |
+
|
| 392 |
+
transform_list.append(Zoomd(keys=keys,
|
| 393 |
+
zoom=self.configs.dataset.zoom, keep_size=False, mode='area', padding_mode="constant", value=pad_value))
|
| 394 |
+
transform_list.append(DivisiblePadd(keys=keys,
|
| 395 |
+
k=self.configs.dataset.div_size, mode="constant", value=pad_value))
|
| 396 |
+
transform_list.append(ResizeWithPadOrCropd(keys=keys,
|
| 397 |
+
spatial_size=self.configs.dataset.resized_size,mode="constant", value=pad_value))
|
| 398 |
+
|
| 399 |
+
if self.configs.dataset.rotate:
|
| 400 |
+
transform_list.append(Rotate90d(keys=keys, k=3))
|
| 401 |
+
return transform_list
|
| 402 |
+
|
| 403 |
+
class anish_loader(BaseDataLoader):
|
| 404 |
+
def __init__(self,configs,paths,dimension=2):
|
| 405 |
+
self.configs=configs
|
| 406 |
+
self.paths=paths
|
| 407 |
+
self.get_loader()
|
| 408 |
+
super().create_dataset(dimension=dimension)
|
| 409 |
+
self.finalcheck(ifsave=True,ifcheck=False,iftest_volumes_pixdim=False)
|
| 410 |
+
|
| 411 |
+
def get_loader(self):
|
| 412 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 413 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 414 |
+
self.indicator_A=indicator_A
|
| 415 |
+
self.indicator_B=indicator_B
|
| 416 |
+
train_number=self.configs.dataset.train_number
|
| 417 |
+
val_number=self.configs.dataset.val_number
|
| 418 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 419 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 420 |
+
load_masks=self.configs.dataset.load_masks
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
#source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
|
| 424 |
+
#target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
|
| 425 |
+
#mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
|
| 426 |
+
if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
|
| 427 |
+
# check if import data is csv file
|
| 428 |
+
if self.configs.dataset.data_dir.endswith('.csv'):
|
| 429 |
+
csv_file = self.configs.dataset.data_dir
|
| 430 |
+
else:
|
| 431 |
+
raise ValueError('The data directory in this case must be a csv file!')
|
| 432 |
+
else:
|
| 433 |
+
if self.configs.server == 'helix' or self.configs.server == 'helixSingle' or self.configs.server=='helixMultiple':
|
| 434 |
+
csv_file = './healthy_dissec_helix.csv'
|
| 435 |
+
else:
|
| 436 |
+
csv_file = './healthy_dissec.csv'
|
| 437 |
+
|
| 438 |
+
if self.configs.dataset.input_is_mask:
|
| 439 |
+
load_seg=True
|
| 440 |
+
else:
|
| 441 |
+
load_seg=False
|
| 442 |
+
source_file_list, source_Aorta_diss_list=list_img_ad_from_anish_csv(csv_file, load_seg)
|
| 443 |
+
target_file_list, target_Aorta_diss_list=list_img_ad_from_anish_csv(csv_file)
|
| 444 |
+
mask_file_list, mask_Aorta_diss_list=list_img_ad_from_anish_csv(csv_file)
|
| 445 |
+
if load_masks:
|
| 446 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 447 |
+
for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
|
| 448 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 449 |
+
for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
|
| 450 |
+
else:
|
| 451 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
|
| 452 |
+
for i, j, ad in zip(source_file_list[0:train_number], target_file_list[0:train_number], source_Aorta_diss_list[0:train_number])]
|
| 453 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
|
| 454 |
+
for i, j, ad in zip(source_file_list[-val_number:], target_file_list[-val_number:], source_Aorta_diss_list[-val_number:])]
|
| 455 |
+
self.train_ds=train_ds
|
| 456 |
+
self.val_ds=val_ds
|
| 457 |
+
self.source_file_list=source_file_list
|
| 458 |
+
self.target_file_list=target_file_list
|
| 459 |
+
self.mask_file_list=mask_file_list
|
| 460 |
+
|
| 461 |
+
def get_pretransforms(self, transform_list):
|
| 462 |
+
normalize=self.configs.dataset.normalize
|
| 463 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 464 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 465 |
+
load_masks=self.configs.dataset.load_masks
|
| 466 |
+
input_is_mask=self.configs.dataset.input_is_mask
|
| 467 |
+
if not input_is_mask:
|
| 468 |
+
transform_list.append(CreateMaskTransformd(keys=[indicator_A],
|
| 469 |
+
tissue_min=self.configs.dataset.tissue_min,
|
| 470 |
+
tissue_max=self.configs.dataset.tissue_max,
|
| 471 |
+
bone_min=self.configs.dataset.bone_min,
|
| 472 |
+
bone_max=self.configs.dataset.bone_max))
|
| 473 |
+
|
| 474 |
+
from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd
|
| 475 |
+
|
| 476 |
+
class synthrad_seg_loader(BaseDataLoader):
|
| 477 |
+
def __init__(self,configs,paths,dimension=2,**kwargs):
|
| 478 |
+
super().__init__(configs,paths,dimension,**kwargs)
|
| 479 |
+
|
| 480 |
+
def get_loader(self):
|
| 481 |
+
# volume-level transforms for both image and label
|
| 482 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 483 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 484 |
+
train_number=self.configs.dataset.train_number
|
| 485 |
+
val_number=self.configs.dataset.val_number
|
| 486 |
+
self.indicator_A=indicator_A
|
| 487 |
+
self.indicator_B=indicator_B
|
| 488 |
+
load_masks=self.configs.dataset.load_masks
|
| 489 |
+
# Conditional dictionary keys based on whether masks are loaded
|
| 490 |
+
|
| 491 |
+
#list all files in the folder
|
| 492 |
+
file_list=[i for i in os.listdir(self.configs.dataset.data_dir) if 'overview' not in i]
|
| 493 |
+
file_list_path=[os.path.join(self.configs.dataset.data_dir,i) for i in file_list]
|
| 494 |
+
#list all ct and mr files in folder
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
# mask file means the images are used for extracting body contour, see get_pretransforms() below
|
| 498 |
+
source_file_list, patient_IDs=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.source_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"source_filenames.txt"))
|
| 499 |
+
target_file_list, _=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"target_filenames.txt"))
|
| 500 |
+
mask_file_list, _=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"mask_filenames.txt"))
|
| 501 |
+
|
| 502 |
+
self.source_file_list=source_file_list
|
| 503 |
+
self.target_file_list=target_file_list
|
| 504 |
+
self.mask_file_list=mask_file_list
|
| 505 |
+
|
| 506 |
+
Manual_Set_Aorta_Diss = 0
|
| 507 |
+
ad = Manual_Set_Aorta_Diss
|
| 508 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
|
| 509 |
+
for i, j, k, pID in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number], patient_IDs[0:train_number])]
|
| 510 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
|
| 511 |
+
for i, j, k, pID in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:], patient_IDs[-val_number:])]
|
| 512 |
+
self.train_ds=train_ds
|
| 513 |
+
self.val_ds=val_ds
|
| 514 |
+
|
| 515 |
+
def get_pretransforms(self, transform_list):
|
| 516 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 517 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 518 |
+
|
| 519 |
+
transform_list.append(CreateMaskTransformd(keys=['mask'],
|
| 520 |
+
body_threshold=-500,
|
| 521 |
+
body_mask_value=1,
|
| 522 |
+
))
|
| 523 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
|
| 524 |
+
return transform_list
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd, MaskHUAssigmentd
|
| 528 |
+
|
| 529 |
+
from monai.transforms import (
|
| 530 |
+
ScaleIntensityd,
|
| 531 |
+
ThresholdIntensityd,
|
| 532 |
+
NormalizeIntensityd,
|
| 533 |
+
ShiftIntensityd,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
class anish_seg_loader(BaseDataLoader):
|
| 537 |
+
def __init__(self,configs,paths=None,dimension=2, **kwargs):
|
| 538 |
+
super().__init__(configs,paths,dimension, **kwargs)
|
| 539 |
+
|
| 540 |
+
def get_loader(self):
|
| 541 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 542 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 543 |
+
self.indicator_A=indicator_A
|
| 544 |
+
self.indicator_B=indicator_B
|
| 545 |
+
train_number=self.configs.dataset.train_number
|
| 546 |
+
val_number=self.configs.dataset.val_number
|
| 547 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 548 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 549 |
+
load_masks=self.configs.dataset.load_masks
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
#source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
|
| 553 |
+
#target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
|
| 554 |
+
#mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
|
| 555 |
+
print('use csv dataset:',self.configs.dataset.data_dir)
|
| 556 |
+
if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
|
| 557 |
+
# check if import data is csv file
|
| 558 |
+
if self.configs.dataset.data_dir.endswith('.csv'):
|
| 559 |
+
csv_file = self.configs.dataset.data_dir
|
| 560 |
+
else:
|
| 561 |
+
raise ValueError('The data directory in this case must be a csv file!')
|
| 562 |
+
else:
|
| 563 |
+
if self.configs.server == 'helix' or self.configs.server == 'helixSingle' or self.configs.server=='helixMultiple':
|
| 564 |
+
csv_file = './healthy_dissec_helix.csv'
|
| 565 |
+
else:
|
| 566 |
+
csv_file = './healthy_dissec.csv'
|
| 567 |
+
|
| 568 |
+
if self.configs.dataset.input_is_mask:
|
| 569 |
+
load_seg=True
|
| 570 |
+
else:
|
| 571 |
+
load_seg=False
|
| 572 |
+
source_file_list, source_Aorta_diss_list, patient_IDs=list_img_ad_pIDs_from_anish_csv(csv_file, load_seg)
|
| 573 |
+
target_file_list, _, _ =list_img_ad_pIDs_from_anish_csv(csv_file)
|
| 574 |
+
mask_file_list, _, _=list_img_ad_pIDs_from_anish_csv(csv_file)
|
| 575 |
+
|
| 576 |
+
# here the original CT images are loaded as mask because they will be further processed as body contour and merged into mask.
|
| 577 |
+
|
| 578 |
+
if load_masks:
|
| 579 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
|
| 580 |
+
for i, j, k, ad, pID in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number], source_Aorta_diss_list[0:train_number], patient_IDs[0:train_number])]
|
| 581 |
+
|
| 582 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
|
| 583 |
+
for i, j, k, ad, pID in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:], source_Aorta_diss_list[-val_number:], patient_IDs[-val_number:])]
|
| 584 |
+
else:
|
| 585 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
|
| 586 |
+
for i, j, ad in zip(source_file_list[0:train_number], target_file_list[0:train_number], source_Aorta_diss_list[0:train_number])]
|
| 587 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
|
| 588 |
+
for i, j, ad in zip(source_file_list[-val_number:], target_file_list[-val_number:], source_Aorta_diss_list[-val_number:])]
|
| 589 |
+
print('train_ds: \n')
|
| 590 |
+
for i in train_ds:
|
| 591 |
+
print(i)
|
| 592 |
+
print('\n')
|
| 593 |
+
self.train_ds=train_ds
|
| 594 |
+
self.val_ds=val_ds
|
| 595 |
+
self.source_file_list=source_file_list
|
| 596 |
+
self.target_file_list=target_file_list
|
| 597 |
+
self.mask_file_list=mask_file_list
|
| 598 |
+
|
| 599 |
+
def get_pretransforms(self, transform_list):
|
| 600 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 601 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 602 |
+
|
| 603 |
+
transform_list.append(CreateMaskTransformd(keys=['mask'],
|
| 604 |
+
body_threshold=-500,
|
| 605 |
+
body_mask_value=1,
|
| 606 |
+
))
|
| 607 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
|
| 608 |
+
return transform_list
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class combined_seg_loader(BaseDataLoader):
|
| 612 |
+
def __init__(self,configs,paths,dimension=2,**kwargs):
|
| 613 |
+
self.dimension = dimension
|
| 614 |
+
self.train_number_1 = kwargs.get('train_number_1', 170)
|
| 615 |
+
self.train_number_2 = kwargs.get('train_number_2', 152)
|
| 616 |
+
self.val_number_1 = kwargs.get('val_number_1', 10)
|
| 617 |
+
self.val_number_2 = kwargs.get('val_number_2', 10)
|
| 618 |
+
self.data_dir_1 = kwargs.get('data_dir_1', 'E:\Projects\yang_proj\data\synthrad\Task1\pelvis')
|
| 619 |
+
self.data_dir_2 = kwargs.get('data_dir_2', 'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\healthy_dissec.csv')
|
| 620 |
+
super().__init__(configs,paths,dimension,**kwargs)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def get_loader(self):
|
| 624 |
+
# define the dataset sizes for the dataset 1
|
| 625 |
+
self.configs.dataset.data_dir = self.data_dir_1
|
| 626 |
+
self.configs.dataset.train_number = self.train_number_1
|
| 627 |
+
self.configs.dataset.val_number = self.val_number_1
|
| 628 |
+
self.configs.dataset.source_name = ["ct_seg"]
|
| 629 |
+
self.configs.dataset.target_name = ["ct"]
|
| 630 |
+
self.configs.dataset.offset = 1024
|
| 631 |
+
loader1 = synthrad_seg_loader(self.configs,self.paths,self.dimension)
|
| 632 |
+
source_file_list1 = loader1.source_file_list
|
| 633 |
+
|
| 634 |
+
# define the dataset sizes for the dataset 2
|
| 635 |
+
self.configs.dataset.data_dir = self.data_dir_2
|
| 636 |
+
self.configs.dataset.train_number = self.train_number_2
|
| 637 |
+
self.configs.dataset.val_number = self.val_number_2
|
| 638 |
+
self.configs.dataset.offset = 1000
|
| 639 |
+
loader2 = anish_seg_loader(self.configs,self.paths,self.dimension)
|
| 640 |
+
source_file_list2 = loader2.source_file_list
|
| 641 |
+
|
| 642 |
+
train_ds1 = loader1.train_ds
|
| 643 |
+
train_ds2 = loader2.train_ds
|
| 644 |
+
|
| 645 |
+
val_ds1 = loader1.val_ds
|
| 646 |
+
val_ds2 = loader2.val_ds
|
| 647 |
+
|
| 648 |
+
self.train_ds = ConcatDataset([train_ds1, train_ds2])
|
| 649 |
+
self.val_ds = ConcatDataset([val_ds1, val_ds2])
|
| 650 |
+
self.source_file_list = source_file_list1+source_file_list2
|
| 651 |
+
def get_pretransforms(self, transform_list):
|
| 652 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 653 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 654 |
+
|
| 655 |
+
transform_list.append(CreateMaskTransformd(keys=['mask'],
|
| 656 |
+
body_threshold=-500,
|
| 657 |
+
body_mask_value=1,
|
| 658 |
+
))
|
| 659 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
|
| 660 |
+
return transform_list
|
| 661 |
+
|
| 662 |
+
def save_nifti(self, save_output_path, case=0):
|
| 663 |
+
from monai.transforms import SaveImage
|
| 664 |
+
step = 0
|
| 665 |
+
with torch.no_grad():
|
| 666 |
+
for data in self.train_loader:
|
| 667 |
+
si_input = SaveImage(output_dir=f'{save_output_path}',
|
| 668 |
+
separate_folder=False,
|
| 669 |
+
output_postfix=f'', # aug_{step}
|
| 670 |
+
resample=False)
|
| 671 |
+
si_seg = SaveImage(output_dir=f'{save_output_path}',
|
| 672 |
+
separate_folder=False,
|
| 673 |
+
output_postfix=f'', # aug_{step}
|
| 674 |
+
resample=False)
|
| 675 |
+
|
| 676 |
+
image_batch = data['img'].squeeze()
|
| 677 |
+
seg_batch = data['seg'].squeeze()
|
| 678 |
+
file_path_batch = data['B_paths']
|
| 679 |
+
Aorta_diss = data['Aorta_diss']
|
| 680 |
+
|
| 681 |
+
batch_size = len(file_path_batch)
|
| 682 |
+
|
| 683 |
+
for i in range(batch_size):
|
| 684 |
+
step += 1
|
| 685 |
+
|
| 686 |
+
file_path = file_path_batch[i]
|
| 687 |
+
image = image_batch[i]
|
| 688 |
+
seg = seg_batch[i]
|
| 689 |
+
|
| 690 |
+
patient_ID = os.path.splitext(os.path.basename(file_path))[0]
|
| 691 |
+
save_name_img = patient_ID + str(case) + '_' + str(step)
|
| 692 |
+
save_name_img = os.path.join(save_output_path, save_name_img)
|
| 693 |
+
|
| 694 |
+
save_name_seg = patient_ID + str(case) + '_' + str(step) + '_seg'
|
| 695 |
+
save_name_seg = os.path.join(save_output_path, save_name_seg)
|
| 696 |
+
|
| 697 |
+
si_input(image.unsqueeze(0), data['img'].meta, filename=save_name_img)
|
| 698 |
+
si_seg(seg.unsqueeze(0), data['seg'].meta, filename=save_name_seg)
|
| 699 |
+
|
| 700 |
+
class combined_seg_assigned_loader(combined_seg_loader):
|
| 701 |
+
def __init__(self,configs,paths=None,dimension=2, **kwargs):
|
| 702 |
+
self.anatomy_list = kwargs.get('anatomy_list', 'synthrad_conversion/TA2_anatomy.csv')
|
| 703 |
+
super().__init__(configs, paths, dimension, **kwargs)
|
| 704 |
+
|
| 705 |
+
def get_pretransforms(self, transform_list):
|
| 706 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 707 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 708 |
+
|
| 709 |
+
transform_list.append(CreateMaskTransformd(keys=['mask'],
|
| 710 |
+
body_threshold=-500,
|
| 711 |
+
body_mask_value=1,
|
| 712 |
+
))
|
| 713 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
|
| 714 |
+
transform_list.append(MaskHUAssigmentd(keys=[self.indicator_A], csv_file=self.anatomy_list))
|
| 715 |
+
return transform_list
|
| 716 |
+
|
| 717 |
+
def get_intensity_transforms(self, transform_list):
|
| 718 |
+
threshold_low=self.configs.dataset.WINDOW_LEVEL - self.configs.dataset.WINDOW_WIDTH / 2
|
| 719 |
+
threshold_high=self.configs.dataset.WINDOW_LEVEL + self.configs.dataset.WINDOW_WIDTH / 2
|
| 720 |
+
offset=(-1)*threshold_low
|
| 721 |
+
# if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
|
| 722 |
+
# mask = img > self.threshold if self.above else img < self.threshold
|
| 723 |
+
# res = where(mask, img, self.cval)
|
| 724 |
+
transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
|
| 725 |
+
transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
|
| 726 |
+
transform_list.append(ShiftIntensityd(keys=[self.indicator_A,self.indicator_B], offset=offset))
|
| 727 |
+
return transform_list
|
| 728 |
+
|
| 729 |
+
def get_normlization(self, transform_list):
|
| 730 |
+
normalize=self.configs.dataset.normalize
|
| 731 |
+
# offset = self.configs.dataset.offset
|
| 732 |
+
# we don't need normalization for segmentation mask
|
| 733 |
+
if normalize=='zscore':
|
| 734 |
+
transform_list.append(NormalizeIntensityd(keys=[self.indicator_A,self.indicator_B], nonzero=False, channel_wise=True))
|
| 735 |
+
print('zscore normalization')
|
| 736 |
+
|
| 737 |
+
elif normalize=='scale2000':
|
| 738 |
+
transform_list.append(ScaleIntensityd(keys=[self.indicator_A,self.indicator_B], minv=None, maxv=None, factor=-0.9995))
|
| 739 |
+
print('scale2000 normalization')
|
| 740 |
+
|
| 741 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 742 |
+
print('no normalization')
|
| 743 |
+
|
| 744 |
+
return transform_list
|
| 745 |
+
|
| 746 |
+
class slices_nifti_DataLoader(BaseDataLoader):
|
| 747 |
+
def __init__(self,configs,paths=None,dimension=2, **kwargs):
|
| 748 |
+
super().__init__(configs, paths, dimension, **kwargs)
|
| 749 |
+
|
| 750 |
+
def get_loader(self):
|
| 751 |
+
print('use json dataset:',self.configs.dataset.data_dir)
|
| 752 |
+
if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
|
| 753 |
+
json_file_root = self.configs.dataset.data_dir
|
| 754 |
+
else:
|
| 755 |
+
raise ValueError('please check the data dir in config file!')
|
| 756 |
+
json_file_train = os.path.join(json_file_root, 'train', 'dataset.json')
|
| 757 |
+
json_file_val = os.path.join(json_file_root, 'val', 'dataset.json')
|
| 758 |
+
|
| 759 |
+
self.train_ds = list_from_json(json_file_train, self.indicator_A, self.indicator_B)
|
| 760 |
+
self.val_ds = list_from_json(json_file_val, self.indicator_A, self.indicator_B)
|
| 761 |
+
|
| 762 |
+
def create_patch_dataset_and_dataloader(self, dimension=2):
|
| 763 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 764 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 765 |
+
self.train_loader = DataLoader(
|
| 766 |
+
self.train_volume_ds,
|
| 767 |
+
num_workers=self.configs.dataset.num_workers,
|
| 768 |
+
batch_size=train_batch_size,
|
| 769 |
+
shuffle=True,
|
| 770 |
+
pin_memory=torch.cuda.is_available())
|
| 771 |
+
|
| 772 |
+
self.val_loader = DataLoader(
|
| 773 |
+
self.val_volume_ds,
|
| 774 |
+
num_workers=self.configs.dataset.num_workers,
|
| 775 |
+
batch_size=val_batch_size,
|
| 776 |
+
shuffle=False,
|
| 777 |
+
pin_memory=torch.cuda.is_available())
|
| 778 |
+
|
| 779 |
+
class csv_slices_DataLoader(BaseDataLoader):
|
| 780 |
+
def __init__(self,configs,paths=None,dimension=2, **kwargs):
|
| 781 |
+
super().__init__(configs, paths, dimension, **kwargs)
|
| 782 |
+
|
| 783 |
+
def get_loader(self):
|
| 784 |
+
print('use csv dataset:',self.configs.dataset.data_dir)
|
| 785 |
+
if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
|
| 786 |
+
csv_file_root = self.configs.dataset.data_dir
|
| 787 |
+
else:
|
| 788 |
+
raise ValueError('please check the data dir in config file!')
|
| 789 |
+
folder_train = os.path.join(csv_file_root, 'train')
|
| 790 |
+
folder__val = os.path.join(csv_file_root, 'val')
|
| 791 |
+
|
| 792 |
+
self.train_ds = list_from_slice_csv(folder_train, self.indicator_A, self.indicator_B)
|
| 793 |
+
self.val_ds = list_from_slice_csv(folder__val, self.indicator_A, self.indicator_B)
|
| 794 |
+
|
| 795 |
+
def create_patch_dataset_and_dataloader(self, dimension=2):
|
| 796 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 797 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 798 |
+
self.train_loader = DataLoader(
|
| 799 |
+
self.train_volume_ds,
|
| 800 |
+
num_workers=self.configs.dataset.num_workers,
|
| 801 |
+
batch_size=train_batch_size,
|
| 802 |
+
shuffle=True,
|
| 803 |
+
pin_memory=torch.cuda.is_available())
|
| 804 |
+
|
| 805 |
+
self.val_loader = DataLoader(
|
| 806 |
+
self.val_volume_ds,
|
| 807 |
+
num_workers=self.configs.dataset.num_workers,
|
| 808 |
+
batch_size=val_batch_size,
|
| 809 |
+
shuffle=False,
|
| 810 |
+
pin_memory=torch.cuda.is_available())
|
| 811 |
+
|
| 812 |
+
class csv_slices_assigned_DataLoader(csv_slices_DataLoader):
|
| 813 |
+
def __init__(self,configs,paths=None,dimension=2, **kwargs):
|
| 814 |
+
super().__init__(configs, paths, dimension, **kwargs)
|
| 815 |
+
|
| 816 |
+
def get_pretransforms(self, transform_list):
|
| 817 |
+
transform_list.append(MaskHUAssigmentd(keys=[self.indicator_A], csv_file=r'synthrad_conversion\TA2_anatomy.csv'))
|
| 818 |
+
return transform_list
|
| 819 |
+
|
| 820 |
+
def get_intensity_transforms(self, transform_list):
|
| 821 |
+
threshold_low=self.configs.dataset.WINDOW_LEVEL - self.configs.dataset.WINDOW_WIDTH / 2
|
| 822 |
+
threshold_high=self.configs.dataset.WINDOW_LEVEL + self.configs.dataset.WINDOW_WIDTH / 2
|
| 823 |
+
offset=(-1)*threshold_low
|
| 824 |
+
# if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
|
| 825 |
+
# mask = img > self.threshold if self.above else img < self.threshold
|
| 826 |
+
# res = where(mask, img, self.cval)
|
| 827 |
+
transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
|
| 828 |
+
transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
|
| 829 |
+
transform_list.append(ShiftIntensityd(keys=[self.indicator_A,self.indicator_B], offset=offset))
|
| 830 |
+
return transform_list
|
| 831 |
+
|
| 832 |
+
def get_normlization(self, transform_list):
|
| 833 |
+
normalize=self.configs.dataset.normalize
|
| 834 |
+
# offset = self.configs.dataset.offset
|
| 835 |
+
# we don't need normalization for segmentation mask
|
| 836 |
+
if normalize=='zscore':
|
| 837 |
+
transform_list.append(NormalizeIntensityd(keys=[self.indicator_A,self.indicator_B], nonzero=False, channel_wise=True))
|
| 838 |
+
print('zscore normalization')
|
| 839 |
+
|
| 840 |
+
elif normalize=='scale2000':
|
| 841 |
+
transform_list.append(ScaleIntensityd(keys=[self.indicator_A,self.indicator_B], minv=None, maxv=None, factor=-0.9995))
|
| 842 |
+
print('scale2000 normalization')
|
| 843 |
+
|
| 844 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 845 |
+
print('no normalization')
|
| 846 |
+
|
| 847 |
+
return transform_list
|
| 848 |
+
|
| 849 |
+
# for MRI -> CT task
|
| 850 |
+
|
| 851 |
+
class synthrad_mr2ct_loader(BaseDataLoader):
|
| 852 |
+
def __init__(self,configs,paths=None,dimension=2):
|
| 853 |
+
super().__init__(configs,paths,dimension)
|
| 854 |
+
|
| 855 |
+
def get_loader(self):
|
| 856 |
+
# volume-level transforms for both image and label
|
| 857 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 858 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 859 |
+
train_number=self.configs.dataset.train_number
|
| 860 |
+
val_number=self.configs.dataset.val_number
|
| 861 |
+
self.indicator_A=indicator_A
|
| 862 |
+
self.indicator_B=indicator_B
|
| 863 |
+
load_masks=self.configs.dataset.load_masks
|
| 864 |
+
# Conditional dictionary keys based on whether masks are loaded
|
| 865 |
+
|
| 866 |
+
#list all files in the folder
|
| 867 |
+
file_list=[i for i in os.listdir(self.configs.dataset.data_dir) if 'overview' not in i]
|
| 868 |
+
file_list_path=[os.path.join(self.configs.dataset.data_dir,i) for i in file_list]
|
| 869 |
+
#list all ct and mr files in folder
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
#source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
|
| 873 |
+
#target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
|
| 874 |
+
#mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
|
| 875 |
+
source_file_list,_=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.source_name,saved_name=None)
|
| 876 |
+
target_file_list,_=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name,saved_name=None)
|
| 877 |
+
mask_file_list,_=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.mask_name,saved_name=None)
|
| 878 |
+
|
| 879 |
+
def write_write_file(images, file):
|
| 880 |
+
with open(file,"w") as file:
|
| 881 |
+
for image in images:
|
| 882 |
+
file.write(f'{image} \n')
|
| 883 |
+
|
| 884 |
+
if self.paths is not None:
|
| 885 |
+
write_write_file(source_file_list, os.path.join(self.paths["saved_logs_folder"],"source_filenames.txt"))
|
| 886 |
+
write_write_file(target_file_list, os.path.join(self.paths["saved_logs_folder"],"target_filenames.txt"))
|
| 887 |
+
write_write_file(mask_file_list, os.path.join(self.paths["saved_logs_folder"],"mask_filenames.txt"))
|
| 888 |
+
|
| 889 |
+
self.source_file_list=source_file_list
|
| 890 |
+
self.target_file_list=target_file_list
|
| 891 |
+
self.mask_file_list=mask_file_list
|
| 892 |
+
|
| 893 |
+
if load_masks:
|
| 894 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 895 |
+
for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
|
| 896 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 897 |
+
for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
|
| 898 |
+
else:
|
| 899 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
|
| 900 |
+
for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])]
|
| 901 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
|
| 902 |
+
for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])]
|
| 903 |
+
self.train_ds=train_ds
|
| 904 |
+
self.val_ds=val_ds
|
| 905 |
+
|
| 906 |
+
def get_normlization(self, transform_list):
|
| 907 |
+
normalize=self.configs.dataset.normalize
|
| 908 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 909 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 910 |
+
load_masks=self.configs.dataset.load_masks
|
| 911 |
+
if normalize=='zscore':
|
| 912 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A, indicator_B], nonzero=False, channel_wise=True))
|
| 913 |
+
print('zscore normalization')
|
| 914 |
+
elif normalize=='minmax':
|
| 915 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A, indicator_B], minv=-1.0, maxv=1.0))
|
| 916 |
+
print('minmax normalization')
|
| 917 |
+
|
| 918 |
+
elif normalize=='scale4000':
|
| 919 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 920 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
|
| 921 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975)) # x=x(1+factor)
|
| 922 |
+
print('scale4000 normalization')
|
| 923 |
+
|
| 924 |
+
elif normalize=='scale1000_wrongbutworks':
|
| 925 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 926 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 927 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
|
| 928 |
+
print('scale1000 normalization')
|
| 929 |
+
|
| 930 |
+
elif normalize=='scale1000':
|
| 931 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 932 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
|
| 933 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
|
| 934 |
+
print('scale1000 normalization')
|
| 935 |
+
|
| 936 |
+
elif normalize=='scale10':
|
| 937 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 938 |
+
#transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 939 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
|
| 940 |
+
print('scale10 normalization')
|
| 941 |
+
|
| 942 |
+
elif normalize=='inputonlyzscore':
|
| 943 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
|
| 944 |
+
print('only normalize input MRI images')
|
| 945 |
+
|
| 946 |
+
elif normalize=='inputonlyminmax':
|
| 947 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=self.configs.dataset.normmin, maxv=self.configs.dataset.normmax))
|
| 948 |
+
print('only normalize input MRI images')
|
| 949 |
+
|
| 950 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 951 |
+
print('no normalization')
|
| 952 |
+
return transform_list
|
| 953 |
+
|
| 954 |
+
class anika_registrated_mr2ct_loader(synthrad_mr2ct_loader):
|
| 955 |
+
def __init__(self,configs,paths,dimension):
|
| 956 |
+
super().__init__(configs,paths,dimension)
|
| 957 |
+
|
| 958 |
+
def get_loader(self):
|
| 959 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 960 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 961 |
+
self.indicator_A=indicator_A
|
| 962 |
+
self.indicator_B=indicator_B
|
| 963 |
+
train_number=self.configs.dataset.train_number
|
| 964 |
+
val_number=self.configs.dataset.val_number
|
| 965 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 966 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 967 |
+
load_masks=self.configs.dataset.load_masks
|
| 968 |
+
|
| 969 |
+
# Conditional dictionary keys based on whether masks are loaded
|
| 970 |
+
keys = [indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B]
|
| 971 |
+
|
| 972 |
+
ct_dir = r'E:\Datasets\M2olie_Patientdata\CT'
|
| 973 |
+
mri_dir = r'E:\Results\MultistepReg\M2olie_Patientdata\Multistep_network_A\predict'
|
| 974 |
+
|
| 975 |
+
ct_dir = self.configs.dataset.ct_dir #'E:\Datasets\M2olie_Patientdata\CT'
|
| 976 |
+
mri_dir = self.configs.dataset.mri_dir #'E:\Results\MultistepReg\M2olie_Patientdata\Multistep_network_A\predict'
|
| 977 |
+
matched_pairs = list_from_anika_dataset(ct_dir, mri_dir, self.configs.dataset.mri_mode)
|
| 978 |
+
for patient_id, paths in matched_pairs.items():
|
| 979 |
+
print(f"Patient ID: {patient_id}, CT: {paths['CT']}, MRI: {paths['MRI']}")
|
| 980 |
+
|
| 981 |
+
# use the matched pairs to form the dataset
|
| 982 |
+
train_ds = [{indicator_A: paths['MRI'], indicator_B: paths['CT']} for patient_id, paths in list(matched_pairs.items())[:train_number]]
|
| 983 |
+
val_ds = [{indicator_A: paths['MRI'], indicator_B: paths['CT']} for patient_id, paths in list(matched_pairs.items())[-val_number:]]
|
dataprocesser/archive/list_dataset_combined_seg.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd
|
| 3 |
+
IMG_EXTENSIONS = [
|
| 4 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 5 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 6 |
+
'.nrrd', '.nii.gz'
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
def is_image_file(filename):
|
| 10 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 11 |
+
import torch
|
| 12 |
+
from dataprocesser.list_dataset_synthrad_seg import synthrad_seg_loader
|
| 13 |
+
from dataprocesser.list_dataset_Anish_seg import anish_seg_loader
|
| 14 |
+
from dataprocesser.list_dataset_base import BaseDataLoader
|
| 15 |
+
|
dataprocesser/archive/list_dataset_combined_seg_assigned.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
dataprocesser/archive/list_dataset_synthrad.py
ADDED
|
File without changes
|
dataprocesser/archive/list_dataset_synthrad_seg.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
dataprocesser/archive/monai_loader_3D.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import monai
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from monai.transforms import (
|
| 5 |
+
Compose,
|
| 6 |
+
LoadImaged,
|
| 7 |
+
EnsureChannelFirstd,
|
| 8 |
+
SqueezeDimd,
|
| 9 |
+
CenterSpatialCropd,
|
| 10 |
+
Rotate90d,
|
| 11 |
+
ScaleIntensityd,
|
| 12 |
+
ResizeWithPadOrCropd,
|
| 13 |
+
DivisiblePadd,
|
| 14 |
+
|
| 15 |
+
ThresholdIntensityd,
|
| 16 |
+
NormalizeIntensityd,
|
| 17 |
+
ShiftIntensityd,
|
| 18 |
+
Identityd,
|
| 19 |
+
ScaleIntensityRanged,
|
| 20 |
+
Spacingd,
|
| 21 |
+
)
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
IMG_EXTENSIONS = [
|
| 26 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 27 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 28 |
+
'.nrrd', '.nii.gz'
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
def is_image_file(filename):
|
| 32 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 33 |
+
|
| 34 |
+
def make_dataset_modality(dir, accepted_modalities = ["ct"], saved_name="source_filenames.txt"):
|
| 35 |
+
# it works for root path of any layer:
|
| 36 |
+
# data_path/Task1 or Task2/pelvis or brain
|
| 37 |
+
# |-patient1
|
| 38 |
+
# |-ct.nii.gz
|
| 39 |
+
# |-mr.nii.gz
|
| 40 |
+
# |-patient2
|
| 41 |
+
# |-ct.nii.gz
|
| 42 |
+
# |-mr.nii.gz
|
| 43 |
+
images = []
|
| 44 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
| 45 |
+
for roots, _, files in sorted(os.walk(dir)): # os.walk digs all folders and subfolders in all layers of dir
|
| 46 |
+
for file in files:
|
| 47 |
+
if is_image_file(file) and file.split('.')[0] in accepted_modalities:
|
| 48 |
+
path = os.path.join(roots, file)
|
| 49 |
+
images.append(path)
|
| 50 |
+
print(f'Found {len(images)} {accepted_modalities} files in {dir} \n')
|
| 51 |
+
with open(saved_name,"w") as file:
|
| 52 |
+
for image in images:
|
| 53 |
+
file.write(f'{image} \n')
|
| 54 |
+
return images
|
| 55 |
+
|
| 56 |
+
class monai_loader_3D:
|
| 57 |
+
def __init__(self,configs,paths):
|
| 58 |
+
self.configs=configs
|
| 59 |
+
self.paths=paths
|
| 60 |
+
self.get_loader()
|
| 61 |
+
self.finalcheck(ifsave=True,ifcheck=False,iftest_volumes_pixdim=False)
|
| 62 |
+
|
| 63 |
+
def get_loader(self):
|
| 64 |
+
# volume-level transforms for both image and label
|
| 65 |
+
train_transforms = self.get_transforms(self.configs,mode='train')
|
| 66 |
+
val_transforms = self.get_transforms(self.configs,mode='val')
|
| 67 |
+
indicator_A=self.configs.dataset.indicator_A
|
| 68 |
+
indicator_B=self.configs.dataset.indicator_B
|
| 69 |
+
self.indicator_A=indicator_A
|
| 70 |
+
self.indicator_B=indicator_B
|
| 71 |
+
train_number=self.configs.dataset.train_number
|
| 72 |
+
val_number=self.configs.dataset.val_number
|
| 73 |
+
train_batch_size=self.configs.dataset.batch_size
|
| 74 |
+
val_batch_size=self.configs.dataset.val_batch_size
|
| 75 |
+
load_masks=self.configs.dataset.load_masks
|
| 76 |
+
|
| 77 |
+
# Conditional dictionary keys based on whether masks are loaded
|
| 78 |
+
keys = [indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
#list all files in the folder
|
| 82 |
+
file_list=[i for i in os.listdir(self.configs.dataset.data_dir) if 'overview' not in i]
|
| 83 |
+
file_list_path=[os.path.join(self.configs.dataset.data_dir,i) for i in file_list]
|
| 84 |
+
#list all ct and mr files in folder
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
#source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
|
| 88 |
+
#target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
|
| 89 |
+
#mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
|
| 90 |
+
source_file_list=make_dataset_modality(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.source_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"source_filenames.txt"))
|
| 91 |
+
target_file_list=make_dataset_modality(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"target_filenames.txt"))
|
| 92 |
+
mask_file_list=make_dataset_modality(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.mask_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"mask_filenames.txt"))
|
| 93 |
+
|
| 94 |
+
if load_masks:
|
| 95 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 96 |
+
for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
|
| 97 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
|
| 98 |
+
for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
|
| 99 |
+
else:
|
| 100 |
+
train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
|
| 101 |
+
for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])]
|
| 102 |
+
val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
|
| 103 |
+
for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])]
|
| 104 |
+
|
| 105 |
+
print('all files in dataset:',len(source_file_list))
|
| 106 |
+
|
| 107 |
+
# load volumes and center crop
|
| 108 |
+
center_crop = self.configs.dataset.center_crop
|
| 109 |
+
transformations_crop = [
|
| 110 |
+
LoadImaged(keys=keys),
|
| 111 |
+
EnsureChannelFirstd(keys=keys),
|
| 112 |
+
]
|
| 113 |
+
if center_crop>0:
|
| 114 |
+
transformations_crop.append(CenterSpatialCropd(keys=keys, roi_size=(-1,-1,center_crop)))
|
| 115 |
+
transformations_crop=Compose(transformations_crop)
|
| 116 |
+
train_crop_ds = monai.data.Dataset(data=train_ds, transform=transformations_crop)
|
| 117 |
+
val_crop_ds = monai.data.Dataset(data=val_ds, transform=transformations_crop)
|
| 118 |
+
|
| 119 |
+
# load volumes
|
| 120 |
+
train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
|
| 121 |
+
val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
|
| 122 |
+
|
| 123 |
+
train_loader = DataLoader(train_volume_ds, batch_size=train_batch_size, shuffle=True, num_workers=self.configs.dataset.num_workers)
|
| 124 |
+
val_loader = DataLoader(val_volume_ds, batch_size=val_batch_size, shuffle=False, num_workers=self.configs.dataset.num_workers)
|
| 125 |
+
|
| 126 |
+
self.saved_name_train=self.paths["saved_name_train"]
|
| 127 |
+
self.saved_name_val=self.paths["saved_name_val"]
|
| 128 |
+
|
| 129 |
+
self.train_ds=train_ds
|
| 130 |
+
self.val_ds=val_ds
|
| 131 |
+
self.train_volume_ds=train_volume_ds
|
| 132 |
+
self.val_volume_ds=val_volume_ds
|
| 133 |
+
|
| 134 |
+
self.train_batch_size=train_batch_size
|
| 135 |
+
self.val_batch_size=val_batch_size
|
| 136 |
+
|
| 137 |
+
self.train_crop_ds=train_crop_ds
|
| 138 |
+
self.val_crop_ds=val_crop_ds
|
| 139 |
+
self.train_transforms=train_transforms
|
| 140 |
+
self.val_transforms=val_transforms
|
| 141 |
+
|
| 142 |
+
self.train_loader=train_loader
|
| 143 |
+
self.val_loader=val_loader
|
| 144 |
+
|
| 145 |
+
def get_transforms(self, configs, mode='train'):
|
| 146 |
+
normalize=configs.dataset.normalize
|
| 147 |
+
pad=configs.dataset.pad
|
| 148 |
+
resized_size=configs.dataset.resized_size
|
| 149 |
+
WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
|
| 150 |
+
WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
|
| 151 |
+
prob=configs.dataset.augmentationProb
|
| 152 |
+
background=configs.dataset.background
|
| 153 |
+
indicator_A=configs.dataset.indicator_A
|
| 154 |
+
indicator_B=configs.dataset.indicator_B
|
| 155 |
+
load_masks=self.configs.dataset.load_masks
|
| 156 |
+
transform_list=[]
|
| 157 |
+
min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
|
| 158 |
+
#transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=min, above=True, cval=background))
|
| 159 |
+
#transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=max, above=False, cval=-1000))
|
| 160 |
+
# filter the source images
|
| 161 |
+
# transform_list.append(ThresholdIntensityd(keys=[indicator_A], threshold=configs.dataset.MRImax, above=False, cval=0))
|
| 162 |
+
if normalize=='zscore':
|
| 163 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A, indicator_B], nonzero=False, channel_wise=True))
|
| 164 |
+
print('zscore normalization')
|
| 165 |
+
elif normalize=='minmax':
|
| 166 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A, indicator_B], minv=-1.0, maxv=1.0))
|
| 167 |
+
print('minmax normalization')
|
| 168 |
+
|
| 169 |
+
elif normalize=='scale4000':
|
| 170 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 171 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
|
| 172 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975)) # x=x(1+factor)
|
| 173 |
+
print('scale4000 normalization')
|
| 174 |
+
|
| 175 |
+
elif normalize=='scale1000_wrongbutworks':
|
| 176 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 177 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 178 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
|
| 179 |
+
print('scale1000 normalization')
|
| 180 |
+
|
| 181 |
+
elif normalize=='scale1000':
|
| 182 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 183 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
|
| 184 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
|
| 185 |
+
print('scale1000 normalization')
|
| 186 |
+
|
| 187 |
+
elif normalize=='scale10':
|
| 188 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
|
| 189 |
+
#transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 190 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
|
| 191 |
+
print('scale10 normalization')
|
| 192 |
+
|
| 193 |
+
elif normalize=='inputonlyzscore':
|
| 194 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
|
| 195 |
+
print('only normalize input MRI images')
|
| 196 |
+
|
| 197 |
+
elif normalize=='inputonlyminmax':
|
| 198 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
|
| 199 |
+
print('only normalize input MRI images')
|
| 200 |
+
|
| 201 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 202 |
+
print('no normalization')
|
| 203 |
+
|
| 204 |
+
spaceXY=0
|
| 205 |
+
if spaceXY>0:
|
| 206 |
+
transform_list.append(Spacingd(keys=[indicator_A], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear")) #
|
| 207 |
+
transform_list.append(Spacingd(keys=[indicator_B, "mask"] if load_masks else [indicator_B],
|
| 208 |
+
pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear"))
|
| 209 |
+
|
| 210 |
+
transform_list.append(ResizeWithPadOrCropd(keys=[indicator_A, indicator_B,"mask"] if load_masks else [indicator_A, indicator_B],
|
| 211 |
+
spatial_size=resized_size,mode=pad))
|
| 212 |
+
|
| 213 |
+
if configs.dataset.rotate:
|
| 214 |
+
transform_list.append(Rotate90d(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B], k=3))
|
| 215 |
+
|
| 216 |
+
if mode == 'train':
|
| 217 |
+
from monai.transforms import (
|
| 218 |
+
# data augmentation
|
| 219 |
+
RandRotated,
|
| 220 |
+
RandZoomd,
|
| 221 |
+
RandBiasFieldd,
|
| 222 |
+
RandAffined,
|
| 223 |
+
RandGridDistortiond,
|
| 224 |
+
RandGridPatchd,
|
| 225 |
+
RandShiftIntensityd,
|
| 226 |
+
RandGibbsNoised,
|
| 227 |
+
RandAdjustContrastd,
|
| 228 |
+
RandGaussianSmoothd,
|
| 229 |
+
RandGaussianSharpend,
|
| 230 |
+
RandGaussianNoised,
|
| 231 |
+
)
|
| 232 |
+
shapeAug=configs.dataset.shapeAug
|
| 233 |
+
if shapeAug:
|
| 234 |
+
transform_list.append(RandRotated(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
|
| 235 |
+
range_x = 0.1, range_y = 0.1, range_z = 0.1,
|
| 236 |
+
prob=prob, padding_mode="border", keep_size=True))
|
| 237 |
+
transform_list.append(RandZoomd(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
|
| 238 |
+
prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True))
|
| 239 |
+
transform_list.append(RandAffined(keys=[indicator_A, indicator_B], padding_mode="border" , prob=prob))
|
| 240 |
+
#transform_list.append(Rand3DElasticd(keys=[indicator_A, indicator_B], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
|
| 241 |
+
intensityAug=configs.dataset.intensityAug
|
| 242 |
+
if intensityAug:
|
| 243 |
+
print('intensity data augmentation is used')
|
| 244 |
+
transform_list.append(RandBiasFieldd(keys=[indicator_A], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
|
| 245 |
+
transform_list.append(RandGaussianNoised(keys=[indicator_A], prob=prob, mean=0.0, std=0.01))
|
| 246 |
+
transform_list.append(RandAdjustContrastd(keys=[indicator_A], prob=prob, gamma=(0.5, 1.5)))
|
| 247 |
+
transform_list.append(RandShiftIntensityd(keys=[indicator_A], prob=prob, offsets=20))
|
| 248 |
+
transform_list.append(RandGaussianSharpend(keys=[indicator_A], alpha=(0.2, 0.8), prob=prob))
|
| 249 |
+
|
| 250 |
+
#transform_list.append(Rotate90d(keys=[indicator_A, indicator_B], k=3))
|
| 251 |
+
#transform_list.append(DivisiblePadd(keys=[indicator_A, indicator_B], k=div_size, mode="minimum"))
|
| 252 |
+
#transform_list.append(Identityd(keys=[indicator_A, indicator_B])) # do nothing for the no norm case
|
| 253 |
+
train_transforms = Compose(transform_list)
|
| 254 |
+
return train_transforms
|
| 255 |
+
|
| 256 |
+
def finalcheck(self,ifsave=False,ifcheck=False,iftest_volumes_pixdim=False):
|
| 257 |
+
if ifsave:
|
| 258 |
+
self.save_volumes(self.train_ds, self.val_ds, self.saved_name_train, self.saved_name_val)
|
| 259 |
+
if iftest_volumes_pixdim:
|
| 260 |
+
self.test_volumes_pixdim(self.train_volume_ds)
|
| 261 |
+
if ifcheck:
|
| 262 |
+
self.check_volumes(self.train_ds, self.train_volume_ds, self.val_volume_ds, self.val_ds)
|
| 263 |
+
self.check_batch_data(self.train_loader,self.val_loader,
|
| 264 |
+
self.train_patch_ds,self.val_volume_ds,
|
| 265 |
+
self.train_batch_size,self.val_batch_size)
|
| 266 |
+
|
| 267 |
+
def test_volumes_pixdim(self, train_volume_ds):
|
| 268 |
+
train_loader = DataLoader(train_volume_ds, batch_size=1)
|
| 269 |
+
for step, data in enumerate(train_loader):
|
| 270 |
+
mr_data=data[self.indicator_A]
|
| 271 |
+
ct_data=data[self.indicator_B]
|
| 272 |
+
|
| 273 |
+
print(f"source image shape: {mr_data.shape}")
|
| 274 |
+
print(f"source image affine:\n{mr_data.meta['affine']}")
|
| 275 |
+
print(f"source image pixdim:\n{mr_data.pixdim}")
|
| 276 |
+
|
| 277 |
+
# target image information
|
| 278 |
+
print(f"target image shape: {ct_data.shape}")
|
| 279 |
+
print(f"target image affine:\n{ct_data.meta['affine']}")
|
| 280 |
+
print(f"target image pixdim:\n{ct_data.pixdim}")
|
| 281 |
+
|
| 282 |
+
def check_volumes(self, train_ds, train_volume_ds, val_volume_ds, val_ds):
|
| 283 |
+
# use batch_size=1 to check the volumes because the input volumes have different shapes
|
| 284 |
+
train_loader = DataLoader(train_volume_ds, batch_size=1)
|
| 285 |
+
val_loader = DataLoader(val_volume_ds, batch_size=1)
|
| 286 |
+
train_iterator = iter(train_loader)
|
| 287 |
+
val_iterator = iter(val_loader)
|
| 288 |
+
print('check training data:')
|
| 289 |
+
idx=0
|
| 290 |
+
for idx in range(len(train_loader)):
|
| 291 |
+
try:
|
| 292 |
+
train_check_data = next(train_iterator)
|
| 293 |
+
ds_idx = idx * 1
|
| 294 |
+
current_item = train_ds[ds_idx]
|
| 295 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 296 |
+
print(idx, current_name, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
|
| 297 |
+
except:
|
| 298 |
+
ds_idx = idx * 1
|
| 299 |
+
current_item = train_ds[ds_idx]
|
| 300 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 301 |
+
print('check data error! Check the input data:',current_name)
|
| 302 |
+
print("checked all training data.")
|
| 303 |
+
|
| 304 |
+
print('check validation data:')
|
| 305 |
+
idx=0
|
| 306 |
+
for idx in range(len(val_loader)):
|
| 307 |
+
try:
|
| 308 |
+
val_check_data = next(val_iterator)
|
| 309 |
+
ds_idx = idx * 1
|
| 310 |
+
current_item = val_ds[ds_idx]
|
| 311 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 312 |
+
print(idx, current_name, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
|
| 313 |
+
except:
|
| 314 |
+
ds_idx = idx * 1
|
| 315 |
+
current_item = val_ds[ds_idx]
|
| 316 |
+
current_name = os.path.basename(os.path.dirname(current_item['image']))
|
| 317 |
+
print('check data error! Check the input data:',current_name)
|
| 318 |
+
print("checked all validation data.")
|
| 319 |
+
|
| 320 |
+
def save_volumes(self, train_ds, val_ds, saved_name_train, saved_name_val):
|
| 321 |
+
shape_list_train=[]
|
| 322 |
+
shape_list_val=[]
|
| 323 |
+
# use the function of saving information before
|
| 324 |
+
for sample in train_ds:
|
| 325 |
+
name = os.path.basename(os.path.dirname(sample[self.indicator_A]))
|
| 326 |
+
shape_list_train.append({'patient': name})
|
| 327 |
+
for sample in val_ds:
|
| 328 |
+
name = os.path.basename(os.path.dirname(sample[self.indicator_A]))
|
| 329 |
+
shape_list_val.append({'patient': name})
|
| 330 |
+
np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
|
| 331 |
+
np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
|
| 332 |
+
|
| 333 |
+
def check_batch_data(self, train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size):
|
| 334 |
+
for idx, train_check_data in enumerate(train_loader):
|
| 335 |
+
ds_idx = idx * train_batch_size
|
| 336 |
+
current_item = train_patch_ds[ds_idx]
|
| 337 |
+
print('check train data:')
|
| 338 |
+
print(current_item, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
|
| 339 |
+
|
| 340 |
+
for idx, val_check_data in enumerate(val_loader):
|
| 341 |
+
ds_idx = idx * val_batch_size
|
| 342 |
+
current_item = val_volume_ds[ds_idx]
|
| 343 |
+
print('check val data:')
|
| 344 |
+
print(current_item, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
|
| 345 |
+
|
| 346 |
+
def len_patchloader(self, train_volume_ds,train_batch_size):
|
| 347 |
+
slice_number=sum(train_volume_ds[i][self.indicator_A].shape[-1] for i in range(len(train_volume_ds)))
|
| 348 |
+
print('total slices in training set:',slice_number)
|
| 349 |
+
|
| 350 |
+
import math
|
| 351 |
+
batch_number=sum(math.ceil(train_volume_ds[i][self.indicator_A].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
|
| 352 |
+
print('total batches in training set:',batch_number)
|
| 353 |
+
return slice_number,batch_number
|
| 354 |
+
|
| 355 |
+
def get_length(self, dataset, patch_batch_size):
|
| 356 |
+
loader=DataLoader(dataset, batch_size=1)
|
| 357 |
+
iterator = iter(loader)
|
| 358 |
+
sum_nslices=0
|
| 359 |
+
for idx in range(len(loader)):
|
| 360 |
+
check_data = next(iterator)
|
| 361 |
+
nslices=check_data[self.indicator_A].shape[-1]
|
| 362 |
+
sum_nslices+=nslices
|
| 363 |
+
if sum_nslices%patch_batch_size==0:
|
| 364 |
+
return sum_nslices//patch_batch_size
|
| 365 |
+
else:
|
| 366 |
+
return sum_nslices//patch_batch_size+1
|
| 367 |
+
|
dataprocesser/archive/slice_loader.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import monai
|
| 2 |
+
from monai.transforms import (
|
| 3 |
+
Compose,
|
| 4 |
+
LoadImaged,
|
| 5 |
+
Rotate90d,
|
| 6 |
+
ScaleIntensityd,
|
| 7 |
+
EnsureChannelFirstd,
|
| 8 |
+
ResizeWithPadOrCropd,
|
| 9 |
+
DivisiblePadd,
|
| 10 |
+
ThresholdIntensityd,
|
| 11 |
+
NormalizeIntensityd,
|
| 12 |
+
SqueezeDimd,
|
| 13 |
+
Identityd,
|
| 14 |
+
CenterSpatialCropd,
|
| 15 |
+
)
|
| 16 |
+
from monai.data import Dataset
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from .basics import get_file_list, get_transforms, load_volumes, crop_volumes
|
| 21 |
+
from .checkdata import check_batch_data, check_volumes, save_volumes
|
| 22 |
+
##### slices #####
|
| 23 |
+
def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=8,val_batch_size=1,window_width=1,ifcheck=True):
|
| 24 |
+
patch_func = monai.data.PatchIterd(
|
| 25 |
+
keys=["source", "target"],
|
| 26 |
+
patch_size=(None, None, window_width), # dynamic first two dimensions
|
| 27 |
+
start_pos=(0, 0, 0)
|
| 28 |
+
)
|
| 29 |
+
if window_width==1:
|
| 30 |
+
patch_transform = Compose(
|
| 31 |
+
[
|
| 32 |
+
SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
else:
|
| 36 |
+
patch_transform = None
|
| 37 |
+
# for training
|
| 38 |
+
train_patch_ds = monai.data.GridPatchDataset(
|
| 39 |
+
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 40 |
+
train_loader = DataLoader(
|
| 41 |
+
train_patch_ds,
|
| 42 |
+
batch_size=train_batch_size,
|
| 43 |
+
num_workers=0,
|
| 44 |
+
pin_memory=torch.cuda.is_available(),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# for validation
|
| 48 |
+
val_patch_ds = monai.data.GridPatchDataset(
|
| 49 |
+
data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
|
| 50 |
+
val_loader = DataLoader(
|
| 51 |
+
val_patch_ds, #val_volume_ds,
|
| 52 |
+
num_workers=0,
|
| 53 |
+
batch_size=val_batch_size,
|
| 54 |
+
pin_memory=torch.cuda.is_available())
|
| 55 |
+
|
| 56 |
+
if ifcheck:
|
| 57 |
+
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
|
| 58 |
+
return train_loader,val_loader
|
| 59 |
+
def myslicesloader(data_pelvis_path,
|
| 60 |
+
normalize='minmax',
|
| 61 |
+
pad='minimum',
|
| 62 |
+
train_number=1,
|
| 63 |
+
val_number=1,
|
| 64 |
+
train_batch_size=8,
|
| 65 |
+
val_batch_size=1,
|
| 66 |
+
saved_name_train='./train_ds_2d.csv',
|
| 67 |
+
saved_name_val='./val_ds_2d.csv',
|
| 68 |
+
resized_size=(512,512,None),
|
| 69 |
+
div_size=(16,16,None),
|
| 70 |
+
center_crop=20,
|
| 71 |
+
ifcheck_volume=True,
|
| 72 |
+
ifcheck_sclices=False,):
|
| 73 |
+
|
| 74 |
+
# volume-level transforms for both image and label
|
| 75 |
+
train_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='train',prob=0.8)
|
| 76 |
+
val_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='val')
|
| 77 |
+
train_ds, val_ds = get_file_list(data_pelvis_path,
|
| 78 |
+
train_number,
|
| 79 |
+
val_number)
|
| 80 |
+
train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop)
|
| 81 |
+
train_ds, val_ds = load_volumes(train_transforms, val_transforms,
|
| 82 |
+
train_crop_ds, val_crop_ds,
|
| 83 |
+
train_ds, val_ds,
|
| 84 |
+
saved_name_train, saved_name_val,
|
| 85 |
+
ifsave=True,
|
| 86 |
+
ifcheck=ifcheck_volume)
|
| 87 |
+
train_loader,val_loader = load_batch_slices(train_ds,
|
| 88 |
+
val_ds,
|
| 89 |
+
train_batch_size,
|
| 90 |
+
val_batch_size=val_batch_size,
|
| 91 |
+
window_width=1,
|
| 92 |
+
ifcheck=ifcheck_sclices)
|
| 93 |
+
return train_ds, val_ds, train_loader,val_loader,train_transforms,val_transforms
|
| 94 |
+
|
| 95 |
+
def len_patchloader(train_volume_ds,train_batch_size):
|
| 96 |
+
slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds)))
|
| 97 |
+
print('total slices in training set:',slice_number)
|
| 98 |
+
|
| 99 |
+
import math
|
| 100 |
+
batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
|
| 101 |
+
print('total batches in training set:',batch_number)
|
| 102 |
+
return slice_number,batch_number
|
| 103 |
+
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
dataset_path=r"F:\yang_Projects\Datasets\Task1\pelvis"
|
| 106 |
+
train_volume_ds,_,train_loader,_,_,_ = myslicesloader(dataset_path,
|
| 107 |
+
normalize='none',
|
| 108 |
+
train_number=2,
|
| 109 |
+
val_number=1,
|
| 110 |
+
train_batch_size=4,
|
| 111 |
+
val_batch_size=1,
|
| 112 |
+
saved_name_train='./train_ds_2d.csv',
|
| 113 |
+
saved_name_val='./val_ds_2d.csv',
|
| 114 |
+
resized_size=(512, 512, None),
|
| 115 |
+
div_size=(16,16,None),
|
| 116 |
+
ifcheck_volume=False,
|
| 117 |
+
ifcheck_sclices=False,)
|
| 118 |
+
from tqdm import tqdm
|
| 119 |
+
parameter_file=r'.\test.txt'
|
| 120 |
+
for data in tqdm(train_loader):
|
| 121 |
+
with open(parameter_file, 'a') as f:
|
| 122 |
+
f.write('image batch:' + str(data["image"].shape)+'\n')
|
| 123 |
+
f.write('label batch:' + str(data["label"].shape)+'\n')
|
| 124 |
+
f.write('\n')
|
dataprocesser/build_dataset.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class BaseDataLoader:
|
| 2 |
+
def __init__(self, configs, paths=None, dimension=2, **kwargs):
|
| 3 |
+
self.configs=configs
|
| 4 |
+
self.paths=paths
|
| 5 |
+
self.init_parameters_and_transforms()
|
| 6 |
+
self.get_loader()
|
| 7 |
+
#print('all files in dataset:',len(self.source_file_list))
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
self.rotation_level = kwargs.get('rotation_level', 0) # Default to no rotation (0)
|
| 11 |
+
self.zoom_level = kwargs.get('zoom_level', 1.0) # Default to no zoom (1.0)
|
| 12 |
+
self.flip = kwargs.get('flip', 0) # Default to no flip
|
| 13 |
+
|
| 14 |
+
self.create_dataset(dimension=dimension)
|
| 15 |
+
|
| 16 |
+
ifsave = None if paths is None else True
|
| 17 |
+
self.finalcheck(ifsave=ifsave,ifcheck=False,iftest_volumes_pixdim=False)
|
| 18 |
+
|
| 19 |
+
def get_loader(self):
|
| 20 |
+
self.source_file_list = []
|
| 21 |
+
self.train_ds=[]
|
| 22 |
+
self.val_ds=[]
|
dataprocesser/config_example.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name: 'ddpm2d_seg2med'
|
| 2 |
+
GPU_ID: [3]
|
| 3 |
+
ckpt_path: 'logs\241118ddpm_512.pt'
|
| 4 |
+
mode: 'test'
|
| 5 |
+
dataset:
|
| 6 |
+
train_csv: 'synthrad_conversion\datacsv\ct_synthrad_testrest_newserver.csv'
|
| 7 |
+
test_csv: 'synthrad_conversion\datacsv\ct_synthrad_testrest_newserver.csv'
|
| 8 |
+
batch_size: 1
|
| 9 |
+
val_batch_size: 8
|
| 10 |
+
normalize: 'scale2000'
|
| 11 |
+
zoom: (1.0,1.0,1.0)
|
| 12 |
+
resized_size: (512,512,None)
|
| 13 |
+
div_size: (None,None,None)
|
| 14 |
+
WINDOW_WIDTH: 2000
|
| 15 |
+
WINDOW_LEVEL: 0
|
| 16 |
+
|
| 17 |
+
train:
|
| 18 |
+
val_epoch_interval: 1
|
| 19 |
+
save_ckpt_interval: 1
|
| 20 |
+
num_epochs: 100
|
| 21 |
+
learning_rate: 0.0002
|
| 22 |
+
writeTensorboard: True
|
| 23 |
+
sample_range_lower: 0
|
| 24 |
+
sample_range_upper: 100000000
|
| 25 |
+
earlystopping_patience: 10
|
| 26 |
+
earlystopping_delta: 0.001
|
| 27 |
+
|
| 28 |
+
validation:
|
| 29 |
+
evaluate_restore_transforms: True
|
| 30 |
+
x_lower_limit: -1000
|
| 31 |
+
x_upper_limit: 3000
|
| 32 |
+
manual_aorta_diss: -1
|
| 33 |
+
ddpm:
|
| 34 |
+
num_train_timesteps: 500
|
| 35 |
+
num_inference_steps: 500
|
| 36 |
+
num_channels: (64, 128, 256, 256)
|
| 37 |
+
attention_levels: (False, False, False, True)
|
| 38 |
+
num_res_units: 2
|
| 39 |
+
norm_num_groups: 32
|
| 40 |
+
num_head_channels: 32
|
| 41 |
+
noise_type: 'normal'
|
| 42 |
+
|
| 43 |
+
|
dataprocesser/create_csv.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
from dataprocesser.dataset_anika import (
|
| 3 |
+
all_list_single_modality_from_anika_dataset_include_duplicate,
|
| 4 |
+
extract_patientID_from_Anika_dataset,
|
| 5 |
+
all_list_from_anika_dataset_include_duplicate)
|
| 6 |
+
from dataprocesser.dataset_synthrad import list_img_pID_from_synthrad_folder
|
| 7 |
+
from dataprocesser.dataset_anish import list_img_seg_ad_pIDs_from_anish_csv
|
| 8 |
+
from dataprocesser.dataset_dominik import all_list_from_dominik_dataset
|
| 9 |
+
from dataprocesser.step1_init_data_list import appart_img_and_seg, appart_merged_seg
|
| 10 |
+
from dataprocesser.step1_init_data_list import extract_patient_id
|
| 11 |
+
|
| 12 |
+
def create_csv_combine_lists_synthrad_anika_mr(synthrad_dir, anika_dir_mr, output_mr_csv_file, ifwrtiecsv=True):
|
| 13 |
+
#synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr_seg"], None)
|
| 14 |
+
seg_name_pattern = "mr_merged_seg" #r"^mr_merged_seg_\d{1}[A-Z]{2}\d{3}$"
|
| 15 |
+
synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, [seg_name_pattern], None)
|
| 16 |
+
synthrad_mr_list, _ = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr"], None)
|
| 17 |
+
synthrad_Aorta_diss = [0] * len(synthrad_seg_list)
|
| 18 |
+
datalist_synthrad = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(synthrad_pIDs, synthrad_Aorta_diss, synthrad_seg_list, synthrad_mr_list)]
|
| 19 |
+
|
| 20 |
+
mr_list = all_list_single_modality_from_anika_dataset_include_duplicate(anika_dir_mr)
|
| 21 |
+
mr_files, mr_seg_files = appart_img_and_seg(mr_list)
|
| 22 |
+
mr_seg_files = appart_merged_seg(mr_seg_files)
|
| 23 |
+
mr_pIDs = extract_patientID_from_Anika_dataset(mr_files)
|
| 24 |
+
|
| 25 |
+
mr_Aorta_diss = [0] * len(mr_files)
|
| 26 |
+
datalist_mr = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(mr_pIDs, mr_Aorta_diss, mr_seg_files, mr_files)]
|
| 27 |
+
|
| 28 |
+
print('length dataset 1: ', len(datalist_synthrad))
|
| 29 |
+
print('length dataset 2: ', len(datalist_mr))
|
| 30 |
+
dataset_list=datalist_synthrad+datalist_mr
|
| 31 |
+
if ifwrtiecsv:
|
| 32 |
+
create_csv_info_file(dataset_list, output_mr_csv_file)
|
| 33 |
+
return dataset_list
|
| 34 |
+
|
| 35 |
+
def create_csv_info_file(dataset_list, output_mr_csv_file):
|
| 36 |
+
with open(output_mr_csv_file, 'w', newline='') as f:
|
| 37 |
+
csvwriter = csv.writer(f)
|
| 38 |
+
csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
|
| 39 |
+
csvwriter.writerows(dataset_list)
|
| 40 |
+
|
| 41 |
+
def create_csv_synthrad_mr(synthrad_dir, output_csv_file):
|
| 42 |
+
synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr_merged_seg"], None)
|
| 43 |
+
synthrad_ct_list, _ = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr"], None)
|
| 44 |
+
synthrad_Aorta_diss = [0] * len(synthrad_seg_list)
|
| 45 |
+
datalist_synthrad = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(synthrad_pIDs, synthrad_Aorta_diss, synthrad_seg_list, synthrad_ct_list)]
|
| 46 |
+
|
| 47 |
+
print('length dataset 2: ', len(datalist_synthrad))
|
| 48 |
+
dataset_list=datalist_synthrad
|
| 49 |
+
create_csv_info_file(dataset_list, output_csv_file)
|
| 50 |
+
|
| 51 |
+
def create_csv_combine_lists_synthrad_anish(synthrad_dir, anish_csv, output_csv_file):
|
| 52 |
+
synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, ["ct_seg"], None)
|
| 53 |
+
synthrad_ct_list, _ = list_img_pID_from_synthrad_folder(synthrad_dir, ["ct"], None)
|
| 54 |
+
synthrad_Aorta_diss = [0] * len(synthrad_seg_list)
|
| 55 |
+
|
| 56 |
+
#anish_pIDs, anish_Aorta_diss, anish_seg_list, anish_ct_list = list_img_seg_ad_pIDs_from_new_simplified_csv(anish_csv)
|
| 57 |
+
anish_pIDs, anish_Aorta_diss, anish_seg_list, anish_ct_list = list_img_seg_ad_pIDs_from_anish_csv(anish_csv)
|
| 58 |
+
datalist_synthrad = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(synthrad_pIDs, synthrad_Aorta_diss, synthrad_seg_list, synthrad_ct_list)]
|
| 59 |
+
datalist_anish = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(anish_pIDs, anish_Aorta_diss, anish_seg_list, anish_ct_list)]
|
| 60 |
+
|
| 61 |
+
print('length dataset 1: ', len(synthrad_ct_list))
|
| 62 |
+
print('length dataset 2: ', len(datalist_synthrad))
|
| 63 |
+
dataset_list=datalist_synthrad+datalist_anish
|
| 64 |
+
create_csv_info_file(dataset_list, output_csv_file)
|
| 65 |
+
|
| 66 |
+
def create_csv_Anika(ct_dir, mri_dir, output_ct_csv_file, output_mr_csv_file):
|
| 67 |
+
ct_list, mr_list = all_list_from_anika_dataset_include_duplicate(ct_dir, mri_dir)
|
| 68 |
+
ct_files, ct_seg_files = appart_img_and_seg(ct_list)
|
| 69 |
+
ct_pIDs = extract_patientID_from_Anika_dataset(ct_files)
|
| 70 |
+
ct_Aorta_diss = [0] * len(ct_list)
|
| 71 |
+
datalist_ct = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(ct_pIDs, ct_Aorta_diss, ct_seg_files, ct_files)]
|
| 72 |
+
create_csv_info_file(datalist_ct, output_ct_csv_file)
|
| 73 |
+
|
| 74 |
+
mr_files, mr_seg_files = appart_img_and_seg(mr_list)
|
| 75 |
+
mr_pIDs = extract_patientID_from_Anika_dataset(mr_files)
|
| 76 |
+
mr_Aorta_diss = [0] * len(mr_files)
|
| 77 |
+
datalist_mr = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(mr_pIDs, mr_Aorta_diss, mr_seg_files, mr_files)]
|
| 78 |
+
create_csv_info_file(datalist_mr, output_mr_csv_file)
|
| 79 |
+
|
| 80 |
+
def create_csv_Dominik(mri_dir, output_mr_csv_file):
|
| 81 |
+
mr_list = all_list_from_dominik_dataset(mri_dir)
|
| 82 |
+
mr_files, mr_seg_files = appart_img_and_seg(mr_list)
|
| 83 |
+
mr_seg_files = appart_merged_seg(mr_seg_files)
|
| 84 |
+
mr_pIDs = [extract_patient_id(mr_file) for mr_file in mr_files]
|
| 85 |
+
mr_Aorta_diss = [0] * len(mr_files)
|
| 86 |
+
datalist_mr = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(mr_pIDs, mr_Aorta_diss, mr_seg_files, mr_files)]
|
| 87 |
+
create_csv_info_file(datalist_mr, output_mr_csv_file)
|
dataprocesser/create_csv_xcat.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
|
| 4 |
+
def extract_prefixes_from_directory(directory):
|
| 5 |
+
prefixes = set()
|
| 6 |
+
for filename in os.listdir(directory):
|
| 7 |
+
if filename.endswith('.nrrd'):
|
| 8 |
+
prefix = filename.split('_')[0]
|
| 9 |
+
prefixes.add(prefix)
|
| 10 |
+
return sorted(prefixes)
|
| 11 |
+
|
| 12 |
+
def save_prefixes_to_csv(prefixes, output_csv_path):
|
| 13 |
+
with open(output_csv_path, mode='w', newline='') as file:
|
| 14 |
+
writer = csv.writer(file)
|
| 15 |
+
for prefix in prefixes:
|
| 16 |
+
writer.writerow([os.path.join(directory, prefix)])
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
directory = r"F:\yang_Projects\ICTUNET_torch\datasets\train"
|
| 20 |
+
output_csv_path = r"F:\yang_Projects\ICTUNET_torch\data_table\train_all.csv"
|
| 21 |
+
|
| 22 |
+
prefixes = extract_prefixes_from_directory(directory)
|
| 23 |
+
save_prefixes_to_csv(prefixes, output_csv_path)
|
| 24 |
+
|
| 25 |
+
print(f"CSV file with prefixes saved to: {output_csv_path}")
|
dataprocesser/create_json_lodopab.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from configs import config as cfg
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
VERBOSE = cfg.verbose
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
IMG_EXTENSIONS = [
|
| 11 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 12 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 13 |
+
'.nrrd', '.nii.gz',
|
| 14 |
+
'.hdf5',
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
def is_image_file(filename):
|
| 18 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 19 |
+
|
| 20 |
+
def create_metadata_jsonl_lodopab(base_path, mode='train', output_json_file= 'lodopab_dataset.json'):
|
| 21 |
+
ground_truth_path = os.path.join(base_path, 'ground_truth_'+mode)
|
| 22 |
+
observation_path = os.path.join(base_path, 'observation_'+mode)
|
| 23 |
+
|
| 24 |
+
# Initialize dataset list
|
| 25 |
+
dataset_list = []
|
| 26 |
+
|
| 27 |
+
# Iterate through the ground truth files
|
| 28 |
+
for gt_file in os.listdir(ground_truth_path):
|
| 29 |
+
if is_image_file(gt_file):
|
| 30 |
+
# Get the corresponding observation file
|
| 31 |
+
obs_file = gt_file.replace('ground_truth', 'observation')
|
| 32 |
+
|
| 33 |
+
# Create the entry
|
| 34 |
+
entry = {
|
| 35 |
+
'ground_truth': os.path.join(ground_truth_path, gt_file),
|
| 36 |
+
'observation': os.path.join(observation_path, obs_file)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Append to the dataset list
|
| 40 |
+
dataset_list.append(entry)
|
| 41 |
+
|
| 42 |
+
# Save the dataset list as a JSON file
|
| 43 |
+
with open(output_json_file, 'w') as json_file:
|
| 44 |
+
json.dump(dataset_list, json_file, indent=4)
|
| 45 |
+
|
| 46 |
+
print(f'Dataset list saved to lodopab_dataset.json with {len(dataset_list)} entries.')
|
| 47 |
+
|
| 48 |
+
def read_metadata_jsonl(file_path):
|
| 49 |
+
with open(file_path, 'r') as f:
|
| 50 |
+
dataset = json.load(f)
|
| 51 |
+
return dataset
|
| 52 |
+
|
| 53 |
+
def print_json_info(data_info):
|
| 54 |
+
for entry in tqdm(data_info, desc="Calculating slice info"):
|
| 55 |
+
print(entry['patient_name'])
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
base_path = r"F:\yang_Projects\Datasets\LoDoPaB"
|
| 59 |
+
create_metadata_jsonl_lodopab(base_path, mode='train', output_json_file= './data_table/lodopab_dataset.json')
|
dataprocesser/create_json_xcat.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from configs import config as cfg
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
VERBOSE = cfg.verbose
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
IMG_EXTENSIONS = [
|
| 11 |
+
#'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 12 |
+
#'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 13 |
+
'.nrrd', '.nii.gz',
|
| 14 |
+
'.hdf5',
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
def is_image_file(filename):
|
| 18 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 19 |
+
|
| 20 |
+
def create_metadata_jsonl_xcat(base_path,
|
| 21 |
+
mode='train',
|
| 22 |
+
sino_entry = "_sino_Metal.nrrd",
|
| 23 |
+
img_entry = "_img_GT_noNoise.nrrd",
|
| 24 |
+
output_json_file= 'xcat_dataset.json'):
|
| 25 |
+
|
| 26 |
+
train_set_path = os.path.join(base_path, mode)
|
| 27 |
+
# Initialize dataset list
|
| 28 |
+
dataset_list = []
|
| 29 |
+
prefixes = set()
|
| 30 |
+
for filename in os.listdir(train_set_path):
|
| 31 |
+
if is_image_file(filename):
|
| 32 |
+
prefix = filename.split('_')[0]
|
| 33 |
+
prefixes.add(prefix)
|
| 34 |
+
prefixes = sorted(prefixes)
|
| 35 |
+
|
| 36 |
+
for prefix in prefixes:
|
| 37 |
+
sino_path = os.path.join(train_set_path, prefix + sino_entry)
|
| 38 |
+
img_path = os.path.join(train_set_path, prefix + img_entry)
|
| 39 |
+
|
| 40 |
+
# Create the entry
|
| 41 |
+
entry = {
|
| 42 |
+
'ground_truth': img_path,
|
| 43 |
+
'observation': sino_path
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Append to the dataset list
|
| 47 |
+
dataset_list.append(entry)
|
| 48 |
+
|
| 49 |
+
# Save the dataset list as a JSON file
|
| 50 |
+
with open(output_json_file, 'w') as json_file:
|
| 51 |
+
json.dump(dataset_list, json_file, indent=4)
|
| 52 |
+
|
| 53 |
+
print(f'Dataset list saved to xcat_dataset.json with {len(dataset_list)} entries.')
|
| 54 |
+
|
| 55 |
+
def read_metadata_jsonl(file_path):
|
| 56 |
+
with open(file_path, 'r') as f:
|
| 57 |
+
dataset = json.load(f)
|
| 58 |
+
return dataset
|
| 59 |
+
|
| 60 |
+
def print_json_info(data_info):
|
| 61 |
+
for entry in tqdm(data_info, desc="Calculating slice info"):
|
| 62 |
+
print(entry['patient_name'])
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
base_path = r"F:\yang_Projects\ICTUNET_torch\datasets"
|
| 66 |
+
create_metadata_jsonl_xcat(base_path,
|
| 67 |
+
mode='train',
|
| 68 |
+
sino_entry = "_sino_Metal.nrrd",
|
| 69 |
+
img_entry = "_img_GT_noNoise.nrrd",
|
| 70 |
+
output_json_file= './data_table/xcat_dataset.json')
|
dataprocesser/customized_datasets.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.utils.data as data
|
| 2 |
+
import nibabel as nib
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
VERBOSE = False
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def volume_slicer(volume_tensor, transform, all_slices=None):
|
| 15 |
+
# Convert numpy array to PyTorch tensor
|
| 16 |
+
# Note: You might need to add channel dimension or perform other adjustments
|
| 17 |
+
volume_tensor = volume_tensor.permute(2, 1, 0) # [H, W, D] -> [D, H, W]
|
| 18 |
+
volume_tensor = volume_tensor.unsqueeze(1) # Add channel dimension [D, H, W] -> [D, 1, H, W]
|
| 19 |
+
if transform is not None:
|
| 20 |
+
volume_tensor = transform(volume_tensor)
|
| 21 |
+
|
| 22 |
+
#print('stacking volume tensor:',volume_tensor.shape)
|
| 23 |
+
if all_slices is None:
|
| 24 |
+
all_slices = volume_tensor
|
| 25 |
+
else:
|
| 26 |
+
all_slices = torch.cat((all_slices, volume_tensor), 0)
|
| 27 |
+
return all_slices
|
| 28 |
+
|
| 29 |
+
def infinite_loader(loader):
|
| 30 |
+
"""Yield batches indefinitely from a DataLoader."""
|
| 31 |
+
while True:
|
| 32 |
+
for batch in loader:
|
| 33 |
+
yield batch
|
| 34 |
+
# This explicitly resets the iterator
|
| 35 |
+
loader.dataset.reset()
|
| 36 |
+
|
| 37 |
+
class csvDataset_3D(Dataset):
|
| 38 |
+
def __init__(self, csv_file, transform=None, load_patient_number=1):
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
csv_file (string): Path to the csv file with annotations.
|
| 42 |
+
transform (callable, optional): Optional transform to be applied on a sample.
|
| 43 |
+
"""
|
| 44 |
+
self.data_frame = pd.read_csv(csv_file)
|
| 45 |
+
# control the length of the dataset
|
| 46 |
+
self.data_frame = self.data_frame[:load_patient_number]
|
| 47 |
+
self.transform = transform
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.data_frame)
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
if torch.is_tensor(idx):
|
| 54 |
+
idx = idx.tolist()
|
| 55 |
+
|
| 56 |
+
img_path = self.data_frame.iloc[idx, -1]
|
| 57 |
+
image = nib.load(img_path).get_fdata()
|
| 58 |
+
image = torch.tensor(image, dtype=torch.float32)
|
| 59 |
+
|
| 60 |
+
# Example: Using the 'Aorta_diss' column as a label
|
| 61 |
+
label = self.data_frame.iloc[idx, -3]
|
| 62 |
+
#label = torch.tensor(label, dtype=torch.float32)
|
| 63 |
+
|
| 64 |
+
# If more processing is needed (e.g., normalization, adding channel dimension), do it here
|
| 65 |
+
image = image.unsqueeze(0) # Add channel dimension if it's a single channel image
|
| 66 |
+
|
| 67 |
+
sample = {'image': image, 'label': label}
|
| 68 |
+
|
| 69 |
+
return sample
|
| 70 |
+
|
| 71 |
+
class csvDataset_2D(Dataset):
|
| 72 |
+
def __init__(self, csv_file, transform=None, load_patient_number=1):
|
| 73 |
+
self.csv_file = csv_file
|
| 74 |
+
self.transform = transform
|
| 75 |
+
self.load_patient_number = load_patient_number
|
| 76 |
+
self.data_frame = pd.read_csv(csv_file)
|
| 77 |
+
if len(self.data_frame) == 0:
|
| 78 |
+
raise RuntimeError(f"Found 0 images in: {csv_file}")
|
| 79 |
+
|
| 80 |
+
# Initialize dataset
|
| 81 |
+
self.initialize_dataset()
|
| 82 |
+
|
| 83 |
+
def initialize_dataset(self):
|
| 84 |
+
print('Loading dataset...')
|
| 85 |
+
self.data_frame = self.data_frame[:self.load_patient_number]
|
| 86 |
+
all_slices = None
|
| 87 |
+
all_labels = []
|
| 88 |
+
|
| 89 |
+
for idx in tqdm(range(len(self.data_frame))):
|
| 90 |
+
img_path = self.data_frame.iloc[idx, -1]
|
| 91 |
+
volume = nib.load(img_path)
|
| 92 |
+
volume_data = volume.get_fdata() # Load as [H, W, D]
|
| 93 |
+
volume_tensor = torch.tensor(volume_data, dtype=torch.float32)
|
| 94 |
+
all_slices = volume_slicer(volume_tensor, self.transform, all_slices) # -> [D, 1, H, W] and pile up all the slices
|
| 95 |
+
label = self.data_frame.iloc[idx, -3]
|
| 96 |
+
all_labels = all_labels + [label] * volume_tensor.shape[0]
|
| 97 |
+
|
| 98 |
+
print('All stacked slices:', all_slices.shape)
|
| 99 |
+
self.all_slices = all_slices
|
| 100 |
+
self.all_labels = all_labels
|
| 101 |
+
|
| 102 |
+
def __len__(self):
|
| 103 |
+
return self.all_slices.shape[0]
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, idx):
|
| 106 |
+
if torch.is_tensor(idx):
|
| 107 |
+
idx = idx.tolist()
|
| 108 |
+
image = self.all_slices[idx]
|
| 109 |
+
label = self.all_labels[idx]
|
| 110 |
+
sample = {'source': image, 'target': label}
|
| 111 |
+
return sample
|
| 112 |
+
|
| 113 |
+
def reset(self):
|
| 114 |
+
print('Resetting dataset...')
|
| 115 |
+
self.initialize_dataset()
|
dataprocesser/customized_normalization.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nibabel as nib
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.interpolate import interp1d
|
| 4 |
+
|
| 5 |
+
def nyul_apply_standard_scale(input_image,
|
| 6 |
+
standard_hist,
|
| 7 |
+
input_mask=None,
|
| 8 |
+
interp_type='linear'):
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
Based on J.Reinhold code:
|
| 12 |
+
https://github.com/jcreinhold/intensity-normalization
|
| 13 |
+
|
| 14 |
+
Use Nyul and Udupa method ([1,2]) to normalize the intensities
|
| 15 |
+
of a MRI image passed as input.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
input_image (np.ndarray): input image to normalize
|
| 19 |
+
standard_hist (str): path to output or use standard histogram landmarks
|
| 20 |
+
input_mask (nii): optional brain mask
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
normalized (np.ndarray): normalized input image
|
| 24 |
+
|
| 25 |
+
References:
|
| 26 |
+
[1] N. Laszlo G and J. K. Udupa, “On Standardizing the MR Image
|
| 27 |
+
Intensity Scale,” Magn. Reson. Med., vol. 42, pp. 1072–1081,
|
| 28 |
+
1999.
|
| 29 |
+
[2] M. Shah, Y. Xiao, N. Subbanna, S. Francis, D. L. Arnold,
|
| 30 |
+
D. L. Collins, and T. Arbel, “Evaluating intensity
|
| 31 |
+
normalization on MRIs of human brain with multiple sclerosis,”
|
| 32 |
+
Med. Image Anal., vol. 15, no. 2, pp. 267–282, 2011.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# load learned standard scale and the percentiles
|
| 36 |
+
standard_scale, percs = np.load(standard_hist)
|
| 37 |
+
|
| 38 |
+
# apply transformation to image
|
| 39 |
+
return do_hist_normalization(input_image,
|
| 40 |
+
percs,
|
| 41 |
+
standard_scale,
|
| 42 |
+
input_mask,
|
| 43 |
+
interp_type=interp_type)
|
| 44 |
+
|
| 45 |
+
def do_hist_normalization(input_image,
|
| 46 |
+
landmark_percs,
|
| 47 |
+
standard_scale,
|
| 48 |
+
mask=None,
|
| 49 |
+
interp_type='linear'):
|
| 50 |
+
"""
|
| 51 |
+
do the Nyul and Udupa histogram normalization routine with a given set of
|
| 52 |
+
learned landmarks
|
| 53 |
+
|
| 54 |
+
Based on J.Reinhold code:
|
| 55 |
+
https://github.com/jcreinhold/intensity-normalization
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
input_image (np.ndarray): image on which to find landmarks
|
| 59 |
+
landmark_percs (np.ndarray): corresponding landmark points of standard scale
|
| 60 |
+
standard_scale (np.ndarray): landmarks on the standard scale
|
| 61 |
+
mask (np.ndarray): foreground mask for img
|
| 62 |
+
interp_type (str): type of interpolation
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
normalized (np.ndarray): normalized image
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
mask_data = input_image > input_image.mean() if mask is None else mask
|
| 69 |
+
masked = input_image[mask_data > 0] # extract only part of image where mask is non-emtpy
|
| 70 |
+
landmarks = get_landmarks(masked, landmark_percs)
|
| 71 |
+
|
| 72 |
+
f = interp1d(landmarks, standard_scale, kind=interp_type, fill_value='extrapolate') # define interpolating function
|
| 73 |
+
|
| 74 |
+
# apply transformation to input image
|
| 75 |
+
return f(input_image)
|
| 76 |
+
|
| 77 |
+
def get_landmarks(img, percs):
|
| 78 |
+
"""
|
| 79 |
+
get the landmarks for the Nyul and Udupa norm method for a specific image
|
| 80 |
+
|
| 81 |
+
Based on J.Reinhold code:
|
| 82 |
+
https://github.com/jcreinhold/intensity-normalization
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
img (nibabel.nifti1.Nifti1Image): image on which to find landmarks
|
| 86 |
+
percs (np.ndarray): corresponding landmark percentiles to extract
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
landmarks (np.ndarray): intensity values corresponding to percs in img
|
| 90 |
+
"""
|
| 91 |
+
landmarks = np.percentile(img, percs)
|
| 92 |
+
return landmarks
|
| 93 |
+
|
| 94 |
+
def nyul_train_standard_scale(img_fns,
|
| 95 |
+
mask_fns=None,
|
| 96 |
+
i_min=1,
|
| 97 |
+
i_max=99,
|
| 98 |
+
i_s_min=1,
|
| 99 |
+
i_s_max=100,
|
| 100 |
+
l_percentile=10,
|
| 101 |
+
u_percentile=90,
|
| 102 |
+
step=10):
|
| 103 |
+
"""
|
| 104 |
+
determine the standard scale for the set of images
|
| 105 |
+
|
| 106 |
+
Based on J.Reinhold code:
|
| 107 |
+
https://github.com/jcreinhold/intensity-normalization
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
img_fns (list): set of NifTI MR image paths which are to be normalized
|
| 112 |
+
mask_fns (list): set of corresponding masks (if not provided, estimated)
|
| 113 |
+
i_min (float): minimum percentile to consider in the images
|
| 114 |
+
i_max (float): maximum percentile to consider in the images
|
| 115 |
+
i_s_min (float): minimum percentile on the standard scale
|
| 116 |
+
i_s_max (float): maximum percentile on the standard scale
|
| 117 |
+
l_percentile (int): middle percentile lower bound (e.g., for deciles 10)
|
| 118 |
+
u_percentile (int): middle percentile upper bound (e.g., for deciles 90)
|
| 119 |
+
step (int): step for middle percentiles (e.g., for deciles 10)
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
standard_scale (np.ndarray): average landmark intensity for images
|
| 123 |
+
percs (np.ndarray): array of all percentiles used
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
# compute masks is those are not entered as a parameters
|
| 127 |
+
mask_fns = [None] * len(img_fns) if mask_fns is None else mask_fns
|
| 128 |
+
|
| 129 |
+
percs = np.concatenate(([i_min],
|
| 130 |
+
np.arange(l_percentile, u_percentile+1, step),
|
| 131 |
+
[i_max]))
|
| 132 |
+
standard_scale = np.zeros(len(percs))
|
| 133 |
+
|
| 134 |
+
# process each image in order to build the standard scale
|
| 135 |
+
for i, (img_fn, mask_fn) in enumerate(zip(img_fns, mask_fns)):
|
| 136 |
+
print('processing scan ', img_fn)
|
| 137 |
+
img_data = nib.load(img_fn).get_data() # extract image as numpy array
|
| 138 |
+
mask = nib.load(mask_fn) if mask_fn is not None else None # load mask as nibabel object
|
| 139 |
+
mask_data = img_data > img_data.mean() \
|
| 140 |
+
if mask is None else mask.get_data() # extract mask as numpy array
|
| 141 |
+
masked = img_data[mask_data > 0] # extract only part of image where mask is non-emtpy
|
| 142 |
+
landmarks = get_landmarks(masked, percs)
|
| 143 |
+
min_p = np.percentile(masked, i_min)
|
| 144 |
+
max_p = np.percentile(masked, i_max)
|
| 145 |
+
f = interp1d([min_p, max_p], [i_s_min, i_s_max]) # create interpolating function
|
| 146 |
+
landmarks = np.array(f(landmarks)) # interpolate landmarks
|
| 147 |
+
standard_scale += landmarks # add landmark values of this volume to standard_scale
|
| 148 |
+
standard_scale = standard_scale / len(img_fns) # get mean values
|
| 149 |
+
return standard_scale, percs
|
dataprocesser/customized_transform_list.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataprocesser.customized_transforms import (
|
| 2 |
+
CreateBodyContourTransformd,
|
| 3 |
+
MergeMasksTransformd,
|
| 4 |
+
UseContourToFilterImaged,
|
| 5 |
+
MaskHUAssigmentd,
|
| 6 |
+
MergeSegTissueTransformd,
|
| 7 |
+
NormalizationMultimodal,
|
| 8 |
+
CreateMaskWithBonesTransformd)
|
| 9 |
+
from monai.transforms import (
|
| 10 |
+
Compose,
|
| 11 |
+
LoadImaged,
|
| 12 |
+
EnsureChannelFirstd,
|
| 13 |
+
SqueezeDimd,
|
| 14 |
+
CenterSpatialCropd,
|
| 15 |
+
Rotate90d,
|
| 16 |
+
ScaleIntensityd,
|
| 17 |
+
ResizeWithPadOrCropd,
|
| 18 |
+
DivisiblePadd,
|
| 19 |
+
Zoomd,
|
| 20 |
+
ThresholdIntensityd,
|
| 21 |
+
NormalizeIntensityd,
|
| 22 |
+
ShiftIntensityd,
|
| 23 |
+
Identityd,
|
| 24 |
+
ScaleIntensityRanged,
|
| 25 |
+
Spacingd,
|
| 26 |
+
SaveImage,
|
| 27 |
+
)
|
| 28 |
+
## intensity transforms
|
| 29 |
+
def add_normalization_transform_single_B(transform_list, indicator_B, normalize):
|
| 30 |
+
if normalize=='zscore':
|
| 31 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_B], nonzero=False, channel_wise=True))
|
| 32 |
+
print('zscore normalization')
|
| 33 |
+
elif normalize=='minmax':
|
| 34 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=-1.0, maxv=1.0))
|
| 35 |
+
print('minmax normalization')
|
| 36 |
+
|
| 37 |
+
elif normalize=='scale1000_wrongbutworks':
|
| 38 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
|
| 39 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
|
| 40 |
+
print('scale1000 normalization')
|
| 41 |
+
|
| 42 |
+
elif normalize=='scale4000':
|
| 43 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975))
|
| 44 |
+
print('scale4000 normalization')
|
| 45 |
+
|
| 46 |
+
elif normalize=='scale2000':
|
| 47 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.9995))
|
| 48 |
+
print('scale2000 normalization')
|
| 49 |
+
|
| 50 |
+
elif normalize=='scale1000':
|
| 51 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
|
| 52 |
+
print('scale1000 normalization')
|
| 53 |
+
|
| 54 |
+
elif normalize=='scale100':
|
| 55 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.99))
|
| 56 |
+
print('scale10 normalization')
|
| 57 |
+
|
| 58 |
+
elif normalize=='scale10':
|
| 59 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
|
| 60 |
+
print('scale10 normalization')
|
| 61 |
+
|
| 62 |
+
elif normalize == 'nonegative':
|
| 63 |
+
offset=1000
|
| 64 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=offset))
|
| 65 |
+
print('none negative normalization')
|
| 66 |
+
|
| 67 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 68 |
+
print('no normalization')
|
| 69 |
+
|
| 70 |
+
return transform_list
|
| 71 |
+
|
| 72 |
+
def add_normalization_multimodal(transform_list, indicator_A, indicator_B):
|
| 73 |
+
transform_list.append(NormalizationMultimodal(keys=[indicator_A,indicator_B]))
|
| 74 |
+
return transform_list
|
| 75 |
+
|
| 76 |
+
def add_normalization_transform_A_B(transform_list, normalize, indicator_A, indicator_B):
|
| 77 |
+
if normalize=='zscore':
|
| 78 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A,indicator_B], nonzero=False, channel_wise=True))
|
| 79 |
+
print('zscore normalization')
|
| 80 |
+
elif normalize=='scale2000':
|
| 81 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A,indicator_B], minv=None, maxv=None, factor=-0.9995))
|
| 82 |
+
print('scale2000 normalization')
|
| 83 |
+
elif normalize=='none' or normalize=='nonorm':
|
| 84 |
+
print('no normalization')
|
| 85 |
+
return transform_list
|
| 86 |
+
|
| 87 |
+
def add_normalization_transform_input_only(transform_list, indicator_A, normalize):
|
| 88 |
+
if normalize=='inputonlyzscore':
|
| 89 |
+
transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
|
| 90 |
+
print('only normalize input MRI images')
|
| 91 |
+
|
| 92 |
+
elif normalize=='inputonlyminmax':
|
| 93 |
+
normmin=0
|
| 94 |
+
normmax=1
|
| 95 |
+
transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=normmin, maxv=normmax))
|
| 96 |
+
print('only normalize input MRI images')
|
| 97 |
+
|
| 98 |
+
def add_CreateContour_MergeMask_transforms(transform_list, indicator_A):
|
| 99 |
+
transform_list.append(CreateBodyContourTransformd(keys=['mask'],
|
| 100 |
+
body_threshold=-500,
|
| 101 |
+
body_mask_value=1,
|
| 102 |
+
))
|
| 103 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
|
| 104 |
+
return transform_list
|
| 105 |
+
|
| 106 |
+
def add_CreateContour_MergeMask_MaskHUAssign_transforms(transform_list, indicator_A, anatomy_list_csv):
|
| 107 |
+
transform_list.append(CreateBodyContourTransformd(keys=['mask'],
|
| 108 |
+
body_threshold=-500,
|
| 109 |
+
body_mask_value=1,
|
| 110 |
+
)) # image -> contour
|
| 111 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask'])) # seg+contour -> seg
|
| 112 |
+
transform_list.append(MaskHUAssigmentd(keys=[indicator_A], csv_file=anatomy_list_csv))
|
| 113 |
+
return transform_list
|
| 114 |
+
|
| 115 |
+
def add_CreateContour_MergeSegTissue_MergeMask_MaskHUAssign_transforms(transform_list, indicator_A, anatomy_list_csv, anatomy_list_csv_mr):
|
| 116 |
+
transform_list.append(CreateBodyContourTransformd(keys=['mask'],
|
| 117 |
+
body_threshold=-500,
|
| 118 |
+
body_mask_value=1,
|
| 119 |
+
)) # image -> contour
|
| 120 |
+
transform_list.append(MergeSegTissueTransformd(keys=[indicator_A, 'seg_tissue'])) # seg+seg_tissue -> seg
|
| 121 |
+
transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask'])) # seg+contour -> seg
|
| 122 |
+
transform_list.append(MaskHUAssigmentd(keys=[indicator_A], csv_file=anatomy_list_csv))
|
| 123 |
+
return transform_list
|
| 124 |
+
|
| 125 |
+
def add_Windowing_ZeroShift_ContourFilter_A_B_transforms(transform_list, WINDOW_LEVEL, WINDOW_WIDTH, indicator_A, indicator_B):
|
| 126 |
+
threshold_low=WINDOW_LEVEL - WINDOW_WIDTH / 2
|
| 127 |
+
threshold_high=WINDOW_LEVEL + WINDOW_WIDTH / 2
|
| 128 |
+
offset=(-1)*threshold_low
|
| 129 |
+
# if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
|
| 130 |
+
# mask = img > self.threshold if self.above else img < self.threshold
|
| 131 |
+
# res = where(mask, img, self.cval)
|
| 132 |
+
transform_list.append(ThresholdIntensityd(keys=[indicator_A,indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
|
| 133 |
+
transform_list.append(ThresholdIntensityd(keys=[indicator_A,indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
|
| 134 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_A,indicator_B], offset=offset))
|
| 135 |
+
transform_list.append(UseContourToFilterImaged(keys=[indicator_B, 'mask'])) # image*contour -> image
|
| 136 |
+
return transform_list
|
| 137 |
+
|
| 138 |
+
def add_Windowing_ZeroShift_ContourFilter_single_B_transforms(transform_list, WINDOW_LEVEL, WINDOW_WIDTH, indicator_B):
|
| 139 |
+
threshold_low=WINDOW_LEVEL - WINDOW_WIDTH / 2
|
| 140 |
+
threshold_high=WINDOW_LEVEL + WINDOW_WIDTH / 2
|
| 141 |
+
offset=(-1)*threshold_low
|
| 142 |
+
# if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
|
| 143 |
+
# mask = img > self.threshold if self.above else img < self.threshold
|
| 144 |
+
# res = where(mask, img``, self.cval)
|
| 145 |
+
transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
|
| 146 |
+
transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
|
| 147 |
+
transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=offset))
|
| 148 |
+
transform_list.append(UseContourToFilterImaged(keys=[indicator_B, 'mask'])) # image*contour -> image
|
| 149 |
+
return transform_list
|
dataprocesser/customized_transforms.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
VERBOSE = False
|
| 7 |
+
|
| 8 |
+
def get_data_scaler(config):
|
| 9 |
+
"""Data normalizer. Assume data are always in [0, 1]."""
|
| 10 |
+
if config.data.centered:
|
| 11 |
+
# Rescale to [-1, 1]
|
| 12 |
+
return lambda x: x * 2. - 1.
|
| 13 |
+
else:
|
| 14 |
+
return lambda x: x
|
| 15 |
+
|
| 16 |
+
def get_data_inverse_scaler(config):
|
| 17 |
+
"""Inverse data normalizer."""
|
| 18 |
+
if config.data.centered:
|
| 19 |
+
# Rescale [-1, 1] to [0, 1]
|
| 20 |
+
return lambda x: (x + 1.) / 2.
|
| 21 |
+
else:
|
| 22 |
+
return lambda x: x
|
| 23 |
+
|
| 24 |
+
def separate_maps(real_images,
|
| 25 |
+
tissue_min, tissue_max,
|
| 26 |
+
bone_min, bone_max):
|
| 27 |
+
mask = torch.zeros_like(real_images)
|
| 28 |
+
# Assign label 1 to tissue regions
|
| 29 |
+
mask[(real_images > tissue_min) & (real_images <= tissue_max)] = 1
|
| 30 |
+
# Assign label 2 to bone regions
|
| 31 |
+
mask[(real_images >= bone_min) & (real_images <= bone_max)] = 2
|
| 32 |
+
return mask
|
| 33 |
+
|
| 34 |
+
def create_body_contour_old(tensor_img, body_threshold=-500):
|
| 35 |
+
"""
|
| 36 |
+
Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
|
| 37 |
+
There would be problem if more body parts are presented (like two arms)
|
| 38 |
+
Args:
|
| 39 |
+
tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
|
| 43 |
+
"""
|
| 44 |
+
# Convert tensor to numpy array
|
| 45 |
+
numpy_img = tensor_img.numpy().astype(np.int16) # Ensure we can handle negative values correctly
|
| 46 |
+
|
| 47 |
+
# Threshold the image at -500 to separate potential body from the background
|
| 48 |
+
binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
|
| 49 |
+
#print(binary_img.shape)
|
| 50 |
+
#print(binary_img)
|
| 51 |
+
# Find contours from the binary image
|
| 52 |
+
contours, _ = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 53 |
+
# Create an empty mask and fill the largest contour
|
| 54 |
+
mask = np.zeros_like(binary_img)
|
| 55 |
+
if contours:
|
| 56 |
+
# Assume the largest contour is the body contour
|
| 57 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 58 |
+
cv2.drawContours(mask, [largest_contour], -1, 1, thickness=cv2.FILLED)
|
| 59 |
+
|
| 60 |
+
# Convert the mask back to a tensor
|
| 61 |
+
mask_tensor = torch.tensor(mask, dtype=torch.int32)
|
| 62 |
+
|
| 63 |
+
return mask_tensor
|
| 64 |
+
|
| 65 |
+
def create_body_contour(tensor_img, body_threshold=-500, min_contour_area=10000):
|
| 66 |
+
"""
|
| 67 |
+
Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
|
| 68 |
+
Solve problem that more body parts are presented (like two arms)
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
|
| 75 |
+
"""
|
| 76 |
+
# Convert tensor to numpy array
|
| 77 |
+
if isinstance(tensor_img, torch.Tensor):
|
| 78 |
+
numpy_img = tensor_img.numpy().astype(np.int16) # Ensure we can handle negative values correctly
|
| 79 |
+
elif isinstance(tensor_img, np.ndarray):
|
| 80 |
+
numpy_img = np.ascontiguousarray(tensor_img.astype(np.int16))
|
| 81 |
+
else:
|
| 82 |
+
print("This is not a PyTorch tensor or a NumPy array. Please Check!")
|
| 83 |
+
# Threshold the image at -500 to separate potential body from the background
|
| 84 |
+
binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
|
| 85 |
+
|
| 86 |
+
# Find contours from the binary image
|
| 87 |
+
contours, _ = cv2.findContours(binary_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
| 88 |
+
|
| 89 |
+
# Create an empty mask
|
| 90 |
+
mask = np.zeros_like(binary_img)
|
| 91 |
+
|
| 92 |
+
# Fill all detected body contours
|
| 93 |
+
if contours:
|
| 94 |
+
for contour in contours:
|
| 95 |
+
if cv2.contourArea(contour) >= min_contour_area:
|
| 96 |
+
if VERBOSE:
|
| 97 |
+
print('current contour area: ', cv2.contourArea(contour), 'threshold: ', min_contour_area)
|
| 98 |
+
cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
|
| 99 |
+
|
| 100 |
+
# Convert the mask back to a tensor
|
| 101 |
+
mask_tensor = torch.tensor(mask, dtype=torch.int32)
|
| 102 |
+
|
| 103 |
+
return mask_tensor
|
| 104 |
+
|
| 105 |
+
import numpy as np
|
| 106 |
+
import cv2
|
| 107 |
+
|
| 108 |
+
def create_body_contour_by_seg_tissue(binary_mask: np.ndarray, area_threshold=1000) -> np.ndarray:
|
| 109 |
+
"""
|
| 110 |
+
提取组织分割图中的身体轮廓(保留最大连通域/多个大区域),输出二值 mask。
|
| 111 |
+
|
| 112 |
+
参数:
|
| 113 |
+
binary_mask: np.ndarray, 2D 输入图,非 0 为组织区域
|
| 114 |
+
area_threshold: int, 保留的最小连通域面积
|
| 115 |
+
返回:
|
| 116 |
+
contour_mask: np.uint8, 2D binary mask (0 or 1)
|
| 117 |
+
"""
|
| 118 |
+
mask_uint8 = (binary_mask > 0).astype(np.uint8).copy() * 255
|
| 119 |
+
|
| 120 |
+
# 找轮廓(忽略空洞)
|
| 121 |
+
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 122 |
+
|
| 123 |
+
contour_mask = np.zeros_like(mask_uint8, dtype=np.uint8)
|
| 124 |
+
|
| 125 |
+
for cnt in contours:
|
| 126 |
+
area = cv2.contourArea(cnt)
|
| 127 |
+
if area > area_threshold:
|
| 128 |
+
cv2.drawContours(contour_mask, [cnt], -1, 1, thickness=-1) # 填充轮廓
|
| 129 |
+
|
| 130 |
+
return (contour_mask > 0).astype(np.uint8)
|
| 131 |
+
|
| 132 |
+
import pandas as pd
|
| 133 |
+
|
| 134 |
+
def HU_assignment(mask, csv_file):
|
| 135 |
+
if isinstance(mask, torch.Tensor):
|
| 136 |
+
hu_mask = torch.zeros_like(mask)
|
| 137 |
+
elif isinstance(mask, np.ndarray):
|
| 138 |
+
hu_mask = np.zeros_like(mask)
|
| 139 |
+
|
| 140 |
+
df = pd.read_csv(csv_file)
|
| 141 |
+
hu_values = dict(zip(df['Order Number'], df['HU Value']))
|
| 142 |
+
order_begin_from_0 = True if df['Order Number'].min()==0 else False
|
| 143 |
+
# Value Assigment
|
| 144 |
+
hu_mask[mask == 0] = -1000 # background
|
| 145 |
+
for organ_index, hu_value in hu_values.items():
|
| 146 |
+
assert isinstance(hu_value, int), f"Expected mask value an integer, but got {hu_value}. Ensure the mask is created by fine mode of totalsegmentator"
|
| 147 |
+
assert isinstance(organ_index, int), f"Expected organ_index an integer, but got {organ_index}. Ensure the mask is created by fine mode of totalsegmentator"
|
| 148 |
+
if order_begin_from_0:
|
| 149 |
+
hu_mask[mask == (organ_index+1)] = hu_value # mask value begin from 1 as body value, other than 0 in TA2 table, so organ_index+1
|
| 150 |
+
else:
|
| 151 |
+
hu_mask[mask == (organ_index)] = hu_value
|
| 152 |
+
return hu_mask
|
| 153 |
+
|
| 154 |
+
class MaskHUAssigmentd:
|
| 155 |
+
def __init__(self, keys, csv_file):
|
| 156 |
+
self.keys = keys
|
| 157 |
+
# Read the CSV into a DataFrame
|
| 158 |
+
self.df = pd.read_csv(csv_file)
|
| 159 |
+
#print(self.hu_values)
|
| 160 |
+
|
| 161 |
+
def __call__(self, data):
|
| 162 |
+
# Create a dictionary to map organ index to HU values
|
| 163 |
+
for key in self.keys:
|
| 164 |
+
mask = data[key]
|
| 165 |
+
|
| 166 |
+
self.hu_values = dict(zip(self.df['Order Number'], self.df['HU Value']))
|
| 167 |
+
self.order_begin_from_0 = True if self.df['Order Number'].min()==0 else False
|
| 168 |
+
hu_mask = torch.zeros_like(mask)
|
| 169 |
+
# Value Assigment
|
| 170 |
+
hu_mask[mask == 0] = -1000 # background
|
| 171 |
+
for organ_index, hu_value in self.hu_values.items():
|
| 172 |
+
assert isinstance(hu_value, int), f"Expected mask value an integer, but got {hu_value}. Ensure the mask is created by fine mode of totalsegmentator"
|
| 173 |
+
assert isinstance(organ_index, int), f"Expected organ_index an integer, but got {organ_index}. Ensure the mask is created by fine mode of totalsegmentator"
|
| 174 |
+
if self.order_begin_from_0:
|
| 175 |
+
hu_mask[mask == (organ_index+1)] = hu_value # mask value begin from 1 as body value, other than 0 in TA2 table, so organ_index+1
|
| 176 |
+
else:
|
| 177 |
+
hu_mask[mask == (organ_index)] = hu_value
|
| 178 |
+
data[key] = hu_mask
|
| 179 |
+
return data
|
| 180 |
+
import pandas as pd
|
| 181 |
+
import numpy as np
|
| 182 |
+
|
| 183 |
+
def convert_segmentation_mask(source_mask, source_csv, target_csv, body_contour_value=1):
|
| 184 |
+
"""
|
| 185 |
+
Converts segmentation mask values from source modality to target modality based on organ name mapping.
|
| 186 |
+
|
| 187 |
+
Parameters:
|
| 188 |
+
- source_mask (ndarray): The source segmentation mask array.
|
| 189 |
+
- source_csv (str): Path to the CSV file of the source modality (CT or MR).
|
| 190 |
+
- target_csv (str): Path to the CSV file of the target modality (MR or CT).
|
| 191 |
+
- body_contour_value (int): The class value for "body contour" in the target modality.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
- target_mask (ndarray): The converted segmentation mask.
|
| 195 |
+
"""
|
| 196 |
+
# Load the source and target anatomy lists
|
| 197 |
+
source_df = pd.read_csv(source_csv)
|
| 198 |
+
target_df = pd.read_csv(target_csv)
|
| 199 |
+
|
| 200 |
+
# Create dictionaries mapping class values to organ names and vice versa
|
| 201 |
+
source_mapping = {row['Organ Name']: row.iloc[0] for _, row in source_df.iterrows()}
|
| 202 |
+
target_mapping = {row['Organ Name']: row.iloc[0] for _, row in target_df.iterrows()}
|
| 203 |
+
|
| 204 |
+
# Initialize the target mask
|
| 205 |
+
target_mask = np.full_like(source_mask, body_contour_value, dtype=source_mask.dtype)
|
| 206 |
+
|
| 207 |
+
# Convert each unique class in the source mask
|
| 208 |
+
for class_value in np.unique(source_mask):
|
| 209 |
+
# Find the corresponding organ name in the source modality
|
| 210 |
+
organ_name = {v: k for k, v in source_mapping.items()}.get(class_value, None)
|
| 211 |
+
|
| 212 |
+
# If organ name exists, find the target class value
|
| 213 |
+
if organ_name and organ_name in target_mapping:
|
| 214 |
+
target_value = target_mapping[organ_name]
|
| 215 |
+
else:
|
| 216 |
+
# Use body contour class value for unmapped organs
|
| 217 |
+
target_value = body_contour_value
|
| 218 |
+
|
| 219 |
+
# Replace class values in the target mask
|
| 220 |
+
target_mask[source_mask == class_value] = target_value
|
| 221 |
+
|
| 222 |
+
return target_mask
|
| 223 |
+
|
| 224 |
+
class CreateBodyContourTransformd:
|
| 225 |
+
def __init__(self, keys, body_threshold,body_mask_value):
|
| 226 |
+
self.keys = keys
|
| 227 |
+
self.body_threshold = body_threshold
|
| 228 |
+
self.body_mask_value = body_mask_value
|
| 229 |
+
|
| 230 |
+
def __call__(self, data):
|
| 231 |
+
# input medical image (CT) and create body contour, then replace the image by contour
|
| 232 |
+
for key in self.keys:
|
| 233 |
+
x = data[key]
|
| 234 |
+
#print(x)
|
| 235 |
+
mask = torch.zeros_like(x)
|
| 236 |
+
# [B, H, W, D]
|
| 237 |
+
# create a mask for each slice in the batch
|
| 238 |
+
for i in range(x.shape[0]):
|
| 239 |
+
for j in range(x.shape[-1]):
|
| 240 |
+
mask_slice = create_body_contour(x[i,:,:,j], body_threshold=self.body_threshold)
|
| 241 |
+
mask[i,:,:, j] = mask_slice
|
| 242 |
+
mask[mask == 1] = self.body_mask_value
|
| 243 |
+
if VERBOSE:
|
| 244 |
+
print("created mask shape:", mask.shape)
|
| 245 |
+
data[key] = mask
|
| 246 |
+
return data
|
| 247 |
+
|
| 248 |
+
class CreateBodyContourMultiModalTransformd:
|
| 249 |
+
def __init__(self, keys, body_threshold,body_mask_value):
|
| 250 |
+
self.keys = keys
|
| 251 |
+
self.body_threshold = body_threshold
|
| 252 |
+
self.body_mask_value = body_mask_value
|
| 253 |
+
|
| 254 |
+
def __call__(self, data):
|
| 255 |
+
# input medical image (CT) and create body contour, then replace the image by contour
|
| 256 |
+
for key in self.keys:
|
| 257 |
+
x = data[key]
|
| 258 |
+
#print(x)
|
| 259 |
+
mask = torch.zeros_like(x)
|
| 260 |
+
# [B, H, W, D]
|
| 261 |
+
# create a mask for each slice in the batch
|
| 262 |
+
for i in range(x.shape[0]):
|
| 263 |
+
for j in range(x.shape[-1]):
|
| 264 |
+
mask_slice = create_body_contour(x[i,:,:,j], body_threshold=self.body_threshold)
|
| 265 |
+
mask[i,:,:, j] = mask_slice
|
| 266 |
+
mask[mask == 1] = self.body_mask_value
|
| 267 |
+
if VERBOSE:
|
| 268 |
+
print("created mask shape:", mask.shape)
|
| 269 |
+
data[key] = mask
|
| 270 |
+
return data
|
| 271 |
+
|
| 272 |
+
def convert_xcat_to_ct_mask(xcat_image, mapping_csv, tolerance=0.5):
|
| 273 |
+
"""
|
| 274 |
+
Converts XCAT CT digital phantom images to simulated CT masks.
|
| 275 |
+
|
| 276 |
+
Parameters:
|
| 277 |
+
- xcat_image (torch.Tensor): The XCAT CT image tensor (in HU values).
|
| 278 |
+
- mapping_csv (str): Path to the CSV file containing organ, HU value, and mask value mappings.
|
| 279 |
+
- tolerance (float): Tolerance for HU value matching (default is ±0.5).
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
- ct_mask (torch.Tensor): The converted CT mask tensor.
|
| 283 |
+
"""
|
| 284 |
+
# Load the mapping CSV
|
| 285 |
+
mapping_df = pd.read_csv(mapping_csv)
|
| 286 |
+
|
| 287 |
+
# Initialize the CT mask as a tensor filled with zeros (or another default background value)
|
| 288 |
+
if isinstance(xcat_image, np.ndarray):
|
| 289 |
+
ct_mask = np.zeros_like(xcat_image, dtype=np.int32)
|
| 290 |
+
elif isinstance(xcat_image, torch.Tensor):
|
| 291 |
+
ct_mask = torch.zeros_like(xcat_image, dtype=torch.int32)
|
| 292 |
+
else:
|
| 293 |
+
raise TypeError("xcat_image must be a NumPy ndarray or a PyTorch tensor.")
|
| 294 |
+
|
| 295 |
+
# Iterate over the mapping and replace pixel values
|
| 296 |
+
for _, row in mapping_df.iterrows():
|
| 297 |
+
organ = row['Organ']
|
| 298 |
+
hu_value = row['HU_Value']
|
| 299 |
+
mask_value = row['Mask_Value']
|
| 300 |
+
|
| 301 |
+
# Apply the tolerance range for matching
|
| 302 |
+
lower_bound = hu_value - tolerance
|
| 303 |
+
upper_bound = hu_value + tolerance
|
| 304 |
+
|
| 305 |
+
# Replace matching pixels with the mask value
|
| 306 |
+
match_condition = (xcat_image >= lower_bound) & (xcat_image <= upper_bound)
|
| 307 |
+
ct_mask[match_condition] = mask_value
|
| 308 |
+
|
| 309 |
+
print(f"Processed {organ} with HU range [{lower_bound}, {upper_bound}] to mask value {mask_value}")
|
| 310 |
+
return ct_mask
|
| 311 |
+
|
| 312 |
+
class UseContourToFilterImaged:
|
| 313 |
+
def __init__(self,
|
| 314 |
+
keys: List[str]
|
| 315 |
+
):
|
| 316 |
+
if len(keys) != 2:
|
| 317 |
+
raise ValueError("Keys must be a list with exactly two string elements.")
|
| 318 |
+
self.image_key = keys[0]
|
| 319 |
+
self.contour_key = keys[1]
|
| 320 |
+
def __call__(self, data):
|
| 321 |
+
image = data[self.image_key]
|
| 322 |
+
contour = data[self.contour_key]
|
| 323 |
+
data[self.image_key] = image*contour
|
| 324 |
+
return data
|
| 325 |
+
|
| 326 |
+
class MergeMasksTransformd:
|
| 327 |
+
def __init__(self,
|
| 328 |
+
keys: List[str]):
|
| 329 |
+
if len(keys) != 2:
|
| 330 |
+
raise ValueError("Keys must be a list with exactly two string elements.")
|
| 331 |
+
self.seg_key = keys[0]
|
| 332 |
+
self.contour_key = keys[1]
|
| 333 |
+
|
| 334 |
+
def __call__(self, data):
|
| 335 |
+
seg = data[self.seg_key]
|
| 336 |
+
contour = data[self.contour_key]
|
| 337 |
+
merged_mask = seg + contour
|
| 338 |
+
|
| 339 |
+
data[self.seg_key] = merged_mask
|
| 340 |
+
return data
|
| 341 |
+
|
| 342 |
+
class MergeSegTissueTransformd:
|
| 343 |
+
def __init__(self,
|
| 344 |
+
keys: List[str]):
|
| 345 |
+
if len(keys) != 2:
|
| 346 |
+
raise ValueError("Keys must be a list with exactly two string elements.")
|
| 347 |
+
self.seg_key = keys[0]
|
| 348 |
+
self.tissue_key = keys[1]
|
| 349 |
+
|
| 350 |
+
def __call__(self, data):
|
| 351 |
+
seg = data[self.seg_key]
|
| 352 |
+
tissue = data[self.tissue_key]
|
| 353 |
+
tissue += 100 # keep the tissue value always higher as segmentation organs
|
| 354 |
+
# Create a mask for overlapping areas
|
| 355 |
+
overlap_mask = (seg > 0) & (tissue > 0)
|
| 356 |
+
|
| 357 |
+
# For overlapping areas, keep the lower value (organ values in seg)
|
| 358 |
+
merged_mask = tissue.copy()
|
| 359 |
+
merged_mask[overlap_mask] = seg[overlap_mask]
|
| 360 |
+
|
| 361 |
+
# Keep all non-overlapping areas
|
| 362 |
+
merged_mask[seg > 0] = seg[seg > 0]
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
data[self.seg_key] = merged_mask
|
| 366 |
+
return data
|
| 367 |
+
|
| 368 |
+
class DivideTransformd:
|
| 369 |
+
def __init__(self,
|
| 370 |
+
keys: List[str],
|
| 371 |
+
divide_factor):
|
| 372 |
+
self.keys=keys
|
| 373 |
+
self.divide_factor=divide_factor
|
| 374 |
+
def __call__(self, data):
|
| 375 |
+
for key in self.keys:
|
| 376 |
+
data[key] = data[key]/self.divide_factor
|
| 377 |
+
return data
|
| 378 |
+
|
| 379 |
+
class MergeMasksTransformOldd:
|
| 380 |
+
def __init__(self, keys):
|
| 381 |
+
self.keys = keys
|
| 382 |
+
|
| 383 |
+
def __call__(self, data):
|
| 384 |
+
#print('check MergeMasksTransformd:', data)
|
| 385 |
+
merged_mask = torch.zeros_like(data[self.keys[0]], dtype=torch.int32)
|
| 386 |
+
|
| 387 |
+
for key in self.keys:
|
| 388 |
+
merged_mask += data[key].to(torch.int32)
|
| 389 |
+
for key in self.keys:
|
| 390 |
+
data[key] = merged_mask
|
| 391 |
+
return data
|
| 392 |
+
|
| 393 |
+
# convert the integer segemented labels to one-hot codes for training
|
| 394 |
+
class ConvertToOneHotd:
|
| 395 |
+
def __init__(self, keys, number_classes):
|
| 396 |
+
self.keys = keys
|
| 397 |
+
self.nc = number_classes
|
| 398 |
+
|
| 399 |
+
def __call__(self, data):
|
| 400 |
+
for key in self.keys:
|
| 401 |
+
x = data[key]
|
| 402 |
+
# Ensure the tensor is of the correct type
|
| 403 |
+
if x.dtype != torch.long:
|
| 404 |
+
x = x.long()
|
| 405 |
+
# Create the one-hot encoded tensor
|
| 406 |
+
one_hot = torch.zeros(x.size(0), self.nc, x.size(1), x.size(2), device=x.device)
|
| 407 |
+
one_hot.scatter_(1, x.unsqueeze(1), 1)
|
| 408 |
+
data[key] = one_hot
|
| 409 |
+
return data
|
| 410 |
+
|
| 411 |
+
# Example usage
|
| 412 |
+
# Assuming `ct_image_tensor` is a PyTorch tensor of a CT image
|
| 413 |
+
# ct_image_tensor = torch.tensor(img_array, dtype=torch.float32)
|
| 414 |
+
# mask_tensor = create_body_contour(ct_image_tensor)
|
| 415 |
+
|
| 416 |
+
class CreateMaskWithBonesTransform:
|
| 417 |
+
def __init__(self,tissue_min,tissue_max,bone_min,bone_max):
|
| 418 |
+
# You can add initialization parameters if needed
|
| 419 |
+
self.tissue_min = tissue_min
|
| 420 |
+
self.tissue_max = tissue_max
|
| 421 |
+
self.bone_min = bone_min
|
| 422 |
+
self.bone_max = bone_max
|
| 423 |
+
|
| 424 |
+
def __call__(self, x):
|
| 425 |
+
# x is the input tensor
|
| 426 |
+
# Initialize mask with zeros (background)
|
| 427 |
+
mask = torch.zeros_like(x)
|
| 428 |
+
|
| 429 |
+
# Assign label 1 to tissue regions (-500 to 200)
|
| 430 |
+
mask[(x > self.tissue_min) & (x <= self.tissue_max)] = 1
|
| 431 |
+
|
| 432 |
+
# Assign label 2 to bone regions (200 to 1500)
|
| 433 |
+
mask[(x >= self.bone_min) & (x <= self.bone_max)] = 2
|
| 434 |
+
|
| 435 |
+
return mask
|
| 436 |
+
|
| 437 |
+
class CreateMaskWithBonesTransformd:
|
| 438 |
+
def __init__(self, keys, tissue_min, tissue_max, bone_min, bone_max):
|
| 439 |
+
self.keys = keys
|
| 440 |
+
self.tissue_min = tissue_min
|
| 441 |
+
self.tissue_max = tissue_max
|
| 442 |
+
self.bone_min = bone_min
|
| 443 |
+
self.bone_max = bone_max
|
| 444 |
+
def __call__(self, data):
|
| 445 |
+
for key in self.keys:
|
| 446 |
+
x = data[key]
|
| 447 |
+
|
| 448 |
+
mask = torch.zeros_like(x)
|
| 449 |
+
# [B, H, W, D]
|
| 450 |
+
for i in range(x.shape[0]):
|
| 451 |
+
for j in range(x.shape[-1]):
|
| 452 |
+
mask_slice = create_body_contour(x[i,:,:,j], body_threshold=self.tissue_min)
|
| 453 |
+
mask[i,:,:, j] = mask_slice
|
| 454 |
+
#mask = torch.zeros_like(x)
|
| 455 |
+
#mask[(x > self.tissue_min) & (x <= self.tissue_max)] = 1
|
| 456 |
+
mask[(x >= self.bone_min) & (x <= self.bone_max)] = 2
|
| 457 |
+
data[key] = mask
|
| 458 |
+
#print("input and mask shape: ",x.shape,data[key].shape)
|
| 459 |
+
return data
|
| 460 |
+
|
| 461 |
+
class NormalizationMultimodal:
|
| 462 |
+
def __init__(self, keys):
|
| 463 |
+
if len(keys) != 2:
|
| 464 |
+
raise ValueError("Keys must be a list with exactly two string elements.")
|
| 465 |
+
self.prior_key = keys[0]
|
| 466 |
+
self.target_key = keys[1]
|
| 467 |
+
|
| 468 |
+
self.prior_modality_norm_dict = {
|
| 469 |
+
0: {'min': -300, 'max': 700}, # CT WW=1000, WL=200
|
| 470 |
+
1: {'min': 0, 'max': 9}, # T1
|
| 471 |
+
2: {'min': 0, 'max': 28}, # T2
|
| 472 |
+
3: {'min': 0, 'max': 9}, # VIBE-IN
|
| 473 |
+
4: {'min': 0, 'max': 10}, # VIBE-OPP
|
| 474 |
+
5: {'min': 0, 'max': 6}, # DIXON
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
self.target_modality_norm_dict = {
|
| 478 |
+
0: {'min': -300, 'max': 700}, # CT
|
| 479 |
+
1: {'min': 0, 'max': 800}, # T1
|
| 480 |
+
2: {'min': 0, 'max': 160}, # T2
|
| 481 |
+
3: {'min': 0, 'max': 500}, # VIBE-IN
|
| 482 |
+
4: {'min': 0, 'max': 520}, # VIBE-OPP
|
| 483 |
+
5: {'min': 0, 'max': 560}, # DIXON
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
def __call__(self, data):
|
| 487 |
+
modality = int(data['modality'])
|
| 488 |
+
|
| 489 |
+
if modality not in self.target_modality_norm_dict:
|
| 490 |
+
raise ValueError(f"Unsupported modality id: {modality}")
|
| 491 |
+
|
| 492 |
+
# Normalize target
|
| 493 |
+
x_target = data[self.target_key]
|
| 494 |
+
target_params = self.target_modality_norm_dict[modality]
|
| 495 |
+
x_target = torch.clamp(x_target, target_params['min'], target_params['max'])
|
| 496 |
+
x_target = (x_target - target_params['min']) / (target_params['max'] - target_params['min'])
|
| 497 |
+
data[self.target_key] = x_target
|
| 498 |
+
|
| 499 |
+
# Normalize prior
|
| 500 |
+
x_prior = data[self.prior_key]
|
| 501 |
+
prior_params = self.prior_modality_norm_dict[modality]
|
| 502 |
+
x_prior = torch.clamp(x_prior, prior_params['min'], prior_params['max'])
|
| 503 |
+
x_prior = (x_prior - prior_params['min']) / (prior_params['max'] - prior_params['min'])
|
| 504 |
+
data[self.prior_key] = x_prior
|
| 505 |
+
|
| 506 |
+
return data
|
| 507 |
+
|
dataprocesser/data_processing/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
saved_png
|
| 2 |
+
data
|
| 3 |
+
__pycache__
|
| 4 |
+
*/*.asv
|
dataprocesser/data_processing/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mydataloader
|
| 2 |
+
dataloader for all projects
|
| 3 |
+
|
| 4 |
+
--0709 add center_crop in slicer_loader
|
| 5 |
+
|
| 6 |
+
--0709 test recursively push submodule
|
| 7 |
+
|
| 8 |
+
--0710 add center_crop in slicer_loader
|
| 9 |
+
|
| 10 |
+
--0729 add make_cond.py
|
| 11 |
+
|
| 12 |
+
--0730 add conditional_loader.py
|
| 13 |
+
|
| 14 |
+
--1103 add input_only normalization in basics.py
|
| 15 |
+
|
| 16 |
+
--1106 change the place of ResizeWithPadOrCropd into crop_volumes of basics.py to directly get 512*512 reversed output
|
| 17 |
+
|
| 18 |
+
--1108 merge gan_loader.py together. Change get_file_list in basics.py, replace ct and mr as "source" and "target"
|
| 19 |
+
|
| 20 |
+
--1109 delete rotate in basics.py
|
dataprocesser/data_processing/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python module for loading data from monai
|
| 2 |
+
__all__ = ['3d_loader', 'slice_loader', 'slice_loader2', 'basics','manual_slice_loader']
|
dataprocesser/data_processing/data_process/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
**pycache**/
|
dataprocesser/data_processing/data_process/CTbatchevaluate.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from CTevaluate import *
|
| 4 |
+
|
| 5 |
+
def singleevaluate(file_path, window_width = 150, window_level = 30):
|
| 6 |
+
# Load the NIfTI CT image data using nibabel.
|
| 7 |
+
if file_path.endswith('.nii.gz'):
|
| 8 |
+
ct_image_nifti = nib.load(file_path)
|
| 9 |
+
ct_image_data = ct_image_nifti.get_fdata()
|
| 10 |
+
ct_image_nifti = nib.load(file_path)
|
| 11 |
+
ct_image_data = ct_image_nifti.get_fdata()
|
| 12 |
+
elif file_path.endswith('.nrrd'):
|
| 13 |
+
ct_image_data, header = nrrd.read(file_path)
|
| 14 |
+
|
| 15 |
+
ct_data_shape=ct_image_data.shape
|
| 16 |
+
|
| 17 |
+
#plot_ct_value_distribution(ct_image_data)
|
| 18 |
+
ct_image_data = ct_windowing(ct_image_data, window_width, window_level)
|
| 19 |
+
#plot_ct_value_distribution(ct_image_data)
|
| 20 |
+
|
| 21 |
+
# cut roi
|
| 22 |
+
center_x = ct_image_data.shape[0] // 2
|
| 23 |
+
center_y = ct_image_data.shape[1] // 2
|
| 24 |
+
ct_image_roi=extract_roi(ct_image_data, center_x=center_x, center_y=center_y, length=300, width=300)
|
| 25 |
+
ct_image_roi_mean=np.mean(ct_image_roi)
|
| 26 |
+
|
| 27 |
+
# Calculate contrast and standard deviation of CT values.
|
| 28 |
+
contrast = calculate_contrast(ct_image_data)
|
| 29 |
+
std_deviation = calculate_standard_deviation(ct_image_data)
|
| 30 |
+
return ct_image_roi_mean, contrast, std_deviation, ct_data_shape
|
| 31 |
+
|
| 32 |
+
def batchevaluate(dataset_path, format='.nii.gz', save_path='', nii_name='test'):
|
| 33 |
+
for patient_data in glob.glob(dataset_path + "/*"):
|
| 34 |
+
if patient_data.endswith(format):
|
| 35 |
+
patient_name=os.path.basename(os.path.normpath(patient_data))
|
| 36 |
+
print('-------------', patient_name, '-------------')
|
| 37 |
+
ct_image_roi_mean, contrast, std_deviation, ct_data_shape = singleevaluate(patient_data)
|
| 38 |
+
with open(os.path.join(save_path, f'{nii_name}.txt'), 'a') as f:
|
| 39 |
+
f.write('-------------'+patient_name+'-------------\n')
|
| 40 |
+
f.write('Mean of CT values in ROI: '+str(ct_image_roi_mean)+'\n')
|
| 41 |
+
f.write('Contrast of CT image: '+str(contrast)+'\n')
|
| 42 |
+
f.write('Standard Deviation of CT values: '+str(std_deviation)+'\n')
|
| 43 |
+
f.write('Size of CT image: '+str(ct_data_shape)+'\n')
|
| 44 |
+
def main():
|
| 45 |
+
dataset_path=r'D:\Data\dataNeaotomAlpha\NIFTI23072115'
|
| 46 |
+
batchevaluate(dataset_path=dataset_path, format='.nii.gz', save_path=dataset_path, nii_name='evaluate')
|
| 47 |
+
|
| 48 |
+
if __name__=="__main__":
|
| 49 |
+
main()
|
dataprocesser/data_processing/data_process/CTevaluate.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import nibabel as nib
|
| 3 |
+
import nrrd
|
| 4 |
+
|
| 5 |
+
def extract_roi(ct_image, center_x=256, center_y=256, length=300, width=300):
|
| 6 |
+
"""
|
| 7 |
+
Extract a Region of Interest (ROI) from the CT image.
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
|
| 11 |
+
center_x (int): X-coordinate of the center of the ROI.
|
| 12 |
+
center_y (int): Y-coordinate of the center of the ROI.
|
| 13 |
+
length (int): Length of the square ROI.
|
| 14 |
+
width (int): Width of the square ROI.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
numpy.ndarray: The ROI extracted from the CT image.
|
| 18 |
+
"""
|
| 19 |
+
half_length = length // 2
|
| 20 |
+
half_width = width // 2
|
| 21 |
+
|
| 22 |
+
start_x = max(0, center_x - half_length)
|
| 23 |
+
end_x = min(ct_image.shape[0], center_x + half_length)
|
| 24 |
+
start_y = max(0, center_y - half_width)
|
| 25 |
+
end_y = min(ct_image.shape[1], center_y + half_width)
|
| 26 |
+
|
| 27 |
+
return ct_image[:, start_x:end_x, start_y:end_y]
|
| 28 |
+
|
| 29 |
+
def calculate_contrast(ct_image):
|
| 30 |
+
"""
|
| 31 |
+
Calculate the contrast of a CT image.
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
float: The contrast of the CT image.
|
| 38 |
+
"""
|
| 39 |
+
# Assuming the CT image data ranges from -1024 to 3071 (typical Hounsfield Units range for CT scans)
|
| 40 |
+
min_value = -1024.0
|
| 41 |
+
max_value = 3071.0
|
| 42 |
+
#ct_image_roi=extract_roi(ct_image, center_x=0, center_y=0, length=300, width=300)
|
| 43 |
+
#ct_image_roi_mean=np.mean(ct_image_roi)
|
| 44 |
+
contrast = np.abs((np.max(ct_image) - np.min(ct_image))) / (max_value - min_value)
|
| 45 |
+
return contrast
|
| 46 |
+
|
| 47 |
+
def calculate_standard_deviation(ct_image):
|
| 48 |
+
"""
|
| 49 |
+
Calculate the standard deviation of CT values in the image.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
float: The standard deviation of CT values.
|
| 56 |
+
"""
|
| 57 |
+
return np.std(ct_image)
|
| 58 |
+
|
| 59 |
+
import matplotlib.pyplot as plt
|
| 60 |
+
|
| 61 |
+
def plot_ct_value_distribution(ct_image):
|
| 62 |
+
"""
|
| 63 |
+
Plot the distribution of CT values in the image.
|
| 64 |
+
|
| 65 |
+
Parameters:
|
| 66 |
+
ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
|
| 67 |
+
"""
|
| 68 |
+
# Flatten the 3D array to a 1D array to get all CT values.
|
| 69 |
+
ct_values = ct_image.flatten()
|
| 70 |
+
|
| 71 |
+
# Create the histogram of CT values.
|
| 72 |
+
plt.hist(ct_values, bins=100, range=(-1024, 3071), color='blue', alpha=0.7)
|
| 73 |
+
plt.xlabel('CT Value')
|
| 74 |
+
plt.ylabel('Frequency')
|
| 75 |
+
plt.title('Distribution of CT Values')
|
| 76 |
+
plt.grid(True)
|
| 77 |
+
plt.show()
|
| 78 |
+
|
| 79 |
+
def ct_windowing(ct_image, window_width, window_level):
|
| 80 |
+
"""
|
| 81 |
+
Apply CT windowing to the CT image.
|
| 82 |
+
|
| 83 |
+
Parameters:
|
| 84 |
+
ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
|
| 85 |
+
window_width (float): The window width.
|
| 86 |
+
window_level (float): The window level.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
numpy.ndarray: The CT image data after applying windowing.
|
| 90 |
+
"""
|
| 91 |
+
# Calculate the lower and upper bounds of the window.
|
| 92 |
+
lower_bound = window_level - window_width / 2.0
|
| 93 |
+
upper_bound = window_level + window_width / 2.0
|
| 94 |
+
|
| 95 |
+
# Clip the CT values within the window bounds.
|
| 96 |
+
ct_image_windowed = np.clip(ct_image, lower_bound, upper_bound)
|
| 97 |
+
|
| 98 |
+
return ct_image_windowed
|
| 99 |
+
|
| 100 |
+
def main():
|
| 101 |
+
# Replace 'your_ct_image.nii' with the path to your NIfTI CT image file.
|
| 102 |
+
pcct_path = r'D:\Data\dataNeaotomAlpha\Nifti\2511\2511_2.nii.gz'
|
| 103 |
+
cbct_path = r'D:\Data\M2OLIE_Phantom\pre_cbct.nrrd'
|
| 104 |
+
nifti_file_path = pcct_path
|
| 105 |
+
nrrd_file_path = cbct_path
|
| 106 |
+
|
| 107 |
+
# Load the NIfTI CT image data using nibabel.
|
| 108 |
+
ct_image_nifti = nib.load(nifti_file_path)
|
| 109 |
+
ct_image_data = ct_image_nifti.get_fdata()
|
| 110 |
+
|
| 111 |
+
#ct_image_data, header = nrrd.read(nrrd_file_path)
|
| 112 |
+
|
| 113 |
+
window_width = 150
|
| 114 |
+
window_level = 30
|
| 115 |
+
#plot_ct_value_distribution(ct_image_data)
|
| 116 |
+
ct_image_data = ct_windowing(ct_image_data, window_width, window_level)
|
| 117 |
+
#plot_ct_value_distribution(ct_image_data)
|
| 118 |
+
|
| 119 |
+
# cut roi
|
| 120 |
+
center_x = ct_image_data.shape[0] // 2
|
| 121 |
+
center_y = ct_image_data.shape[1] // 2
|
| 122 |
+
ct_image_roi=extract_roi(ct_image_data, center_x=center_x, center_y=center_y, length=300, width=300)
|
| 123 |
+
ct_image_roi_mean=np.mean(ct_image_roi)
|
| 124 |
+
|
| 125 |
+
# Calculate contrast and standard deviation of CT values.
|
| 126 |
+
contrast = calculate_contrast(ct_image_data)
|
| 127 |
+
std_deviation = calculate_standard_deviation(ct_image_data)
|
| 128 |
+
print(ct_image_data.shape)
|
| 129 |
+
|
| 130 |
+
print("size of ROI:", ct_image_roi.shape)
|
| 131 |
+
print("Mean of CT values in ROI:", ct_image_roi_mean)
|
| 132 |
+
print("Contrast of CT image:", contrast)
|
| 133 |
+
print("Standard Deviation of CT values:", std_deviation)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
main()
|
dataprocesser/data_processing/data_process/convert_dicoms.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dicom2nifti
|
| 2 |
+
import SimpleITK as sitk
|
| 3 |
+
from dicom2nifti.exceptions import ConversionValidationError
|
| 4 |
+
from MsSeg.SegmentationNetworkBasis.NetworkBasis import image as Image
|
| 5 |
+
import os
|
| 6 |
+
import glob
|
| 7 |
+
import pydicom
|
| 8 |
+
|
| 9 |
+
#dataset_path = r"C:\Users\ms97\Documents\MRF-Daten\Messdaten"
|
| 10 |
+
|
| 11 |
+
def fromrootgroupconvert(dataset_path, nii_name='test'):
|
| 12 |
+
i=0
|
| 13 |
+
for patient_folder in glob.glob(dataset_path + "/*/"):
|
| 14 |
+
print('-------------', patient_folder, '-------------')
|
| 15 |
+
i=i+1
|
| 16 |
+
t1_nii_path = os.path.join(dataset_path, patient_folder, f'{nii_name}_{i}.nii')
|
| 17 |
+
try:
|
| 18 |
+
try:
|
| 19 |
+
t1_dicom_path = os.path.join(dataset_path, patient_folder)
|
| 20 |
+
dicom2nifti.dicom_series_to_nifti(t1_dicom_path, t1_nii_path)
|
| 21 |
+
except OSError as err:
|
| 22 |
+
print("Finished for Sequence T1Map " + patient_folder)
|
| 23 |
+
|
| 24 |
+
t1_img = sitk.ReadImage(t1_nii_path)
|
| 25 |
+
data_info = Image.get_data_info(t1_img)
|
| 26 |
+
print('Data Info T1: ', data_info)
|
| 27 |
+
except (KeyError, IndexError) as err:
|
| 28 |
+
print("Failed for Sequence T1Map " + patient_folder + " ", err)
|
| 29 |
+
|
| 30 |
+
def simplepatientconvert(patient_folder, nii_name='test'):
|
| 31 |
+
t1_nii_path = os.path.join(patient_folder, f'{nii_name}.nii')
|
| 32 |
+
try:
|
| 33 |
+
try:
|
| 34 |
+
dicom2nifti.dicom_series_to_nifti(patient_folder, t1_nii_path)
|
| 35 |
+
except OSError as err:
|
| 36 |
+
print("Finished for Sequence Dicom " + patient_folder)
|
| 37 |
+
|
| 38 |
+
t1_img = sitk.ReadImage(t1_nii_path)
|
| 39 |
+
data_info = Image.get_data_info(t1_img)
|
| 40 |
+
print('Data Info T1: ', data_info)
|
| 41 |
+
except (KeyError, IndexError) as err:
|
| 42 |
+
print("Failed for Sequence Dicom " + patient_folder + " ", err)
|
| 43 |
+
|
| 44 |
+
def itkfromrootgroupconvert(dataset_path, nii_name='test'):
|
| 45 |
+
i=0
|
| 46 |
+
for patient_folder in glob.glob(dataset_path + "/*/"):
|
| 47 |
+
print('-------------', patient_folder, '-------------')
|
| 48 |
+
i=i+1
|
| 49 |
+
try:
|
| 50 |
+
try:
|
| 51 |
+
reader = sitk.ImageSeriesReader()
|
| 52 |
+
dicom_names = reader.GetGDCMSeriesFileNames(patient_folder)
|
| 53 |
+
reader.SetFileNames(dicom_names)
|
| 54 |
+
image = reader.Execute()
|
| 55 |
+
basefoldername=os.path.basename(os.path.normpath(patient_folder))
|
| 56 |
+
t1_nii_path = os.path.join(dataset_path, f'{basefoldername}.nii.gz')
|
| 57 |
+
# Added a call to PermuteAxes to change the axes of the data
|
| 58 |
+
image = sitk.PermuteAxes(image, [2, 1, 0])
|
| 59 |
+
sitk.WriteImage(image, t1_nii_path)
|
| 60 |
+
except OSError as err:
|
| 61 |
+
print("Finished for Sequence T1Map " + patient_folder)
|
| 62 |
+
|
| 63 |
+
t1_img = sitk.ReadImage(t1_nii_path)
|
| 64 |
+
data_info = Image.get_data_info(t1_img)
|
| 65 |
+
print('Data Info Dicom: ', data_info)
|
| 66 |
+
except (KeyError, IndexError) as err:
|
| 67 |
+
print("Failed for Sequence Dicom " + patient_folder + " ", err)
|
| 68 |
+
|
| 69 |
+
def itkforpatientconvert(patient_folder, nii_name='test'):
|
| 70 |
+
reader = sitk.ImageSeriesReader()
|
| 71 |
+
dicom_names = reader.GetGDCMSeriesFileNames(patient_folder)
|
| 72 |
+
reader.SetFileNames(dicom_names)
|
| 73 |
+
image = reader.Execute()
|
| 74 |
+
t1_nii_path = os.path.join(patient_folder, f'{nii_name}.nii.gz')
|
| 75 |
+
# Added a call to PermuteAxes to change the axes of the data
|
| 76 |
+
image = sitk.PermuteAxes(image, [2, 1, 0])
|
| 77 |
+
sitk.WriteImage(image, t1_nii_path)
|
| 78 |
+
|
| 79 |
+
if __name__=="__main__":
|
| 80 |
+
dataset_path = r"D:\Data\dataNeaotomAlpha\DICOM_Naeotom\DICOM\23072115"
|
| 81 |
+
itkfromrootgroupconvert(dataset_path)
|
| 82 |
+
#patient_folder = r"D:\Data\dataNeaotomAlpha\Q0Q1Q4"
|
| 83 |
+
#simpleitkconvert(patient_folder,'2511')
|
dataprocesser/data_processing/data_process/make_cond.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import nibabel as nib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
def make_cond(dataset_path):
|
| 7 |
+
for patient_folder in tqdm(glob.glob(dataset_path + "/*/")):
|
| 8 |
+
if 'overview' not in patient_folder:
|
| 9 |
+
ct_file=os.path.join(patient_folder,'ct.nii.gz')
|
| 10 |
+
ct_image_nifti = nib.load(ct_file)
|
| 11 |
+
ct_image_data = ct_image_nifti.get_fdata()
|
| 12 |
+
ct_slice_number=ct_image_data.shape[-1]
|
| 13 |
+
ct_slice_label=np.arange(0,ct_slice_number-1,1)
|
| 14 |
+
# write into csv
|
| 15 |
+
with open(os.path.join(patient_folder, 'ct_slice_cond.csv'), 'w') as f:
|
| 16 |
+
f.write('slice\n')
|
| 17 |
+
for i in range(len(ct_slice_label)):
|
| 18 |
+
f.write(str(ct_slice_label[i])+'\n')
|
| 19 |
+
|
| 20 |
+
mr_file=os.path.join(patient_folder,'mr.nii.gz')
|
| 21 |
+
mr_image_nifti = nib.load(mr_file)
|
| 22 |
+
mr_image_data = mr_image_nifti.get_fdata()
|
| 23 |
+
mr_slice_number=mr_image_data.shape[-1]
|
| 24 |
+
mr_slice_label=np.arange(0,mr_slice_number-1,1)
|
| 25 |
+
# write into csv
|
| 26 |
+
with open(os.path.join(patient_folder, 'mr_slice_cond.csv'), 'w') as f:
|
| 27 |
+
f.write('slice\n')
|
| 28 |
+
for i in range(len(mr_slice_label)):
|
| 29 |
+
f.write(str(mr_slice_label[i])+'\n')
|
| 30 |
+
|
| 31 |
+
def main():
|
| 32 |
+
dataset_path=r'F:\yang_Projects\Datasets\Task1\pelvis'
|
| 33 |
+
dataset_path_razer=r'C:\Users\56991\Projects\Datasets\Task1\pelvis'
|
| 34 |
+
make_cond(dataset_path)
|
| 35 |
+
|
| 36 |
+
if __name__=="__main__":
|
| 37 |
+
main()
|
dataprocesser/data_processing/data_process/matlab/BCELossIllustration.m
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
% Create an example image (ground truth)
|
| 2 |
+
images1 = rand(256, 256); % Random matrix as an example
|
| 3 |
+
images1(images1 > 0.5) = 0.9999;
|
| 4 |
+
images1(images1 <= 0.5) = 0.0001;
|
| 5 |
+
|
| 6 |
+
% Add noise to the image
|
| 7 |
+
noise = randn(size(images1)) * 0.01; % Gaussian noise
|
| 8 |
+
images2 = images1+noise;
|
| 9 |
+
|
| 10 |
+
% Ensure values are in the range [0, 1]
|
| 11 |
+
% images2 = max(min(images2, 1), 0);
|
| 12 |
+
|
| 13 |
+
% Plot the images
|
| 14 |
+
subplot(1, 2, 1);
|
| 15 |
+
imshow(images1);
|
| 16 |
+
title('Original Image');
|
| 17 |
+
|
| 18 |
+
subplot(1, 2, 2);
|
| 19 |
+
imshow(images2);
|
| 20 |
+
title('Noisy Image');
|
| 21 |
+
|
| 22 |
+
% Calculate BCEWithLogitsLoss
|
| 23 |
+
BCEWithLogitsLoss = calculateBCEWithLogitsLoss(images1, images2);
|
| 24 |
+
disp(['BCEWithLogitsLoss: ', num2str(BCEWithLogitsLoss)]);
|
| 25 |
+
|
| 26 |
+
BCELoss=calculateBCELoss(images1, images2);
|
| 27 |
+
disp(['BCELoss: ', num2str(BCELoss)]);
|
| 28 |
+
|
| 29 |
+
function loss = calculateBCELoss(images1, images2)
|
| 30 |
+
% Convert probabilities to logits
|
| 31 |
+
logits2 = images2;
|
| 32 |
+
|
| 33 |
+
% Calculate BCEWithLogitsLoss
|
| 34 |
+
loss = mean(mean(-images1 .* log(logits2) - (1 - images1) .* log(1 - logits2)));
|
| 35 |
+
end
|
| 36 |
+
|
| 37 |
+
function loss = calculateBCEWithLogitsLoss(images1, images2)
|
| 38 |
+
% Convert probabilities to logits
|
| 39 |
+
logits2 = images2;
|
| 40 |
+
|
| 41 |
+
% Calculate BCEWithLogitsLoss
|
| 42 |
+
loss = mean(mean(-images1 .* log(sigmoid(logits2)) - (1 - images1) .* log(1 - sigmoid(logits2))));
|
| 43 |
+
end
|
| 44 |
+
|
| 45 |
+
function logit = probToLogit(p)
|
| 46 |
+
% Convert probability to logit
|
| 47 |
+
logit = log(p ./ (1 - p));
|
| 48 |
+
end
|
| 49 |
+
|
| 50 |
+
function s = sigmoid(x)
|
| 51 |
+
% Sigmoid function
|
| 52 |
+
s = 1 ./ (1 + exp(-x));
|
| 53 |
+
end
|