zy7_oldserver commited on
Commit
fd601de
·
1 Parent(s): 299136f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +16 -0
  2. README.md +1 -20
  3. dataprocesser/.gitignore +1 -0
  4. dataprocesser/Preprocess_CT_Mask_generation.py +267 -0
  5. dataprocesser/Preprocess_MRCT_mask_conversion.py +294 -0
  6. dataprocesser/Preprocess_MR_Mask_generation.py +306 -0
  7. dataprocesser/Preprocess_MR_Masks_overlay.py +78 -0
  8. dataprocesser/__init__.py +8 -0
  9. dataprocesser/archive/archiv.py +236 -0
  10. dataprocesser/archive/basics.py +167 -0
  11. dataprocesser/archive/checkdata.py +91 -0
  12. dataprocesser/archive/createsegtransform.py +276 -0
  13. dataprocesser/archive/csv_dataset.py +121 -0
  14. dataprocesser/archive/csv_dataset_slices.py +20 -0
  15. dataprocesser/archive/csv_dataset_slices_assigned.py +11 -0
  16. dataprocesser/archive/data_create_seg.py +28 -0
  17. dataprocesser/archive/data_slicing.py +13 -0
  18. dataprocesser/archive/dataset_med.py +188 -0
  19. dataprocesser/archive/gan_loader.py +310 -0
  20. dataprocesser/archive/init_dataset.py +0 -0
  21. dataprocesser/archive/json_dataset_slices.py +28 -0
  22. dataprocesser/archive/list_dataset_Anika.py +10 -0
  23. dataprocesser/archive/list_dataset_Anish.py +0 -0
  24. dataprocesser/archive/list_dataset_Anish_seg.py +42 -0
  25. dataprocesser/archive/list_dataset_base.py +983 -0
  26. dataprocesser/archive/list_dataset_combined_seg.py +15 -0
  27. dataprocesser/archive/list_dataset_combined_seg_assigned.py +1 -0
  28. dataprocesser/archive/list_dataset_synthrad.py +0 -0
  29. dataprocesser/archive/list_dataset_synthrad_seg.py +3 -0
  30. dataprocesser/archive/monai_loader_3D.py +367 -0
  31. dataprocesser/archive/slice_loader.py +124 -0
  32. dataprocesser/build_dataset.py +22 -0
  33. dataprocesser/config_example.yaml +43 -0
  34. dataprocesser/create_csv.py +87 -0
  35. dataprocesser/create_csv_xcat.py +25 -0
  36. dataprocesser/create_json_lodopab.py +59 -0
  37. dataprocesser/create_json_xcat.py +70 -0
  38. dataprocesser/customized_datasets.py +115 -0
  39. dataprocesser/customized_normalization.py +149 -0
  40. dataprocesser/customized_transform_list.py +149 -0
  41. dataprocesser/customized_transforms.py +507 -0
  42. dataprocesser/data_processing/.gitignore +4 -0
  43. dataprocesser/data_processing/README.md +20 -0
  44. dataprocesser/data_processing/__init__.py +2 -0
  45. dataprocesser/data_processing/data_process/.gitignore +1 -0
  46. dataprocesser/data_processing/data_process/CTbatchevaluate.py +49 -0
  47. dataprocesser/data_processing/data_process/CTevaluate.py +137 -0
  48. dataprocesser/data_processing/data_process/convert_dicoms.py +83 -0
  49. dataprocesser/data_processing/data_process/make_cond.py +37 -0
  50. dataprocesser/data_processing/data_process/matlab/BCELossIllustration.m +53 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !.gitignore
2
+ /logs
3
+ __pycache__*
4
+ /*pycache*
5
+ /model/*pycache*
6
+ **/*pycache*
7
+ *.out
8
+ venv*
9
+ helix_log
10
+ checkpoints
11
+ MONAI
12
+ *.svg
13
+ data
14
+ generative-models
15
+ datasets
16
+ notuse
README.md CHANGED
@@ -1,20 +1 @@
1
- ---
2
- title: Frankenstein
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Artificial Life
12
- license: mit
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
1
+ #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataprocesser/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__*
dataprocesser/Preprocess_CT_Mask_generation.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import nrrd
5
+ import SimpleITK as sitk
6
+ import cv2
7
+
8
+ import numpy as np
9
+
10
+ def shift_to_min_zero(arr):
11
+ """
12
+ Shifts the input NumPy array so that the minimum value becomes 0.
13
+
14
+ Parameters:
15
+ arr (numpy.ndarray): The input array to shift.
16
+
17
+ Returns:
18
+ numpy.ndarray: The shifted array with the minimum value as 0.
19
+ """
20
+ min_value = np.min(arr) # Find the minimum value
21
+ shifted_array = arr - min_value # Subtract the minimum value from all elements
22
+ return shifted_array
23
+
24
+
25
+ def create_body_mask(numpy_img, body_threshold=-500, min_contour_area=10000):
26
+ """
27
+ Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
28
+
29
+ Args:
30
+ tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
31
+
32
+ Returns:
33
+ torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
34
+ """
35
+ # Convert tensor to numpy array
36
+ numpy_img = np.ascontiguousarray(numpy_img.astype(np.int16)) # Ensure we can handle negative values correctly
37
+ #numpy_img = numpy_img.astype(np.int16)
38
+
39
+ # Threshold the image at -500 to separate potential body from the background
40
+ binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
41
+
42
+ # Find contours from the binary image
43
+ contours, _ = cv2.findContours(binary_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
44
+
45
+ # Create an empty mask
46
+ mask = np.zeros_like(binary_img)
47
+
48
+ VERBOSE = False
49
+ # Fill all detected body contours
50
+ if contours:
51
+ for contour in contours:
52
+ if cv2.contourArea(contour) >= min_contour_area:
53
+ if VERBOSE:
54
+ print('current contour area: ', cv2.contourArea(contour), 'threshold: ', min_contour_area)
55
+ cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
56
+
57
+ return mask
58
+
59
+ def apply_mask(normalized_image_array, mask_array):
60
+ return normalized_image_array * mask_array
61
+
62
+ def print_all_info(data, title):
63
+ print(f'min, max of {title}:', np.min(data), np.max(data))
64
+
65
+ def process_CT_segmentation_numpy(mask, csv_simulation_values):
66
+ #df = pd.read_csv(csv_file)
67
+ df = csv_simulation_values
68
+ # Create a dictionary to map organ index to HU values
69
+ hu_values = dict(zip(df['Order Number'], df['HU Value']))
70
+ order_begin_from_0 = True if df['Order Number'].min()==0 else False
71
+
72
+ hu_mask = np.zeros_like(mask)
73
+ # Value Assigment
74
+ hu_mask[mask == 0] = -1000 # background
75
+ for organ_index, hu_value in hu_values.items():
76
+ assert isinstance(hu_value, int), f"Expected mask value an integer, but got {hu_value}. Ensure the mask is created by fine mode of totalsegmentator"
77
+ assert isinstance(organ_index, int), f"Expected organ_index an integer, but got {organ_index}. Ensure the mask is created by fine mode of totalsegmentator"
78
+ if order_begin_from_0:
79
+ hu_mask[mask == (organ_index+1)] = hu_value # mask value begin from 1 as body value, other than 0 in TA2 table, so organ_index+1
80
+ else:
81
+ hu_mask[mask == (organ_index)] = hu_value
82
+ return hu_mask
83
+
84
+
85
+ # 处理单个图像和分割图
86
+ def process_image(input_path, contour_path, seg_path, seg_tissue_path, csv_simulation_values, output_path1, output_path2, output_path3, body_threshold):
87
+ # 读取原始 MR 图像和分割图
88
+ if input_path.endswith('.nrrd'):
89
+ img, header = nrrd.read(input_path)
90
+ segmentation_img, header_seg = nrrd.read(seg_path)
91
+ seg_tissue_img, header_seg_tissue = nrrd.read(seg_tissue_path)
92
+ elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
93
+ import nibabel as nib
94
+ img_metadata = nib.load(input_path)
95
+ img = img_metadata.get_fdata()
96
+ affine = img_metadata.affine
97
+
98
+ seg_metadata = nib.load(seg_path)
99
+ segmentation_img = seg_metadata.get_fdata()
100
+ affine_seg = seg_metadata.affine
101
+
102
+ seg_tissue_metadata = nib.load(seg_tissue_path)
103
+ seg_tissue_img = seg_tissue_metadata.get_fdata()
104
+
105
+ # extract contour
106
+ body_contour = np.zeros_like(img, dtype=np.int16)
107
+ for i in range(img.shape[-1]):
108
+ slice_data = img[:, :, i]
109
+ body_contour[:, :, i] = create_body_mask(slice_data, body_threshold=body_threshold)
110
+
111
+ # CT images don't need additional normalization
112
+ #
113
+
114
+ # normalize to 0-1
115
+ img_normalized = shift_to_min_zero(img)
116
+ # img_normalized = img_normalized/2000 # scale factor
117
+
118
+ # apply mask to ct img
119
+ masked_image = apply_mask(img_normalized, body_contour)
120
+
121
+ # process the mask image
122
+ seg = segmentation_img
123
+ tissue = seg_tissue_img
124
+ tissue[tissue!=0] += 200
125
+ # Create a mask for overlapping areas
126
+ overlap_mask = (seg > 0) & (tissue > 0)
127
+
128
+ # For overlapping areas, keep the lower value (organ values in seg)
129
+ merged_mask = tissue.copy()
130
+ merged_mask[overlap_mask] = seg[overlap_mask]
131
+
132
+ # Keep all non-overlapping areas
133
+ merged_mask[seg > 0] = seg[seg > 0]
134
+
135
+ combined_array = merged_mask + body_contour
136
+
137
+ processed_segmentation = combined_array
138
+
139
+ # assign simulation value to ct segmentation mask
140
+ assigned_segmentation = process_CT_segmentation_numpy(combined_array, csv_simulation_values)
141
+
142
+ if input_path.endswith('.nrrd'):
143
+ # 保存处理后的 MR 图像
144
+ nrrd.write(output_path1, masked_image, header)
145
+
146
+ # 保存处理后的分割图
147
+ nrrd.write(output_path2, processed_segmentation, header_seg)
148
+
149
+ # save the body contour mask
150
+
151
+ elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
152
+ img_processed = nib.Nifti1Image(masked_image, affine)
153
+ nib.save(img_processed, output_path1)
154
+ seg_processed = nib.Nifti1Image(processed_segmentation, affine_seg)
155
+ nib.save(seg_processed, output_path2)
156
+ contour_processed = nib.Nifti1Image(body_contour, affine_seg)
157
+ assigned_segmentation_processed = nib.Nifti1Image(assigned_segmentation, affine_seg)
158
+ # Split the path into directory and filename
159
+ directory, filename = os.path.split(output_path2)
160
+ contour_filename = filename.replace('_seg_merged', '_contour')
161
+ contour_path = os.path.join(directory, contour_filename)
162
+ nib.save(contour_processed, contour_path)
163
+
164
+ nib.save(assigned_segmentation_processed, output_path3)
165
+
166
+ return processed_segmentation
167
+
168
+ def analyse_hist(input_path):
169
+ if input_path.endswith('.nrrd'):
170
+ img, header = nrrd.read(input_path)
171
+ elif input_path.endswith('.nii.gz'):
172
+ import nibabel as nib
173
+ img_metadata = nib.load(input_path)
174
+ img = img_metadata.get_fdata()
175
+ affine = img_metadata.affine
176
+ import numpy as np
177
+ import matplotlib.pyplot as plt
178
+
179
+ # Plot the histogram
180
+ print('shape of img: ', img.shape)
181
+ plt.hist(img[:, :, 50], bins=30, edgecolor='black', alpha=0.7)
182
+ plt.xlabel('Value')
183
+ plt.ylabel('Frequency')
184
+ plt.title('Value Distribution')
185
+ plt.show()
186
+
187
+
188
+ def process_csv(csv_file, output_root, csv_simulation_file, body_threshold=-500):
189
+ # read csv to get simulation value
190
+ csv_simulation_values = pd.read_csv(csv_simulation_file) #.to_numpy()
191
+ #csv_simulation_values = pd.read_csv(csv_simulation_file)
192
+
193
+ # check 2-dimensional csv_simulation_values
194
+ if csv_simulation_values.ndim == 1:
195
+ raise ValueError("CSV should contain two columns: organ_index and simulation_value")
196
+
197
+ if not os.path.exists(csv_file):
198
+ print('csv:', csv_file)
199
+ raise ValueError('csv_file must input a available csv file in simplified form: id, Aorta_diss, seg, img!')
200
+ else:
201
+ print(f'use csv: {csv_file}')
202
+
203
+ data_frame = pd.read_csv(csv_file)
204
+ if len(data_frame) == 0:
205
+ raise RuntimeError(f"Found 0 images in: {csv_file}")
206
+ patient_IDs = data_frame.iloc[:, 0].tolist()
207
+ Aorta_diss = data_frame.iloc[:, 1].tolist()
208
+ segs = data_frame.iloc[:, 2].tolist()
209
+ images = data_frame.iloc[:, 3].tolist()
210
+
211
+ from tqdm import tqdm
212
+ dataset_list = []
213
+ for idx in tqdm(range(len(images))):
214
+ if (images[idx].endswith('.nii.gz') and segs[idx].endswith('.nii.gz')) or \
215
+ (images[idx].endswith('.nii') and segs[idx].endswith('.nii')):
216
+ input_file_path = images[idx]
217
+ seg_file_path = segs[idx]
218
+ patient_id = patient_IDs[idx]
219
+ ad = Aorta_diss[idx]
220
+ seg_tissue_file_path = seg_file_path.replace("_seg","_seg_tissue")
221
+
222
+ root_dir = os.path.dirname(input_file_path)
223
+
224
+ # Get root path (directory path)
225
+ root_path = os.path.dirname(seg_file_path)
226
+ ct_processed_file_name = f"{patient_id}_ct_processed.nii.gz"
227
+ seg_merged_file_name = f"{patient_id}_ct_seg_merged.nii.gz"
228
+ seg_merged_assigned_mask_file_name = f"{patient_id}_ct_seg_merged_assigned_mask.nii.gz"
229
+
230
+ os.makedirs(output_root, exist_ok=True)
231
+ output_file_path1 = os.path.join(output_root, ct_processed_file_name)
232
+ output_file_path2 = os.path.join(output_root, seg_merged_file_name)
233
+ output_file_path3 = os.path.join(output_root, seg_merged_assigned_mask_file_name)
234
+ print(f"Processing {input_file_path} with segmentation {seg_file_path}")
235
+ print(f"Save results to {output_file_path1} and {output_file_path2} and {output_file_path3} \n")
236
+
237
+
238
+ processed_seg = process_image(input_file_path, None, seg_file_path, seg_tissue_file_path, csv_simulation_values, output_file_path1, output_file_path2, output_file_path3, body_threshold)
239
+
240
+ # processed_mr_csv_file = ...
241
+ csv_mr_line = [patient_id,ad, output_file_path2, output_file_path1, output_file_path3]
242
+ dataset_list.append(csv_mr_line)
243
+
244
+ import csv
245
+ output_csv_file=os.path.join(output_root, 'processed_csv_file.csv')
246
+ with open(output_csv_file, 'w', newline='') as f:
247
+ csvwriter = csv.writer(f)
248
+ csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img', 'seg_mask'])
249
+ csvwriter.writerows(dataset_list)
250
+
251
+ if __name__ == "__main__":
252
+ import argparse
253
+ csv_file = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\datacsv\ct_synthrad_test_newserver.csv'
254
+ output_root = r'E:\Projects\yang_proj\data\synthrad\processed'
255
+ csv_simulation_file = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\TA2_CT_from1.csv'
256
+ process_csv(csv_file, output_root, csv_simulation_file, body_threshold=-500)
257
+
258
+ '''parser = argparse.ArgumentParser(description="Process MR images and segmentation maps, apply masks and replace grayscale values.")
259
+ parser.add_argument('--input_folder1', required=True, help="Path to the folder containing input MR .nrrd files.")
260
+ parser.add_argument('--input_folder2', required=True, help="Path to the folder containing segmentation .nrrd files.")
261
+ parser.add_argument('--output_folder1', required=True, help="Path to the folder to save the output MR files.")
262
+ parser.add_argument('--output_folder2', required=True, help="Path to the folder to save the output segmentation files.")
263
+ parser.add_argument('--csv_simulation_file', required=True, help="CSV file containing simulated CT grayscale values.")
264
+ parser.add_argument('--body_threshold', type=int, default=50, help="Threshold to separate body from background.")
265
+ args = parser.parse_args()
266
+
267
+ process_folder(args.input_folder1, args.input_folder2, args.output_folder1, args.output_folder2, args.csv_simulation_file, args.body_threshold)'''
dataprocesser/Preprocess_MRCT_mask_conversion.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import nibabel as nib
4
+ import torch
5
+ import os
6
+ from tqdm import tqdm
7
+ #from dataprocesser.customized_transforms import create_body_contour
8
+ from dataprocesser.Preprocess_MR_Mask_generation import process_segmentation
9
+ from dataprocesser.Preprocess_CT_Mask_generation import process_CT_segmentation_numpy
10
+ import difflib
11
+
12
+ def find_best_match_smart(organ_name, target_names):
13
+ # 完全匹配
14
+ if organ_name in target_names:
15
+ return organ_name
16
+ # 精确 startswith 匹配(如 vertebrae → vertebrae_Lx)
17
+ matches = [t for t in target_names if t.startswith(organ_name)]
18
+ if matches:
19
+ return matches[0]
20
+ # 再 fallback 到 difflib,但严格一些
21
+ import difflib
22
+ closes = difflib.get_close_matches(organ_name, target_names, n=1, cutoff=0.8)
23
+ return closes #[0] if close else None
24
+
25
+
26
+ def convert_segmentation_mask_torch(source_mask, source_csv, target_csv, body_contour_value=1):
27
+ """
28
+ Converts segmentation mask values from source modality to target modality based on organ name mapping.
29
+
30
+ Parameters:
31
+ - source_mask (torch.Tensor): The source segmentation mask tensor.
32
+ - source_csv (str): Path to the CSV file of the source modality (CT or MR).
33
+ - target_csv (str): Path to the CSV file of the target modality (MR or CT).
34
+ - body_contour_value (int): The class value for "body contour" in the target modality.
35
+
36
+ Returns:
37
+ - target_mask (torch.Tensor): The converted segmentation mask tensor.
38
+ """
39
+ # Load the source and target anatomy lists
40
+ source_df = pd.read_csv(source_csv)
41
+ target_df = pd.read_csv(target_csv)
42
+
43
+ # Create dictionaries mapping class values to organ names
44
+ source_mapping = {}
45
+ for _, row in source_df.iterrows():
46
+ organ_name = row['Organ Name']
47
+ class_value = row.iloc[0]
48
+ source_mapping.setdefault(organ_name, []).append(class_value)
49
+
50
+ target_mapping = {}
51
+ for _, row in target_df.iterrows():
52
+ organ_name = row['Organ Name']
53
+ class_value = row.iloc[0]
54
+ target_mapping.setdefault(organ_name, []).append(class_value)
55
+
56
+ # Create a reverse mapping from class values to organ names for the source modality
57
+ class_to_organ = {class_value: organ_name for organ_name, class_values in source_mapping.items() for class_value in class_values}
58
+
59
+ # Initialize the target mask with zeros
60
+ target_mask = torch.zeros_like(source_mask, dtype=source_mask.dtype)
61
+
62
+ # Convert each unique class in the source mask
63
+ unique_classes = torch.unique(source_mask)
64
+ for class_value in unique_classes:
65
+ # Find the corresponding organ name in the source modality
66
+ organ_name = class_to_organ.get(class_value.item(), None)
67
+
68
+ if class_value.item() == 0: # Preserve background as is
69
+ target_value = 0
70
+ else:
71
+ # If organ name exists, find the corresponding target class values
72
+ if organ_name and organ_name in target_mapping:
73
+ # Pick the first target class value (or handle overlaps if needed)
74
+ target_value = target_mapping[organ_name][0]
75
+ else:
76
+ # Use body contour class value for unmapped organs
77
+ target_value = body_contour_value
78
+ #print(f'Processing for class {class_value.item()}')
79
+ #print(f'Not found {organ_name} in target mapping, replaced with body contour.')
80
+
81
+ # Replace class values in the target mask
82
+ target_mask[source_mask == class_value] = target_value
83
+
84
+ return target_mask
85
+
86
+ def convert_segmentation_mask(source_mask, source_csv, target_csv, body_contour_value=1000):
87
+ """
88
+ Converts segmentation mask values from source modality to target modality based on organ name mapping.
89
+
90
+ Parameters:
91
+ - source_mask (ndarray): The source segmentation mask array.
92
+ - source_csv (str): Path to the CSV file of the source modality (CT or MR).
93
+ - target_csv (str): Path to the CSV file of the target modality (MR or CT).
94
+ - body_contour_value (int): The class value for "body contour" in the target modality.
95
+
96
+ Returns:
97
+ - target_mask (ndarray): The converted segmentation mask.
98
+ """
99
+ # Load the source and target anatomy lists
100
+ source_df = pd.read_csv(source_csv)
101
+ target_df = pd.read_csv(target_csv)
102
+
103
+ # Create dictionaries mapping class values to organ names and vice versa
104
+ source_mapping = {}
105
+ for _, row in source_df.iterrows():
106
+ organ_name = row['Organ Name']
107
+ class_value = row.iloc[0]
108
+ source_mapping.setdefault(organ_name, []).append(class_value)
109
+
110
+ target_mapping = {}
111
+ for _, row in target_df.iterrows():
112
+ organ_name = row['Organ Name']
113
+ class_value = row.iloc[0]
114
+ target_mapping.setdefault(organ_name, []).append(class_value)
115
+
116
+ # Create a reverse mapping from class values to organ names for the source modality
117
+ class_to_organ = {class_value: organ_name for organ_name, class_values in source_mapping.items() for class_value in class_values}
118
+ # Initialize the target mask
119
+ target_organ_names = list(target_mapping.keys())
120
+ target_mask = np.full_like(source_mask, 0, dtype=source_mask.dtype)
121
+
122
+ # Convert each unique class in the source mask
123
+ for class_value in np.unique(source_mask):
124
+ # Find the corresponding organ name in the source modality
125
+ organ_name = class_to_organ.get(class_value, None)
126
+ if class_value == 0:
127
+ target_value = 0
128
+ else:
129
+ # If organ name exists, find the corresponding target class values
130
+ if organ_name and organ_name in target_mapping:
131
+ # Pick the first target class value (or handle overlaps if needed)
132
+ target_value = target_mapping[organ_name][0]
133
+ else:
134
+ # Manual mapping: source organ name → target organ name
135
+ manual_mapping = {
136
+ 'intervertebral_discs': 'spinal_cord',
137
+ 'quadriceps_femoris_left':'gluteus_maximus_left',
138
+ 'quadriceps_femoris_right':'gluteus_maximus_right',
139
+ 'thigh_medial_compartment_left': 'gluteus_maximus_left',
140
+ 'thigh_medial_compartment_right': 'gluteus_maximus_right',
141
+ 'thigh_posterior_compartment_left': 'gluteus_maximus_left',
142
+ 'thigh_posterior_compartment_right': 'gluteus_maximus_right',
143
+ 'sartorius_left': 'gluteus_maximus_left',
144
+ 'sartorius_right': 'gluteus_maximus_right',
145
+ # Add more mappings here as needed
146
+ }
147
+ # Check manual mapping first
148
+ if organ_name in manual_mapping and manual_mapping[organ_name] in target_mapping:
149
+ matched_name = manual_mapping[organ_name]
150
+ target_value = target_mapping[matched_name][0]
151
+ print(f"[Manual match] '{organ_name}' → '{matched_name}' → label {target_value}")
152
+ else:
153
+ # Fuzzy match fallback
154
+ close_matches = difflib.get_close_matches(organ_name, target_organ_names, n=1, cutoff=0.4)
155
+ if close_matches:
156
+ matched_name = close_matches[0]
157
+ target_value = target_mapping[matched_name][0]
158
+ print(f"[Fuzzy match] '{organ_name}' → '{matched_name}' → label {target_value}")
159
+ else:
160
+ print(f"[Warning] No match for '{organ_name}', using body contour value.")
161
+ target_value = body_contour_value
162
+ '''close_matches = difflib.get_close_matches(organ_name, target_organ_names, n=1, cutoff=0.4)
163
+ if close_matches:
164
+ matched_name = close_matches[0]
165
+ target_value = target_mapping[matched_name][0]
166
+ print(f"[Fuzzy match] '{organ_name}' → '{matched_name}' → label {target_value}")
167
+ else:
168
+ print(f"[Warning] No match for '{organ_name}', using body contour value.")
169
+ target_value = body_contour_value'''
170
+ # Replace class values in the target mask
171
+ target_mask[source_mask == class_value] = target_value
172
+
173
+ return target_mask
174
+
175
+ def run_mask_conversion(
176
+ mask = r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis\1PA001\ct_seg.nii.gz',
177
+ img = r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis\1PA001\ct.nii.gz',
178
+ MR_csv = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\TA2_MR_for_convert.csv',
179
+ CT_csv = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\TA2_CT_for_convert.csv',
180
+ output_path = r'mr_mask_from_ct.nii.gz', # output_path = r'ct_mask_from_mr.nii.gz'
181
+ mode = 'ct2mr'
182
+ ):
183
+ if mode == 'ct2mr':
184
+ body_threshold=-500
185
+ source_csv = CT_csv
186
+ target_csv = MR_csv
187
+ elif mode == 'mr2ct':
188
+ body_threshold=5
189
+ source_csv = MR_csv
190
+ target_csv = CT_csv
191
+
192
+ source_mask = mask
193
+ img = img
194
+
195
+ seg_metadata = nib.load(source_mask)
196
+ seg = seg_metadata.get_fdata()
197
+ affine = seg_metadata.affine
198
+
199
+ img_metadata = nib.load(img)
200
+ img = img_metadata.get_fdata()
201
+ affine = img_metadata.affine
202
+
203
+ '''body_contour = np.zeros_like(img, dtype=np.int16)
204
+ for i in range(img.shape[2]):
205
+ slice_data = img[:, :, i]
206
+ body_contour[:, :, i] = create_body_contour(slice_data, body_threshold)
207
+ contour = body_contour
208
+ seg_with_contour = seg+contour'''
209
+ seg_with_contour = seg
210
+ target_mask = convert_segmentation_mask(seg_with_contour, source_csv, target_csv, body_contour_value=1)
211
+ if mode == 'ct2mr':
212
+ csv_simulation_file = MR_csv
213
+ csv_values = pd.read_csv(csv_simulation_file, header=None).to_numpy()
214
+ target_mask = process_segmentation(target_mask, csv_values)
215
+ elif mode == 'mr2ct':
216
+ csv_simulation_file = CT_csv
217
+ target_mask = process_CT_segmentation_numpy(target_mask, csv_simulation_file)
218
+ img_processed = nib.Nifti1Image(target_mask, affine)
219
+ nib.save(img_processed, output_path)
220
+
221
+
222
+ def run_mask_conversion_synthrad_test(synthrad_root = r'E:\Projects\yang_proj\data\synthrad\Task1\pelvis', patient_list=['1PA001'], mode = 'ct2mr', output_csv_file = 'ct2mr_conversion.csv'):
223
+ dataset_list = []
224
+ for patient in tqdm(patient_list):
225
+ mr_mask = os.path.join(synthrad_root, patient, 'mr_merged_seg.nii.gz')
226
+ mr_img = os.path.join(synthrad_root, patient, 'mr.nii.gz')
227
+ ct_mask = os.path.join(synthrad_root, patient, 'ct_seg.nii.gz')
228
+ ct_img = os.path.join(synthrad_root, patient, 'ct.nii.gz')
229
+ MR_csv = r'synthrad_conversion/TA2_MR_for_convert.csv'
230
+ CT_csv = r'synthrad_conversion/TA2_CT_for_convert.csv'
231
+ if mode == 'ct2mr':
232
+ preprocessed_mr_path = r'E:\Projects\yang_proj\data\anika\MR_processed'
233
+ preprocessed_mr_img = os.path.join(preprocessed_mr_path, f'mr_{patient}.nii.gz')
234
+ output_path = os.path.join(synthrad_root, patient, 'mr_mask_from_ct.nii.gz')
235
+ csv_mr_line = [patient,0,output_path,preprocessed_mr_img]
236
+
237
+ elif mode == 'mr2ct':
238
+ output_path = os.path.join(synthrad_root, patient, 'ct_mask_from_mr.nii.gz')
239
+ csv_mr_line = [patient,0,output_path,ct_img]
240
+ run_mask_conversion(mr_mask, mr_img, ct_mask, ct_img, MR_csv, CT_csv, output_path, mode)
241
+ dataset_list.append(csv_mr_line)
242
+
243
+ import csv
244
+ with open(output_csv_file, 'w', newline='') as f:
245
+ csvwriter = csv.writer(f)
246
+ csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
247
+ csvwriter.writerows(dataset_list)
248
+
249
+ def run_mask_conversion_csv(csv_file = r'E:\Projects\yang_proj\data\synthrad\processed\processed_ct_csv_file.csv', mode = 'ct2mr', output_csv_file = 'ct2mr_conversion.csv'):
250
+ data_frame = pd.read_csv(csv_file)
251
+ if len(data_frame) == 0:
252
+ raise RuntimeError(f"Found 0 images in: {csv_file}")
253
+ patient_IDs = data_frame.iloc[:, 0].tolist()
254
+ Aorta_diss = data_frame.iloc[:, 1].tolist()
255
+ segs = data_frame.iloc[:, 2].tolist()
256
+ images = data_frame.iloc[:, 3].tolist()
257
+ aligned_segs = data_frame.iloc[:, 4].tolist()
258
+ dataset_list = []
259
+ synthrad_root = r"E:\Projects\yang_proj\data\synthrad\Task1\pelvis"
260
+ from tqdm import tqdm
261
+ for idx in tqdm(range(len(images))):
262
+ MR_csv = r'synthrad_conversion/TA2_MR_for_convert.csv'
263
+ CT_csv = r'synthrad_conversion/TA2_CT_for_convert.csv'
264
+ patient = patient_IDs[idx]
265
+ if mode == 'ct2mr':
266
+ ct_mask = segs[idx]
267
+ ct_img = images[idx]
268
+ preprocessed_mr_path = r'E:\Projects\yang_proj\data\anika\MR_processed'
269
+ preprocessed_mr_img = os.path.join(preprocessed_mr_path, f'mr_{patient}.nii.gz')
270
+
271
+ mr_mask_from_ct_folder = r'E:\Projects\yang_proj\data\synthrad\mr_mask_from_ct'
272
+ output_path = os.path.join(mr_mask_from_ct_folder, f'{patient}_mr_mask_from_ct.nii.gz')
273
+ csv_mr_line = [patient,0,output_path,preprocessed_mr_img]
274
+ run_mask_conversion(ct_mask, ct_img, MR_csv, CT_csv, output_path, mode)
275
+
276
+ elif mode == 'mr2ct':
277
+ mr_mask = os.path.join(synthrad_root, patient, 'mr_merged_seg.nii.gz')
278
+ mr_img = os.path.join(synthrad_root, patient, 'mr.nii.gz')
279
+ output_path = os.path.join(synthrad_root, patient, 'ct_mask_from_mr.nii.gz')
280
+ csv_mr_line = [patient,0,output_path,ct_img]
281
+ run_mask_conversion(mr_mask, mr_img, MR_csv, CT_csv, output_path, mode)
282
+ dataset_list.append(csv_mr_line)
283
+
284
+ import csv
285
+ with open(output_csv_file, 'w', newline='') as f:
286
+ csvwriter = csv.writer(f)
287
+ csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
288
+ csvwriter.writerows(dataset_list)
289
+
290
+ if __name__ == "__main__":
291
+ csv_file = r'E:\Projects\yang_proj\data\synthrad\processed\processed_csv_file.csv'
292
+ mode = 'ct2mr'
293
+ output_csv_file = r'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\datacsv\ct2mr_conversion.csv'
294
+ run_mask_conversion_csv(csv_file = csv_file, mode = mode, output_csv_file = output_csv_file)
dataprocesser/Preprocess_MR_Mask_generation.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import nrrd
5
+ import SimpleITK as sitk
6
+ import cv2
7
+
8
+ from dataprocesser.preprocess_MR import step3_vibe_resetsignal
9
+ """
10
+ 该代码用于处理一组 MR 图像和对应的分割图,应用掩膜、进行归一化,并根据 CSV 文件中的仿真 MR 灰度值对分割图进行替换。最后将处理后的 MR 图像和分割图保存。
11
+
12
+ 主要步骤:
13
+ 1. **读取数据**:从指定的文件夹中读取 MR 图像和对应的分割图。
14
+ 2. **归一化处理**:对 MR 图像进行归一化,将其值范围映射到 0 到 255 之间。
15
+ 3. **轮廓提取**:从归一化后的 MR 图像中提取出主体区域的轮廓(根据给定的阈值分割),创建掩膜。
16
+ 4. **掩膜应用**:将提取出的掩膜应用到归一化后的 MR 图像上,保留主体区域,抑制背景。
17
+ 5. **分割图处理**:读取对应的分割图,并与提取出的轮廓进行叠加,之后根据 CSV 文件中的仿真 CT 值替换分割图中的灰度值。
18
+ 6. **图像保存**:将处理后的 MR 图像和修改后的分割图保存到指定的输出文件夹中,保证其空间属性和几何信息与输入图像一致。
19
+ 7. **输出**:在 ITK-SNAP 等医学图像工具中打开时, MR 图像和分割图能够保持同步和正确的比例显示。
20
+
21
+ 函数简介:
22
+ - `normalize`: 对 MR 图像进行归一化处理,将像素值范围映射到 [0, 255]。
23
+ - `create_body_mask`: 从图像中提取出身体的轮廓,生成二值掩膜。
24
+ - `apply_mask`: 将提取的掩膜应用到 MR 图像上,保留轮廓内部的区域。
25
+ - `process_segmentation`: 读取分割图,并根据 CSV 文件中的仿真 CT 值对其灰度值进行替换。
26
+ - `process_image`: 处理单个 MR 图像及其对应的分割图,包括归一化、轮廓提取、掩膜应用、分割图处理等。
27
+ - `process_folder`: 处理整个文件夹中的 MR 图像和分割图,逐一处理所有图像并保存结果。
28
+ """
29
+
30
+ # 归一化函数
31
+ def normalize(img, vmin_out=0, vmax_out=1, norm_min_v=None, norm_max_v=None, epsilon=1e-5):
32
+ if norm_min_v is None and norm_max_v is None:
33
+ norm_min_v = np.min(img)
34
+ norm_max_v = np.max(img)
35
+ img = np.clip(img, norm_min_v, norm_max_v)
36
+ img = (img - norm_min_v) / (norm_max_v - norm_min_v + epsilon)
37
+ img = img * (vmax_out - vmin_out) + vmin_out
38
+ return img
39
+
40
+ # 创建轮廓掩膜
41
+ def create_body_mask_simple(numpy_img, body_threshold=50):
42
+ numpy_img = numpy_img.astype(np.int16)
43
+ body_mask = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
44
+ contours, _ = cv2.findContours(body_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
45
+ mask = np.zeros_like(body_mask, dtype=np.uint8)
46
+
47
+ if contours:
48
+ largest_contour = max(contours, key=cv2.contourArea)
49
+ mask = np.ascontiguousarray(mask)
50
+ largest_contour = np.ascontiguousarray(largest_contour)
51
+ cv2.drawContours(mask, [largest_contour], -1, 1, thickness=cv2.FILLED)
52
+
53
+ return mask
54
+
55
+ def create_body_mask(numpy_img, body_threshold=-500, min_contour_area=10000):
56
+ """
57
+ Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
58
+
59
+ Args:
60
+ tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
61
+
62
+ Returns:
63
+ torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
64
+ """
65
+ # Convert tensor to numpy array
66
+ numpy_img = np.ascontiguousarray(numpy_img.astype(np.int16)) # Ensure we can handle negative values correctly
67
+ #numpy_img = numpy_img.astype(np.int16)
68
+
69
+ # Threshold the image at -500 to separate potential body from the background
70
+ binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
71
+
72
+ # Find contours from the binary image
73
+ contours, _ = cv2.findContours(binary_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
74
+
75
+ # Create an empty mask
76
+ mask = np.zeros_like(binary_img)
77
+
78
+ VERBOSE = False
79
+ # Fill all detected body contours
80
+ if contours:
81
+ for contour in contours:
82
+ if cv2.contourArea(contour) >= min_contour_area:
83
+ if VERBOSE:
84
+ print('current contour area: ', cv2.contourArea(contour), 'threshold: ', min_contour_area)
85
+ cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
86
+
87
+ return mask
88
+
89
+ def apply_mask(normalized_image_array, mask_array):
90
+ return normalized_image_array * mask_array
91
+
92
+ def print_all_info(data, title):
93
+ print(f'min, max of {title}:', np.min(data), np.max(data))
94
+
95
+ # process the segmentation, replace the classes with simulated MR values
96
+ def process_segmentation(combined_array, csv_simulation_values, mr_signal_formula=step3_vibe_resetsignal.calculate_signal_vibe):
97
+ combined_array = combined_array.astype(np.int16)
98
+ print_all_info(combined_array, 'combine')
99
+ # two columns of unique value 和 simulation value
100
+ # the first element will not be included
101
+ organ_indexs = csv_simulation_values[1:, 0] # first column: organ index
102
+ T1_values = csv_simulation_values[1:, 1] # second column: simulate MRI value
103
+ T2_values = csv_simulation_values[1:, 2]
104
+ Rho_values = csv_simulation_values[1:, 3]
105
+ order_begin_from_0 = True if organ_indexs.astype(int).min()==0 else False
106
+ #print('organ order number begin from 0:', order_begin_from_0)
107
+ #print(organ_indexs)
108
+ assign_value_mask = np.zeros_like(combined_array)
109
+
110
+ step=0
111
+ for step in range(len(organ_indexs)):
112
+ organ_index = organ_indexs[step] # in csv file, organs begin with 1
113
+ t1 = float(T1_values[step])
114
+ t2 = float(T2_values[step])
115
+ rho = float(Rho_values[step])
116
+
117
+ simulation_value = mr_signal_formula(t1, t2, rho)
118
+ organ_index = int(organ_index)
119
+ if order_begin_from_0:
120
+ #print("order in csv begin from 0")
121
+ assign_value_mask[combined_array == organ_index+1] = simulation_value # organ_index+ 1
122
+ else:
123
+ #print("order in csv begin from 1")
124
+ assign_value_mask[combined_array == organ_index] = simulation_value
125
+ step+=1
126
+ print_all_info(assign_value_mask, 'assignment')
127
+ return assign_value_mask
128
+
129
+ # 处理单个图像和分割图
130
+ def process_image(input_path, contour_path, seg_path, csv_simulation_values, output_path1, output_path2, body_threshold):
131
+ # 读取原始 MR 图像和分割图
132
+ if input_path.endswith('.nrrd'):
133
+ img, header = nrrd.read(input_path)
134
+ segmentation_img, header_seg = nrrd.read(seg_path)
135
+ elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
136
+ import nibabel as nib
137
+ img_metadata = nib.load(input_path)
138
+ img = img_metadata.get_fdata()
139
+ affine = img_metadata.affine
140
+
141
+ seg_metadata = nib.load(seg_path)
142
+ segmentation_img = seg_metadata.get_fdata()
143
+
144
+ # 归一化处理
145
+ norm_max=255 #255
146
+ low_percentile = 5
147
+ high_percentile = 90
148
+ img_normalized = normalize(img, 0, norm_max, np.percentile(img, low_percentile), np.percentile(img, high_percentile), epsilon=0)
149
+
150
+ # 提取轮廓图
151
+ body_contour = np.zeros_like(img, dtype=np.int16)
152
+ for i in range(img.shape[2]):
153
+ slice_data = img[:, :, i]
154
+ body_contour[:, :, i] = create_body_mask(slice_data, body_threshold=body_threshold)
155
+
156
+ # 应用掩膜到归一化 MR 图像
157
+ masked_image = apply_mask(img_normalized, body_contour)
158
+
159
+ # 处理分割图
160
+ # add contour background to the segmentation (all region inside body + 1)
161
+ combined_array = segmentation_img + body_contour
162
+ combined_array = np.clip(combined_array, 0, np.max(segmentation_img) + 1)
163
+ print_all_info(segmentation_img, 'seg')
164
+ processed_segmentation = process_segmentation(combined_array, csv_simulation_values)
165
+
166
+ # normalize to 0-1
167
+ # masked_image = masked_image/norm_max
168
+ # processed_segmentation = processed_segmentation/norm_max
169
+
170
+ if input_path.endswith('.nrrd'):
171
+ # 保存处理后的 MR 图像
172
+ nrrd.write(output_path1, masked_image, header)
173
+
174
+ # 保存处理后的分割图
175
+ nrrd.write(output_path2, processed_segmentation, header_seg)
176
+
177
+ # save the body contour mask
178
+
179
+ elif input_path.endswith('.nii.gz') or input_path.endswith('.nii'):
180
+ img_processed = nib.Nifti1Image(masked_image, affine)
181
+ nib.save(img_processed, output_path1)
182
+ seg_processed = nib.Nifti1Image(processed_segmentation, affine)
183
+ nib.save(seg_processed, output_path2)
184
+ contour_processed = nib.Nifti1Image(body_contour, affine)
185
+
186
+ # Split the path into directory and filename
187
+ directory, filename = os.path.split(output_path2)
188
+ new_filename = filename.replace('seg', 'contour')
189
+ contour_path = os.path.join(directory, new_filename)
190
+
191
+ nib.save(contour_processed, contour_path)
192
+ return processed_segmentation
193
+
194
+ # 处理文件夹
195
+ def process_folder(input_folder1, input_folder2, output_folder1, output_folder2, csv_simulation_file, body_threshold=50):
196
+ # 读取CSV文件获取仿真CT灰度值 (两列)
197
+ csv_simulation_values = pd.read_csv(csv_simulation_file, header=None).to_numpy()
198
+
199
+ # 检查 csv_simulation_values 是否是二维数组
200
+ if csv_simulation_values.ndim == 1:
201
+ raise ValueError("CSV 文件格式不正确,应该包含两列:organ_index 和 simulation_value")
202
+
203
+ # 确保输出文件夹存在
204
+ os.makedirs(output_folder1, exist_ok=True)
205
+ os.makedirs(output_folder2, exist_ok=True)
206
+
207
+ for filename in os.listdir(input_folder1):
208
+ if filename.endswith('.nrrd'):
209
+ input_file_path = os.path.join(input_folder1, filename)
210
+ seg_file_path = os.path.join(input_folder2, filename)
211
+ output_file_path1 = os.path.join(output_folder1, filename)
212
+ output_file_path2 = os.path.join(output_folder2, filename)
213
+
214
+ print(f"Processing {input_file_path} with segmentation {seg_file_path}")
215
+ processed_seg = process_image(input_file_path, None, seg_file_path, csv_simulation_values, output_file_path1, output_file_path2, body_threshold)
216
+
217
+ def analyse_hist(input_path):
218
+ if input_path.endswith('.nrrd'):
219
+ img, header = nrrd.read(input_path)
220
+ elif input_path.endswith('.nii.gz'):
221
+ import nibabel as nib
222
+ img_metadata = nib.load(input_path)
223
+ img = img_metadata.get_fdata()
224
+ affine = img_metadata.affine
225
+ import numpy as np
226
+ import matplotlib.pyplot as plt
227
+
228
+ # Plot the histogram
229
+ print('shape of img: ', img.shape)
230
+ plt.hist(img[:, :, 50], bins=30, edgecolor='black', alpha=0.7)
231
+ plt.xlabel('Value')
232
+ plt.ylabel('Frequency')
233
+ plt.title('Value Distribution')
234
+ plt.show()
235
+
236
+
237
+ def process_csv(csv_file, output_folder1, output_folder2, csv_simulation_file, body_threshold=50, output_mr_csv_file='processed_mr_csv_file.csv'):
238
+ # 读取CSV文件获取仿真CT灰度值 (两列)
239
+ csv_simulation_values = pd.read_csv(csv_simulation_file, header=None).to_numpy()
240
+ #csv_simulation_values = pd.read_csv(csv_simulation_file)
241
+
242
+ # 检查 csv_simulation_values 是否是二维数组
243
+ if csv_simulation_values.ndim == 1:
244
+ raise ValueError("CSV 文件格式不正确,应该包含两列:organ_index 和 simulation_value")
245
+
246
+ # 确保输出文件夹存在
247
+ os.makedirs(output_folder1, exist_ok=True)
248
+ os.makedirs(output_folder2, exist_ok=True)
249
+
250
+ from step1_init_data_list import list_img_seg_ad_pIDs_from_new_simplified_csv
251
+ patient_IDs, Aorta_diss, segs, images = list_img_seg_ad_pIDs_from_new_simplified_csv(csv_file)
252
+ from tqdm import tqdm
253
+ dataset_list = []
254
+ for idx in tqdm(range(len(images))):
255
+ if (images[idx].endswith('.nii.gz') and segs[idx].endswith('.nii.gz')) or \
256
+ (images[idx].endswith('.nii') and segs[idx].endswith('.nii')):
257
+ input_file_path = images[idx]
258
+ seg_file_path = segs[idx]
259
+ patient_id = patient_IDs[idx]
260
+ ad = Aorta_diss[idx]
261
+ root_dir = os.path.dirname(input_file_path)
262
+
263
+ output_file_path1 = os.path.join(output_folder1, os.path.relpath(input_file_path, start=root_dir))
264
+
265
+ synthrad_basic_mr_name = 'mr'
266
+ synthrad_basic_seg_name = 'mr_merged_seg'
267
+ if os.path.basename(output_file_path1) == f'{synthrad_basic_mr_name}.nii.gz' or \
268
+ os.path.basename(output_file_path1) == f'{synthrad_basic_mr_name}.nii':
269
+ # Insert the patient ID in the filename
270
+ output_file_path1 = output_file_path1.replace(f'{synthrad_basic_mr_name}', f'mr_{patient_id}')
271
+
272
+ output_file_path2 = os.path.join(output_folder2, os.path.relpath(seg_file_path, start=root_dir))
273
+
274
+ if os.path.basename(output_file_path2) == f'{synthrad_basic_seg_name}.nii.gz' or \
275
+ os.path.basename(output_file_path2) == f'{synthrad_basic_seg_name}.nii':
276
+ # Insert the patient ID in the filename
277
+ output_file_path2 = output_file_path2.replace(f'{synthrad_basic_seg_name}', f'mr_seg_{patient_id}')
278
+
279
+ print(f"Processing {input_file_path} with segmentation {seg_file_path}")
280
+ print(f"Save results to {output_file_path1} and {output_file_path2}")
281
+
282
+ processed_seg = process_image(input_file_path, None, seg_file_path, csv_simulation_values, output_file_path1, output_file_path2, body_threshold)
283
+
284
+ # processed_mr_csv_file = ...
285
+ csv_mr_line = [patient_id,ad,output_file_path2,output_file_path1]
286
+ dataset_list.append(csv_mr_line)
287
+
288
+ import csv
289
+ with open(output_mr_csv_file, 'w', newline='') as f:
290
+ csvwriter = csv.writer(f)
291
+ csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
292
+ csvwriter.writerows(dataset_list)
293
+
294
+ if __name__ == "__main__":
295
+ import argparse
296
+
297
+ parser = argparse.ArgumentParser(description="Process MR images and segmentation maps, apply masks and replace grayscale values.")
298
+ parser.add_argument('--input_folder1', required=True, help="Path to the folder containing input MR .nrrd files.")
299
+ parser.add_argument('--input_folder2', required=True, help="Path to the folder containing segmentation .nrrd files.")
300
+ parser.add_argument('--output_folder1', required=True, help="Path to the folder to save the output MR files.")
301
+ parser.add_argument('--output_folder2', required=True, help="Path to the folder to save the output segmentation files.")
302
+ parser.add_argument('--csv_simulation_file', required=True, help="CSV file containing simulated CT grayscale values.")
303
+ parser.add_argument('--body_threshold', type=int, default=50, help="Threshold to separate body from background.")
304
+ args = parser.parse_args()
305
+
306
+ process_folder(args.input_folder1, args.input_folder2, args.output_folder1, args.output_folder2, args.csv_simulation_file, args.body_threshold)
dataprocesser/Preprocess_MR_Masks_overlay.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
3
+ import nrrd
4
+ import numpy as np
5
+
6
+ def load_nifti_file(input_path):
7
+ if input_path.endswith('.nrrd'):
8
+ data, header = nrrd.read(input_path)
9
+ return data, header
10
+ elif input_path.endswith('.nii.gz') or input_path.endswith(".nii"):
11
+ import nibabel as nib
12
+ img_metadata = nib.load(input_path)
13
+ img = img_metadata.get_fdata()
14
+ affine = img_metadata.affine
15
+ return img, affine
16
+
17
+ def save_nrrd_file(data, HeaderOrAffine, input_path, save_path):
18
+ #nrrd.write(save_path, data, header)
19
+ if input_path.endswith('.nrrd'):
20
+ nrrd.write(save_path, data, HeaderOrAffine)
21
+
22
+ elif input_path.endswith('.nii.gz') or input_path.endswith(".nii"):
23
+ import nibabel as nib
24
+ img_processed = nib.Nifti1Image(data, HeaderOrAffine)
25
+ nib.save(img_processed, save_path)
26
+
27
+ def overlay_images(mask_data, organ_data):
28
+ # Combine the images by adding the pixel values
29
+ organ_data = np.where(organ_data == 1, organ_data + 98, organ_data)
30
+ organ_data = np.where(organ_data == 2, organ_data + 197, organ_data)
31
+ organ_data = np.where(organ_data == 3, organ_data + 296, organ_data)
32
+ combined_data = mask_data + organ_data
33
+ return combined_data
34
+
35
+ def main(files1, files2, output_folder=None):
36
+ # files is the list including all basic MR segmentations
37
+ # files is the list including all basic MR tissue segmentations
38
+
39
+ print("preprocess length of seg files: ", len(files1))
40
+ print("preprocess length of tissue seg files: ", len(files2))
41
+
42
+ files2 = [file.replace('seg_tissue', 'seg') for file in files2]
43
+
44
+ files1 = set(files1)
45
+ files2 = set(files2)
46
+
47
+ common_files = files1.intersection(files2)
48
+
49
+ from tqdm import tqdm
50
+ for filename in tqdm(common_files):
51
+ if filename.endswith(".nrrd") or filename.endswith(".nii.gz") or filename.endswith(".nii"):
52
+ nrrd_path1 = filename
53
+ nrrd_path2 = filename.replace('seg', 'seg_tissue')
54
+
55
+ '''
56
+ if os.path.basename(filename) == 'mr_seg.nii.gz':
57
+ patient_ID = os.path.basename(os.path.dirname(filename))
58
+ output_file_name = os.path.basename(filename).replace("seg", f"seg_{patient_ID}")
59
+ else:
60
+ output_file_name = os.path.basename(filename)
61
+ '''
62
+
63
+ output_file_name = os.path.basename(filename)
64
+ output_file_name = output_file_name.replace("seg", "merged_seg")
65
+ if output_folder == None:
66
+ output_folder_current_patient = os.path.dirname(filename)
67
+ else:
68
+ output_folder_current_patient = output_folder
69
+ save_path = os.path.join(output_folder_current_patient, output_file_name)
70
+
71
+ print(f"Processing {nrrd_path1} and {nrrd_path2}, saving to {save_path}")
72
+
73
+ data1, header1 = load_nifti_file(nrrd_path1)
74
+ data2, header2 = load_nifti_file(nrrd_path2)
75
+
76
+ combined_data = overlay_images(data1, data2)
77
+ save_nrrd_file(combined_data, header1, nrrd_path1, save_path)
78
+
dataprocesser/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataprocesser.dataset_registry import DATASET_REGISTRY
2
+ # register all the dataset in DATASET_REGISTRY
3
+ import dataprocesser.dataset_anish
4
+ import dataprocesser.dataset_combined_csv
5
+ import dataprocesser.dataset_combined_synthrad_anish
6
+ import dataprocesser.dataset_csv_slice
7
+ import dataprocesser.dataset_json
8
+ import dataprocesser.dataset_synthrad
dataprocesser/archive/archiv.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..basics import get_file_list,crop_volumes, load_volumes
2
+ def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,window_width=1,ifcheck=True):
3
+ patch_func = monai.data.PatchIterd(
4
+ keys=["source", "target"],
5
+ patch_size=(None, None, window_width), # dynamic first two dimensions
6
+ start_pos=(0, 0, 0)
7
+ )
8
+ if window_width==1:
9
+ patch_transform = Compose(
10
+ [
11
+ SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
12
+ ]
13
+ )
14
+ else:
15
+ patch_transform = None
16
+ # for training
17
+ train_patch_ds = monai.data.GridPatchDataset(
18
+ data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
19
+ train_loader = DataLoader(
20
+ train_patch_ds,
21
+ batch_size=train_batch_size,
22
+ num_workers=2,
23
+ pin_memory=torch.cuda.is_available(),
24
+ )
25
+
26
+ # for validation
27
+ val_loader = DataLoader(
28
+ val_volume_ds,
29
+ num_workers=1,
30
+ batch_size=val_batch_size,
31
+ pin_memory=torch.cuda.is_available())
32
+
33
+ if ifcheck:
34
+ check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
35
+ return train_loader,val_loader
36
+
37
+ def load_batch_slices3D(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,ifcheck=True):
38
+ patch_func = monai.data.PatchIterd(
39
+ keys=["source", "target"],
40
+ patch_size=(None, None,32), # dynamic first two dimensions
41
+ start_pos=(0, 0, 0)
42
+ )
43
+
44
+ # for training
45
+ train_patch_ds = monai.data.GridPatchDataset(
46
+ data=train_volume_ds, patch_iter=patch_func, with_coordinates=False)
47
+ train_loader = DataLoader(
48
+ train_patch_ds,
49
+ batch_size=train_batch_size,
50
+ num_workers=2,
51
+ pin_memory=torch.cuda.is_available(),
52
+ )
53
+
54
+ # for validation
55
+ val_loader = DataLoader(
56
+ val_volume_ds,
57
+ num_workers=1,
58
+ batch_size=val_batch_size,
59
+ pin_memory=torch.cuda.is_available())
60
+
61
+ if ifcheck:
62
+ check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
63
+ return train_loader,val_loader
64
+
65
+
66
+
67
+ def mydataloader_3d(data_pelvis_path,
68
+ train_number,
69
+ val_number,
70
+ train_batch_size,
71
+ val_batch_size,
72
+ saved_name_train='./train_ds_2d.csv',
73
+ saved_name_val='./val_ds_2d.csv',
74
+ resized_size=(600,400,150),
75
+ div_size=(16,16,16),
76
+ ifcheck_volume=True,):
77
+ # volume-level transforms for both image and segmentation
78
+ normalize='zscore'
79
+ train_transforms = get_transforms(normalize,resized_size,div_size)
80
+
81
+ train_ds, val_ds = get_file_list(data_pelvis_path,
82
+ train_number,
83
+ val_number)
84
+ #train_volume_ds, val_volume_ds
85
+
86
+ train_volume_ds,val_volume_ds = load_volumes(train_transforms=train_transforms,
87
+ train_ds=train_ds,
88
+ val_ds=val_ds,
89
+ saved_name_train=saved_name_train,
90
+ saved_name_val=saved_name_train,
91
+ ifsave=True,
92
+ ifcheck=ifcheck_volume)
93
+ '''
94
+ train_loader = DataLoader(train_volume_ds, batch_size=train_batch_size)
95
+ val_loader = DataLoader(val_volume_ds, batch_size=val_batch_size)
96
+ '''
97
+ ifcheck_sclices=False
98
+ train_loader,val_loader = load_batch_slices3D(train_volume_ds,
99
+ val_volume_ds,
100
+ train_batch_size,
101
+ val_batch_size=val_batch_size,
102
+ ifcheck=ifcheck_sclices)
103
+
104
+ return train_loader,val_loader,train_transforms
105
+
106
+
107
+ from torchvision.utils import save_image
108
+ def save_dataset_as_png(train_ds, train_volume_ds,saved_img_folder,saved_label_folder):
109
+ train_loader = DataLoader(train_volume_ds, batch_size=1)
110
+ for idx, train_check_data in enumerate(train_loader):
111
+ image_volume = train_check_data['image']
112
+ label_volume = train_check_data['label']
113
+ current_item = train_ds[idx]
114
+ file_name_prex = os.path.basename(os.path.dirname(current_item['image']))
115
+ slices_num=image_volume.shape[-1]
116
+ for i in range(slices_num):
117
+ image_i=image_volume[0,0,:,:,i]
118
+ label_i=label_volume[0,0,:,:,i]
119
+ #print(label_volume.shape)
120
+ #SaveImage(output_dir=saved_img_folder, output_postfix=f'{file_name_prex}_image', output_ext='.png', resample=True)(image_volume[0,:,:,:,0])
121
+ save_image(image_i, f'{saved_img_folder}\{file_name_prex}_image_{i}.png')
122
+ save_image(label_i, f'{saved_label_folder}\{file_name_prex}_label_{i}.png')
123
+
124
+ def pre_dataset_for_stylegan(data_pelvis_path,
125
+ normalize,
126
+ train_number,
127
+ val_number,
128
+ saved_img_folder,
129
+ saved_label_folder,
130
+ saved_name_train='./train_ds_2d.csv',
131
+ saved_name_val='./val_ds_2d.csv',
132
+ resized_size=(600,400,None),
133
+ div_size=(16,16,None),):
134
+ train_transforms = get_transforms(normalize,resized_size,div_size)
135
+ train_ds, val_ds = get_file_list(data_pelvis_path,
136
+ train_number,
137
+ val_number)
138
+ train_volume_ds, _ = load_volumes(train_transforms,
139
+ train_ds,
140
+ val_ds,
141
+ saved_name_train,
142
+ saved_name_val,
143
+ ifsave=False,
144
+ ifcheck=False)
145
+ save_dataset_as_png(train_ds, train_volume_ds,saved_img_folder,saved_label_folder)
146
+ return train_ds,train_volume_ds
147
+
148
+ def sum_slices(data_pelvis_path, num=180):
149
+ train_ds, val_ds=get_file_list(data_pelvis_path, 0, num)
150
+ train_ds_2d, val_ds_2d,\
151
+ all_slices_train,all_slices_val,\
152
+ shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds,
153
+ saved_name_train='./train_ds_2d.csv',
154
+ saved_name_val='./val_ds_2d.csv',
155
+ ifsave=False)
156
+ print(all_slices_val)
157
+ return all_slices_val
158
+
159
+ def transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val,ifsave=True):
160
+ # Load 2D slices of CT images
161
+ train_ds_2d = []
162
+ val_ds_2d = []
163
+ shape_list_train = []
164
+ shape_list_val = []
165
+ all_slices_train=0
166
+ all_slices_val=0
167
+
168
+ # Load 2D slices for training
169
+ for sample in train_ds:
170
+ train_ds_2d_image = LoadImaged(keys=["source","target"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
171
+ name = os.path.basename(os.path.dirname(sample['image']))
172
+ num_slices = train_ds_2d_image["source"].shape[-1]
173
+ shape_list_train.append({'patient': name, 'shape': train_ds_2d_image["image"].shape})
174
+ for i in range(num_slices):
175
+ train_ds_2d.append({'image': train_ds_2d_image['image'][:,:,i], 'label': train_ds_2d_image['label'][:,:,i]})
176
+ all_slices_train += num_slices
177
+
178
+ # Load 2D slices for validation
179
+ for sample in val_ds:
180
+ val_ds_2d_image = LoadImaged(keys=["source","target"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample)
181
+ name = os.path.basename(os.path.dirname(sample['image']))
182
+ shape_list_val.append({'patient': name, 'shape': val_ds_2d_image["image"].shape})
183
+ num_slices = val_ds_2d_image["image"].shape[-1]
184
+ for i in range(num_slices):
185
+ val_ds_2d.append({'image': val_ds_2d_image['image'][:,:,i], 'label': val_ds_2d_image['label'][:,:,i]})
186
+ all_slices_val += num_slices
187
+ # Save shape list to csv
188
+ if ifsave:
189
+ np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
190
+ np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
191
+ return train_ds_2d, val_ds_2d, all_slices_train, all_slices_val, shape_list_train, shape_list_val
192
+
193
+ def get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size, val_batch_size,normalize, resized_size=(600,400), div_size=(16,16,None),):
194
+ # Define transforms
195
+ train_transforms = get_transforms(normalize,resized_size,div_size)
196
+ train_transforms_list=train_transforms.__dict__['transforms']
197
+ batch_size = batch_size
198
+ # Create training dataset and data loader
199
+ train_dataset = Dataset(data=train_ds_2d, transform=train_transforms)
200
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
201
+
202
+ val_batch_size = val_batch_size
203
+ # Create validation dataset and data loader
204
+ val_dataset = Dataset(data=val_ds_2d, transform=train_transforms)
205
+ val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=1, pin_memory=True)
206
+ return train_loader, val_loader, train_transforms_list,train_transforms
207
+
208
+ def mydataloader(data_pelvis_path,
209
+ train_number,
210
+ val_number,
211
+ batch_size,
212
+ val_batch_size,
213
+ saved_name_train='./train_ds_2d.csv',
214
+ saved_name_val='./val_ds_2d.csv',
215
+ resized_size=(600,400)):
216
+ train_ds, val_ds = get_file_list(data_pelvis_path,
217
+ train_number,
218
+ val_number)
219
+ train_ds_2d, val_ds_2d,\
220
+ all_slices_train,all_slices_val,\
221
+ shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds,
222
+ saved_name_train,
223
+ saved_name_val,ifsave=True)
224
+
225
+ train_loader, val_loader, \
226
+ train_transforms_list,train_transforms = get_train_val_loaders(train_ds_2d,
227
+ val_ds_2d,
228
+ batch_size=batch_size,
229
+ val_batch_size=val_batch_size,
230
+ normalize='zscore',
231
+ resized_size=resized_size,
232
+ div_size=(16,16,None),)
233
+ return train_loader,val_loader,\
234
+ train_transforms_list,train_transforms,\
235
+ all_slices_train,all_slices_val,\
236
+ shape_list_train,shape_list_val
dataprocesser/archive/basics.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ import os
3
+ import numpy as np
4
+ from monai.transforms import (
5
+ Compose,
6
+ LoadImaged,
7
+ Rotate90d,
8
+ ScaleIntensityd,
9
+ EnsureChannelFirstd,
10
+ ResizeWithPadOrCropd,
11
+ DivisiblePadd,
12
+ ThresholdIntensityd,
13
+ NormalizeIntensityd,
14
+ SqueezeDimd,
15
+ ShiftIntensityd,
16
+ Identityd,
17
+ CenterSpatialCropd,
18
+ ScaleIntensityRanged,
19
+ Spacingd,
20
+ )
21
+ from torch.utils.data import DataLoader
22
+ from .checkdata import check_volumes, save_volumes
23
+ def get_file_list(data_pelvis_path, train_number, val_number, source='mr', target='ct'):
24
+ #list all files in the folder
25
+ file_list=[i for i in os.listdir(data_pelvis_path) if 'overview' not in i]
26
+ file_list_path=[os.path.join(data_pelvis_path,i) for i in file_list]
27
+ #list all ct and mr files in folder
28
+ source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
29
+ target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path] #mr
30
+ # Dict Version
31
+ # source -> image
32
+ # target -> label
33
+ train_ds = [{'source': i, 'target': j, 'A_paths': i, 'B_paths': j} for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])]
34
+ val_ds = [{'source': i, 'target': j, 'A_paths': i, 'B_paths': j} for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])]
35
+ print('all files in dataset:',len(file_list))
36
+ return train_ds, val_ds
37
+
38
+ def load_volumes(train_transforms,val_transforms,
39
+ train_crop_ds, val_crop_ds,
40
+ train_ds, val_ds,
41
+ saved_name_train=None, saved_name_val=None,
42
+ ifsave=False,ifcheck=False):
43
+ train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
44
+ val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
45
+ if ifsave:
46
+ save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
47
+ if ifcheck:
48
+ check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
49
+ return train_volume_ds,val_volume_ds
50
+
51
+ def crop_volumes(train_ds, val_ds,center_crop,resized_size=(512,512,None),pad='minimum'):
52
+ if center_crop>0:
53
+ crop=Compose([LoadImaged(keys=["source", "target"]),
54
+ EnsureChannelFirstd(keys=["source", "target"]),
55
+ CenterSpatialCropd(keys=["source", "target"], roi_size=(-1,-1,center_crop)),
56
+
57
+ ])
58
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
59
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
60
+ print('center crop:',center_crop)
61
+ else:
62
+ crop=Compose([LoadImaged(keys=["source", "target"]),
63
+ EnsureChannelFirstd(keys=["source", "target"]),
64
+ ])
65
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
66
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
67
+ return train_crop_ds, val_crop_ds
68
+
69
+ def get_transforms(configs, mode='train'):
70
+ normalize=configs.dataset.normalize
71
+ pad=configs.dataset.pad
72
+ resized_size=configs.dataset.resized_size
73
+ WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
74
+ WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
75
+ prob=configs.dataset.augmentationProb
76
+ background=configs.dataset.background
77
+
78
+ transform_list=[]
79
+ min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
80
+ transform_list.append(ThresholdIntensityd(keys=["target"], threshold=min, above=True, cval=background))
81
+ #transform_list.append(ThresholdIntensityd(keys=["target"], threshold=max, above=False, cval=-1000))
82
+ # filter the source images
83
+ # transform_list.append(ThresholdIntensityd(keys=["source"], threshold=configs.dataset.MRImax, above=False, cval=0))
84
+ if normalize=='zscore':
85
+ transform_list.append(NormalizeIntensityd(keys=["source", "target"], nonzero=False, channel_wise=True))
86
+ print('zscore normalization')
87
+ elif normalize=='minmax':
88
+ transform_list.append(ScaleIntensityd(keys=["source", "target"], minv=-1, maxv=1.0))
89
+ print('minmax normalization')
90
+
91
+ elif normalize=='scale4000':
92
+ transform_list.append(ScaleIntensityd(keys=["source"], minv=-1, maxv=1))
93
+ transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
94
+ transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975)) # x=x(1+factor)
95
+ print('scale1000 normalization')
96
+
97
+ elif normalize=='scale1000':
98
+ transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1))
99
+ transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
100
+ transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975))
101
+ print('scale1000 normalization')
102
+
103
+ elif normalize=='inputonlyzscore':
104
+ transform_list.append(NormalizeIntensityd(keys=["source"], nonzero=False, channel_wise=True))
105
+ print('only normalize input MRI images')
106
+
107
+ elif normalize=='inputonlyminmax':
108
+ transform_list.append(ScaleIntensityd(keys=["source"], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
109
+ print('only normalize input MRI images')
110
+ elif normalize=='none':
111
+ print('no normalization')
112
+ transform_list.append(Spacingd(keys=["source"], pixdim=(1.0, 1.0, 1.0), mode="bilinear")) #
113
+ transform_list.append(Spacingd(keys=["target", "mask"], pixdim=(1.0, 1.0 , 2.5), mode="bilinear")) #
114
+ transform_list.append(ResizeWithPadOrCropd(keys=["source", "target", "mask"], spatial_size=resized_size,mode=pad))
115
+ # transform_list.append(ScaleIntensityRanged(keys=["target"],a_min=WINDOW_LEVEL-(WINDOW_WIDTH/2), a_max=WINDOW_LEVEL+(WINDOW_WIDTH/2),b_min=0, b_max=1, clip=True))
116
+
117
+ if mode == 'train':
118
+ from monai.transforms import (
119
+ # data augmentation
120
+ RandRotated,
121
+ RandZoomd,
122
+ RandBiasFieldd,
123
+ RandAffined,
124
+ RandGridDistortiond,
125
+ RandGridPatchd,
126
+ RandShiftIntensityd,
127
+ RandGibbsNoised,
128
+ RandAdjustContrastd,
129
+ RandGaussianSmoothd,
130
+ RandGaussianSharpend,
131
+ RandGaussianNoised,
132
+ )
133
+ Aug=True
134
+ if Aug:
135
+ transform_list.append(RandRotated(keys=["source", "target", "mask"], range_x = 0.1, range_y = 0.1, range_z = 0.1, prob=prob, padding_mode="border", keep_size=True))
136
+ transform_list.append(RandZoomd(keys=["source", "target", "mask"], prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True))
137
+ transform_list.append(RandAffined(keys=["source", "target", "mask"],padding_mode="border" , prob=prob))
138
+ #transform_list.append(Rand3DElasticd(keys=["source", "target"], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
139
+ intensityAug=False
140
+ if intensityAug:
141
+ print('intensity data augmentation is used')
142
+ transform_list.append(RandBiasFieldd(keys=["source"], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
143
+ transform_list.append(RandGaussianNoised(keys=["source"], prob=prob, mean=0.0, std=0.01))
144
+ transform_list.append(RandAdjustContrastd(keys=["source"], prob=prob, gamma=(0.5, 1.5)))
145
+ transform_list.append(RandShiftIntensityd(keys=["source"], prob=prob, offsets=20))
146
+ transform_list.append(RandGaussianSharpend(keys=["source"], alpha=(0.2, 0.8), prob=prob))
147
+
148
+ #transform_list.append(Rotate90d(keys=["source", "target"], k=3))
149
+ #transform_list.append(DivisiblePadd(keys=["source", "target"], k=div_size, mode="minimum"))
150
+ #transform_list.append(Identityd(keys=["source", "target"])) # do nothing for the no norm case
151
+ train_transforms = Compose(transform_list)
152
+ return train_transforms
153
+
154
+ def get_length(dataset, patch_batch_size):
155
+ loader=DataLoader(dataset, batch_size=1)
156
+ iterator = iter(loader)
157
+ sum_nslices=0
158
+ for idx in range(len(loader)):
159
+ check_data = next(iterator)
160
+ nslices=check_data['source'].shape[-1]
161
+ sum_nslices+=nslices
162
+ if sum_nslices%patch_batch_size==0:
163
+ return sum_nslices//patch_batch_size
164
+ else:
165
+ return sum_nslices//patch_batch_size+1
166
+
167
+
dataprocesser/archive/checkdata.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import numpy as np
3
+ import os
4
+ def test_volumes_pixdim(train_volume_ds):
5
+ train_loader = DataLoader(train_volume_ds, batch_size=1)
6
+ for step, data in enumerate(train_loader):
7
+ mr_data=data['source']
8
+ ct_data=data['target']
9
+
10
+ print(f"source image shape: {mr_data.shape}")
11
+ print(f"source image affine:\n{mr_data.meta['affine']}")
12
+ print(f"source image pixdim:\n{mr_data.pixdim}")
13
+
14
+ # target image information
15
+ print(f"target image shape: {ct_data.shape}")
16
+ print(f"target image affine:\n{ct_data.meta['affine']}")
17
+ print(f"target image pixdim:\n{ct_data.pixdim}")
18
+
19
+
20
+ def check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds):
21
+ # use batch_size=1 to check the volumes because the input volumes have different shapes
22
+ train_loader = DataLoader(train_volume_ds, batch_size=1)
23
+ val_loader = DataLoader(val_volume_ds, batch_size=1)
24
+ train_iterator = iter(train_loader)
25
+ val_iterator = iter(val_loader)
26
+ print('check training data:')
27
+ idx=0
28
+ for idx in range(len(train_loader)):
29
+ try:
30
+ train_check_data = next(train_iterator)
31
+ ds_idx = idx * 1
32
+ current_item = train_ds[ds_idx]
33
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
34
+ print(idx, current_name, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
35
+ except:
36
+ ds_idx = idx * 1
37
+ current_item = train_ds[ds_idx]
38
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
39
+ print('check data error! Check the input data:',current_name)
40
+ print("checked all training data.")
41
+
42
+ print('check validation data:')
43
+ idx=0
44
+ for idx in range(len(val_loader)):
45
+ try:
46
+ val_check_data = next(val_iterator)
47
+ ds_idx = idx * 1
48
+ current_item = val_ds[ds_idx]
49
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
50
+ print(idx, current_name, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
51
+ except:
52
+ ds_idx = idx * 1
53
+ current_item = val_ds[ds_idx]
54
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
55
+ print('check data error! Check the input data:',current_name)
56
+ print("checked all validation data.")
57
+
58
+ def save_volumes(train_ds, val_ds, saved_name_train, saved_name_val):
59
+ shape_list_train=[]
60
+ shape_list_val=[]
61
+ # use the function of saving information before
62
+ for sample in train_ds:
63
+ name = os.path.basename(os.path.dirname(sample['image']))
64
+ shape_list_train.append({'patient': name})
65
+ for sample in val_ds:
66
+ name = os.path.basename(os.path.dirname(sample['image']))
67
+ shape_list_val.append({'patient': name})
68
+ np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
69
+ np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
70
+
71
+ def check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size):
72
+ for idx, train_check_data in enumerate(train_loader):
73
+ ds_idx = idx * train_batch_size
74
+ current_item = train_patch_ds[ds_idx]
75
+ print('check train data:')
76
+ print(current_item, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
77
+
78
+ for idx, val_check_data in enumerate(val_loader):
79
+ ds_idx = idx * val_batch_size
80
+ current_item = val_volume_ds[ds_idx]
81
+ print('check val data:')
82
+ print(current_item, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
83
+
84
+ def len_patchloader(train_volume_ds,train_batch_size):
85
+ slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds)))
86
+ print('total slices in training set:',slice_number)
87
+
88
+ import math
89
+ batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
90
+ print('total batches in training set:',batch_number)
91
+ return slice_number,batch_number
dataprocesser/archive/createsegtransform.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from totalsegmentator.python_api import totalsegmentator
3
+ class CreateMaskTransformd:
4
+ def __init__(self, keys, tissue_min, tissue_max, bone_min, bone_max, mask_value_bones=2,
5
+ if_use_total_seg=False, organ_label_id=52, mask_value_organ=2, fast=True):
6
+ self.keys = keys
7
+ self.tissue_min = tissue_min
8
+ self.tissue_max = tissue_max
9
+ self.bone_min = bone_min
10
+ self.bone_max = bone_max
11
+ self.mask_value_bones = mask_value_bones
12
+
13
+ self.if_use_total_seg = if_use_total_seg
14
+ self.organ_label_id = organ_label_id
15
+ self.mask_value_organ = mask_value_organ
16
+ self.fast = fast
17
+
18
+ def extract_organ_mask(self, input_img, organ_label_id, mask_value):
19
+ # aorta = 52
20
+ """
21
+ Extracts a binary mask for a specific organ from a labeled NIFTI image.
22
+
23
+ img_in: NIFTI image with segmentation labels.
24
+ organ_name: Name of the organ to extract.
25
+ label_map: Dictionary mapping label IDs to organ names.
26
+
27
+ returns: Binary mask as a NIFTI image.
28
+ """
29
+ img_in = totalsegmentator(input=input_img, task='total',fast=self.fast)
30
+ data = img_in.get_fdata()
31
+
32
+ # Create a binary mask for the specified organ
33
+ organ_mask_data = np.zeros_like(data)
34
+ organ_mask_data[data == organ_label_id] = mask_value
35
+
36
+ # Create a new NIFTI image for the binary mask
37
+ organ_mask_img = nib.Nifti1Image(organ_mask_data, img_in.affine, img_in.header)
38
+ return organ_mask_img
39
+
40
+ def __call__(self, data):
41
+ for key in self.keys:
42
+ x = data[key]
43
+
44
+ mask = torch.zeros_like(x)
45
+ # [B, H, W, D]
46
+ # create a mask for each slice in the batch
47
+ for i in range(x.shape[0]):
48
+ if self.if_use_total_seg:
49
+ mask_batch_i = self.extract_organ_mask(x[i,:,:,:], organ_label_id=self.organ_label_id, mask_value=self.mask_value_organ)
50
+ mask[i,:,:,:] = mask_batch_i
51
+ for j in range(x.shape[-1]):
52
+ mask_slice = create_body_mask(x[i,:,:,j], body_threshold=self.tissue_min)
53
+ mask[i,:,:, j] = mask_slice
54
+ #mask = torch.zeros_like(x)
55
+ #mask[(x > self.tissue_min) & (x <= self.tissue_max)] = 1
56
+ mask[(x >= self.bone_min) & (x <= self.bone_max)] = self.mask_value_bones
57
+ data[key] = mask
58
+ #print("input and mask shape: ",x.shape,data[key].shape)
59
+ return data
60
+
61
+ class CreateSegTransformd:
62
+ # create a mask by segmenting the input image using totalsegmentator
63
+ def __init__(self, keys, organ_label_id=52, mask_value=2, fast=True):
64
+ self.keys = keys
65
+ self.organ_label_id = organ_label_id
66
+ self.mask_value = mask_value
67
+ self.fast = fast
68
+
69
+ def extract_organ_mask(self, input_img, organ_label_id, mask_value):
70
+ # aorta = 52
71
+ """
72
+ Extracts a binary mask for a specific organ from a labeled NIFTI image.
73
+
74
+ img_in: NIFTI image with segmentation labels.
75
+ organ_name: Name of the organ to extract.
76
+ label_map: Dictionary mapping label IDs to organ names.
77
+
78
+ returns: Binary mask as a NIFTI image.
79
+ """
80
+ img_in = totalsegmentator(input=input_img, task='total',fast=self.fast)
81
+ data = img_in.get_fdata()
82
+
83
+ if organ_label_id>0:
84
+ # Create a binary mask for the specified organ
85
+ organ_mask_data = np.zeros_like(data)
86
+ organ_mask_data[data == organ_label_id] = mask_value
87
+ else:
88
+ organ_mask_data=data
89
+
90
+ # Create a new NIFTI image for the binary mask
91
+ organ_mask_img = nib.Nifti1Image(organ_mask_data, img_in.affine, img_in.header)
92
+ return organ_mask_img
93
+
94
+ def __call__(self, data):
95
+ for key in self.keys:
96
+ x = data[key]
97
+ mask = torch.zeros_like(x)
98
+ # [B, H, W, D]
99
+ for i in range(x.shape[0]):
100
+ mask_batch_i = self.extract_organ_mask(x[i,:,:,:], organ_label_id=self.organ_label_id, mask_value=self.mask_value)
101
+ mask[i,:,:,:] = mask_batch_i
102
+ data[key] = mask
103
+ return data
104
+
105
+
106
+ class CreateTotalSegTransformd:
107
+ # create a mask by segmenting the input image using totalsegmentator
108
+ def __init__(self, keys, fast=True):
109
+ self.keys = keys
110
+ self.fast = fast
111
+
112
+ def extract_organ_mask(self, input_img):
113
+ # aorta = 52
114
+ """
115
+ Extracts a binary mask for a specific organ from a labeled NIFTI image.
116
+
117
+ img_in: NIFTI image with segmentation labels.
118
+ organ_name: Name of the organ to extract.
119
+ label_map: Dictionary mapping label IDs to organ names.
120
+
121
+ returns: Binary mask as a NIFTI image.
122
+ """
123
+ #print(input_img.meta)
124
+ input_affine = input_img.meta['affine']
125
+
126
+ input_img = torch_tensor_to_nifti(input_img, affine=input_affine)
127
+ img_in = totalsegmentator(input=input_img, task='total', fast=self.fast)
128
+ data = img_in.get_fdata()
129
+ organ_mask_data=data
130
+ # Create a new NIFTI image for the binary mask
131
+ organ_mask_img = nib.Nifti1Image(organ_mask_data, img_in.affine, img_in.header)
132
+ return organ_mask_img
133
+
134
+ def __call__(self, data):
135
+ for key in self.keys:
136
+ x = data[key]
137
+ mask = torch.zeros_like(x)
138
+ # [B, H, W, D]
139
+ for i in range(x.shape[0]):
140
+ mask_batch_i = self.extract_organ_mask(x[i,:,:,:])
141
+ numpy_data = mask_batch_i.get_fdata()
142
+
143
+ # Convert the NumPy array to a PyTorch tensor
144
+ tensor_data = torch.from_numpy(numpy_data).float()
145
+ mask[i,:,:,:] = tensor_data
146
+ data[key] = mask
147
+ return data
148
+
149
+ def get_transforms(self, transform_list):
150
+ normalize=configs.dataset.normalize
151
+ pad=configs.dataset.pad
152
+ resized_size=configs.dataset.resized_size
153
+ WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
154
+ WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
155
+ prob=configs.dataset.augmentationProb
156
+ background=configs.dataset.background
157
+ indicator_A=configs.dataset.indicator_A
158
+ indicator_B=configs.dataset.indicator_B
159
+ load_masks=configs.dataset.load_masks
160
+ transform_list=[]
161
+ input_is_mask=configs.dataset.input_is_mask
162
+ # normally we input CT images and here we create masks for CT images
163
+ if not input_is_mask:
164
+ if not configs.dataset.use_all_masks:
165
+ transform_list.append(CreateMaskTransformd(keys=[indicator_A],
166
+ tissue_min=configs.dataset.tissue_min,
167
+ tissue_max=configs.dataset.tissue_max,
168
+ bone_min=configs.dataset.bone_min,
169
+ bone_max=configs.dataset.bone_max,
170
+ mask_value_bones=2,
171
+ ))
172
+ else: # use all masks from the totalsegmentator
173
+ transform_list.append(CreateTotalSegTransformd(keys=[indicator_A],
174
+ fast=True))
175
+ min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
176
+ #transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=min, above=True, cval=background))
177
+ #transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=max, above=False, cval=-1000))
178
+ # filter the source images
179
+ # transform_list.append(ThresholdIntensityd(keys=[indicator_A], threshold=configs.dataset.MRImax, above=False, cval=0))
180
+ if normalize=='zscore':
181
+ transform_list.append(NormalizeIntensityd(keys=[indicator_B], nonzero=False, channel_wise=True))
182
+ print('zscore normalization')
183
+ elif normalize=='minmax':
184
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=-1.0, maxv=1.0))
185
+ print('minmax normalization')
186
+
187
+ elif normalize=='scale1000_wrongbutworks':
188
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
189
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
190
+ print('scale1000 normalization')
191
+
192
+ elif normalize=='scale1000':
193
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
194
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
195
+ print('scale1000 normalization')
196
+
197
+ elif normalize=='scale4000':
198
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
199
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975))
200
+ print('scale4000 normalization')
201
+
202
+ elif normalize=='scale10':
203
+ #transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
204
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
205
+ print('scale10 normalization')
206
+
207
+ elif normalize=='inputonlyzscore':
208
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
209
+ print('only normalize input MRI images')
210
+
211
+ elif normalize=='inputonlyminmax':
212
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
213
+ print('only normalize input MRI images')
214
+
215
+ elif normalize=='none' or normalize=='nonorm':
216
+ print('no normalization')
217
+
218
+ spaceXY=self.configs.dataset.spaceXY
219
+ if spaceXY>0:
220
+ transform_list.append(Spacingd(keys=[indicator_A], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear", ensure_same_shape=True)) #
221
+ transform_list.append(Spacingd(keys=[indicator_B, "mask"] if load_masks else [indicator_B],
222
+ pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear", ensure_same_shape=True))
223
+
224
+ transform_list.append(Zoomd(keys=[indicator_A, indicator_B,"mask"] if load_masks
225
+ else [indicator_A, indicator_B],
226
+ zoom=configs.dataset.zoom, keep_size=False, mode='area',padding_mode='minimum'))
227
+
228
+
229
+ transform_list.append(DivisiblePadd(keys=[indicator_A, indicator_B,"mask"] if load_masks else [indicator_A, indicator_B],
230
+ k=self.configs.dataset.div_size, mode="minimum"))
231
+ transform_list.append(ResizeWithPadOrCropd(keys=[indicator_A, indicator_B,"mask"] if load_masks else [indicator_A, indicator_B],
232
+ spatial_size=resized_size,mode=pad))
233
+
234
+ if configs.dataset.rotate:
235
+ transform_list.append(Rotate90d(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B], k=3))
236
+
237
+ if mode == 'train':
238
+ from monai.transforms import (
239
+ # data augmentation
240
+ RandRotated,
241
+ RandZoomd,
242
+ RandBiasFieldd,
243
+ RandAffined,
244
+ RandGridDistortiond,
245
+ RandGridPatchd,
246
+ RandShiftIntensityd,
247
+ RandGibbsNoised,
248
+ RandAdjustContrastd,
249
+ RandGaussianSmoothd,
250
+ RandGaussianSharpend,
251
+ RandGaussianNoised,
252
+ )
253
+ shapeAug=configs.dataset.shapeAug
254
+ if shapeAug:
255
+ #transform_list.append(RandRotated(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
256
+ # range_x = 0.0, range_y = 1.0, range_z = 1.0,
257
+ # prob=prob, padding_mode="border", keep_size=False))
258
+ transform_list.append(RandZoomd(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
259
+ prob=prob, min_zoom=self.configs.dataset.rand_min_zoom, max_zoom=self.configs.dataset.rand_max_zoom,
260
+ padding_mode= "minimum" ,keep_size=False))
261
+ #transform_list.append(RandAffined(keys=[indicator_A, indicator_B], padding_mode="border" , prob=prob))
262
+ #transform_list.append(Rand3DElasticd(keys=[indicator_A, indicator_B], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
263
+ intensityAug=configs.dataset.intensityAug
264
+ if intensityAug:
265
+ print('intensity data augmentation is used')
266
+ transform_list.append(RandBiasFieldd(keys=[indicator_A], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
267
+ transform_list.append(RandGaussianNoised(keys=[indicator_A], prob=prob, mean=0.0, std=0.01))
268
+ transform_list.append(RandAdjustContrastd(keys=[indicator_A], prob=prob, gamma=(0.5, 1.5)))
269
+ transform_list.append(RandShiftIntensityd(keys=[indicator_A], prob=prob, offsets=20))
270
+ transform_list.append(RandGaussianSharpend(keys=[indicator_A], alpha=(0.2, 0.8), prob=prob))
271
+
272
+ #transform_list.append(Rotate90d(keys=[indicator_A, indicator_B], k=3))
273
+ #transform_list.append(DivisiblePadd(keys=[indicator_A, indicator_B], k=div_size, mode="minimum"))
274
+ #transform_list.append(Identityd(keys=[indicator_A, indicator_B])) # do nothing for the no norm case
275
+ train_transforms = Compose(transform_list)
276
+ return train_transforms
dataprocesser/archive/csv_dataset.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_data_scaler(config):
2
+ """Data normalizer. Assume data are always in [0, 1]."""
3
+ if config.data.centered:
4
+ # Rescale to [-1, 1]
5
+ return lambda x: x * 2. - 1.
6
+ else:
7
+ return lambda x: x
8
+
9
+
10
+ def get_data_inverse_scaler(config):
11
+ """Inverse data normalizer."""
12
+ if config.data.centered:
13
+ # Rescale [-1, 1] to [0, 1]
14
+ return lambda x: (x + 1.) / 2.
15
+ else:
16
+ return lambda x: x
17
+
18
+ IMG_EXTENSIONS = [
19
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
20
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
21
+ '.nrrd', '.nii.gz'
22
+ ]
23
+
24
+
25
+ def is_image_file(filename):
26
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
27
+
28
+ def volume_slicer(volume_tensor, transform, all_slices=None):
29
+ # Convert numpy array to PyTorch tensor
30
+ # Note: You might need to add channel dimension or perform other adjustments
31
+ volume_tensor = volume_tensor.permute(2, 1, 0) # [H, W, D] -> [D, H, W]
32
+ volume_tensor = volume_tensor.unsqueeze(1) # Add channel dimension [D, H, W] -> [D, 1, H, W]
33
+ if transform is not None:
34
+ volume_tensor = transform(volume_tensor)
35
+
36
+ #print('stacking volume tensor:',volume_tensor.shape)
37
+ if all_slices is None:
38
+ all_slices = volume_tensor
39
+ else:
40
+ all_slices = torch.cat((all_slices, volume_tensor), 0)
41
+ return all_slices
42
+
43
+ class csvDataset_3D(Dataset):
44
+ def __init__(self, csv_file, transform=None, load_patient_number=1):
45
+ """
46
+ Args:
47
+ csv_file (string): Path to the csv file with annotations.
48
+ transform (callable, optional): Optional transform to be applied on a sample.
49
+ """
50
+ self.data_frame = pd.read_csv(csv_file)
51
+ # control the length of the dataset
52
+ self.data_frame = self.data_frame[:load_patient_number]
53
+ self.transform = transform
54
+
55
+ def __len__(self):
56
+ return len(self.data_frame)
57
+
58
+ def __getitem__(self, idx):
59
+ if torch.is_tensor(idx):
60
+ idx = idx.tolist()
61
+
62
+ img_path = self.data_frame.iloc[idx, -1]
63
+ image = nib.load(img_path).get_fdata()
64
+ image = torch.tensor(image, dtype=torch.float32)
65
+
66
+ # Example: Using the 'Aorta_diss' column as a label
67
+ label = self.data_frame.iloc[idx, -3]
68
+ #label = torch.tensor(label, dtype=torch.float32)
69
+
70
+ # If more processing is needed (e.g., normalization, adding channel dimension), do it here
71
+ image = image.unsqueeze(0) # Add channel dimension if it's a single channel image
72
+
73
+ sample = {'image': image, 'label': label}
74
+
75
+ return sample
76
+
77
+ class csvDataset_2D(Dataset):
78
+ def __init__(self, csv_file, transform=None, load_patient_number=1):
79
+ self.csv_file = csv_file
80
+ self.transform = transform
81
+ self.load_patient_number = load_patient_number
82
+ self.data_frame = pd.read_csv(csv_file)
83
+ if len(self.data_frame) == 0:
84
+ raise RuntimeError(f"Found 0 images in: {csv_file}")
85
+
86
+ # Initialize dataset
87
+ self.initialize_dataset()
88
+
89
+ def initialize_dataset(self):
90
+ print('Loading dataset...')
91
+ self.data_frame = self.data_frame[:self.load_patient_number]
92
+ all_slices = None
93
+ all_labels = []
94
+
95
+ for idx in tqdm(range(len(self.data_frame))):
96
+ img_path = self.data_frame.iloc[idx, -1]
97
+ volume = nib.load(img_path)
98
+ volume_data = volume.get_fdata() # Load as [H, W, D]
99
+ volume_tensor = torch.tensor(volume_data, dtype=torch.float32)
100
+ all_slices = volume_slicer(volume_tensor, self.transform, all_slices) # -> [D, 1, H, W] and pile up all the slices
101
+ label = self.data_frame.iloc[idx, -3]
102
+ all_labels = all_labels + [label] * volume_tensor.shape[0]
103
+
104
+ print('All stacked slices:', all_slices.shape)
105
+ self.all_slices = all_slices
106
+ self.all_labels = all_labels
107
+
108
+ def __len__(self):
109
+ return self.all_slices.shape[0]
110
+
111
+ def __getitem__(self, idx):
112
+ if torch.is_tensor(idx):
113
+ idx = idx.tolist()
114
+ image = self.all_slices[idx]
115
+ label = self.all_labels[idx]
116
+ sample = {'source': image, 'target': label}
117
+ return sample
118
+
119
+ def reset(self):
120
+ print('Resetting dataset...')
121
+ self.initialize_dataset()
dataprocesser/archive/csv_dataset_slices.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import os.path
3
+ from PIL import Image
4
+ import torch
5
+ from PIL import ImageFile
6
+ import os
7
+ import pandas as pd
8
+ import monai
9
+ import json
10
+ Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
12
+
13
+
14
+
15
+
16
+ from dataprocesser.list_dataset_base import BaseDataLoader
17
+
18
+
19
+
20
+
dataprocesser/archive/csv_dataset_slices_assigned.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataprocesser.csv_dataset_slices import csv_slices_DataLoader
2
+ from dataprocesser.customized_transforms import MaskHUAssigmentd
3
+
4
+ from monai.transforms import (
5
+ ScaleIntensityd,
6
+ ThresholdIntensityd,
7
+ NormalizeIntensityd,
8
+ ShiftIntensityd,
9
+ )
10
+
11
+
dataprocesser/archive/data_create_seg.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from step1_init_data_list import (
2
+ list_img_ad_from_anish_csv,
3
+ list_img_pID_from_synthrad_folder,
4
+ )
5
+ def run():
6
+ number = 1
7
+ dataset='anish'
8
+ if dataset=='anish':
9
+ data_dir = 'D:\Projects\SynthRad\synthrad_conversion\healthy_dissec_home.csv'
10
+ target_file_list, _ =list_img_ad_from_anish_csv(data_dir) # a csv_file
11
+ elif dataset=='synthrad':
12
+ data_dir = 'D:\Projects\data\synthrad\train\Task1\pelvis'
13
+ target_file_list, _=list_img_pID_from_synthrad_folder(data_dir, accepted_modalities='ct', saved_name="target_filenames.txt")
14
+ create_segmentation(target_file_list[0: number])
15
+
16
+ def create_segmentation(dataset_list):
17
+ import nibabel as nib
18
+ try:
19
+ from totalsegmentator.python_api import totalsegmentator
20
+ for sample in dataset_list:
21
+ input_path=sample
22
+ print(f'create segmentation mask for {input_path}')
23
+ output_path=input_path.replace('.nii','_seg.nii')
24
+ input_img = nib.load(input_path)
25
+ totalsegmentator(input=input_img, output=output_path, task='total', fast=False, ml=True)
26
+ print(f'segmentation mask is saved as {output_path}')
27
+ except:
28
+ print("An exception occurred")
dataprocesser/archive/data_slicing.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataprocesser.step1_init_data_list import init_dataset
2
+ import os
3
+
4
+ loader, opt, my_paths = init_dataset()
5
+ path=r'E:\Projects\yang_proj\data\seg2med\seg2med_nifti_2d_343'
6
+ train_path=os.path.join(path, 'train')
7
+ val_path=os.path.join(path,'val')
8
+ os.makedirs(path,exist_ok=True)
9
+ os.makedirs(train_path,exist_ok=True)
10
+ os.makedirs(val_path,exist_ok=True)
11
+
12
+ loader.save_slices_nifti_and_csv(train_path, loader.train_volume_ds)
13
+ loader.save_slices_nifti_and_csv(val_path, loader.val_volume_ds)
dataprocesser/archive/dataset_med.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset
2
+ import torch.utils.data as data
3
+ import os.path
4
+ import random
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import torch
8
+ from PIL import ImageFile
9
+ from utils.MattingLaplacian import compute_laplacian
10
+
11
+ import nibabel as nib
12
+ import numpy as np
13
+
14
+ Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
15
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
16
+
17
+
18
+
19
+ IMG_EXTENSIONS = [
20
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
21
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
22
+ '.nrrd', '.nii.gz'
23
+ ]
24
+
25
+
26
+ def is_image_file(filename):
27
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
28
+
29
+ def make_dataset_modality(dir, modality='ct'):
30
+ images = []
31
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
32
+
33
+ # for image data in the following structure:
34
+ # root/
35
+ # patient_folder/
36
+ # ct_image.nii.gz
37
+ # mr_image.nii.gz
38
+ # ...
39
+ # patient_folder2/
40
+ # ct_image.nii.gz
41
+ # mr_image.nii.gz
42
+ # ...
43
+ for patient_folder, _, fnames in sorted(os.walk(dir)): # means that it will go through all the files in the directory
44
+ #print(patient_folder)
45
+ if patient_folder != dir:
46
+ #print('patient folder:',patient_folder)
47
+ for root2, _, fnames2 in sorted(os.walk(patient_folder)):
48
+ #print('files:',fnames2)
49
+ for fname2 in fnames2:
50
+ if is_image_file(fname2) and modality in fname2:
51
+ #print('passed file:',fname2)
52
+ path = os.path.join(root2, fname2)
53
+ images.append(path)
54
+ return images
55
+
56
+
57
+ class CTImageDataset(Dataset):
58
+ def __init__(self, root, modality='ct', transform=None,
59
+ load_patient_number=1,
60
+ use_lap=True, win_rad=1):
61
+ self.imgs_paths = sorted(make_dataset_modality(root, modality))
62
+ self.transform = transform
63
+ self.to_tensor = transforms.ToTensor() # Might need adjustment for 3D
64
+
65
+ if len(self.imgs_paths) == 0:
66
+ raise RuntimeError(f"Found 0 images in: {root}")
67
+ # form the images to be in the form of [D, H, W]
68
+ all_slices = None
69
+ for img_path in self.imgs_paths[:load_patient_number]:
70
+ volume = nib.load(img_path)
71
+ volume_data = volume.get_fdata() # load as [H, W, D]
72
+ #
73
+ # Convert numpy array to PyTorch tensor
74
+ # Note: You might need to add channel dimension or perform other adjustments
75
+ volume_tensor = torch.tensor(volume_data, dtype=torch.float32)
76
+ volume_tensor = volume_tensor.permute(2, 1, 0) # [N, H, W]
77
+ volume_tensor = volume_tensor.unsqueeze(3) # Add channel dimension [N, H, W] -> [N, H, W, 1]
78
+ # pasting grayscale information to all three channels.
79
+ volume_tensor = volume_tensor.repeat(1, 1, 1, 3)
80
+ #print('Debug, volume tensor:',volume_tensor.shape)
81
+ if self.transform is not None:
82
+ volume_tensor = self.transform(volume_tensor)
83
+ if all_slices is None:
84
+ all_slices = volume_tensor
85
+ else:
86
+ all_slices = torch.cat((all_slices, volume_tensor), 0)
87
+ print(f'slices of {modality} dataset:',all_slices.shape)
88
+ self.all_slices = all_slices
89
+ self.use_lap = use_lap
90
+ self.win_rad = win_rad
91
+
92
+ def __getitem__(self, index):
93
+ img = self.all_slices[index]
94
+ #print('Debug 1, img shape:',img.shape)
95
+ if self.use_lap:
96
+ laplacian_m = compute_laplacian(img, win_rad=self.win_rad)
97
+ else:
98
+ laplacian_m = None
99
+ #print('Debug 2, laplacian_m:',laplacian_m.shape)
100
+ # permute img from [H, W, C] to [C, H, W]
101
+ img = img.permute(2, 0, 1)
102
+ return {'img': img, 'laplacian_m': laplacian_m}
103
+
104
+ def __len__(self):
105
+ return self.all_slices.shape[0]
106
+
107
+ from monai.transforms import (
108
+ ResizeWithPadOrCrop,
109
+ ScaleIntensity,
110
+ Compose,
111
+ )
112
+
113
+ def get_data_loader_folder(input_folder, modality,
114
+ batch_size, new_size=288,
115
+ height=256, width=256,
116
+ num_workers=None, load_patient_number=1):
117
+ transform_list = []
118
+ transform_list = [ResizeWithPadOrCrop(spatial_size=[height,width, -1],mode="minimum")] + transform_list
119
+ transform_list = [ScaleIntensity(minv=0, maxv=1.0)]+ transform_list
120
+ #transform_list = [ScaleIntensity(factor=-0.9)]+ transform_list
121
+ #transform_list = [transforms.Resize(new_size)] + transform_list
122
+ transform = Compose(transform_list)
123
+
124
+ dataset = CTImageDataset(input_folder, modality=modality, transform=transform, load_patient_number=load_patient_number)
125
+
126
+ if num_workers is None:
127
+ num_workers = 0
128
+ loader = DataLoader(dataset=dataset,
129
+ batch_size=batch_size,
130
+ drop_last=True,
131
+ num_workers=num_workers,
132
+ sampler=InfiniteSamplerWrapper(dataset),
133
+ collate_fn=collate_fn
134
+ )
135
+ return loader
136
+
137
+ def main(root = r'C:\Users\56991\Projects\Datasets\Task1\pelvis',modality='ct'):
138
+ # Example usage
139
+
140
+ batch_size = 8
141
+ new_size = 512
142
+ height = 512
143
+ width = 512
144
+ num_workers = None
145
+ load_patient_number = 1
146
+ loader = get_data_loader_folder(root,modality, batch_size, new_size, height, width, num_workers, load_patient_number)
147
+ #print length of loader
148
+ print('Length of loader:',len(loader))
149
+ for i, batch in enumerate(loader):
150
+ print(f'Batch {i}:',batch['img'].shape)
151
+ print('Done')
152
+
153
+ if __name__=='__main__':
154
+ main()
155
+
156
+
157
+ def InfiniteSampler(n):
158
+ # i = 0
159
+ i = n - 1
160
+ order = np.random.permutation(n)
161
+ while True:
162
+ yield order[i]
163
+ i += 1
164
+ if i >= n:
165
+ np.random.seed()
166
+ order = np.random.permutation(n)
167
+ i = 0
168
+
169
+
170
+ class InfiniteSamplerWrapper(data.sampler.Sampler):
171
+ def __init__(self, data_source):
172
+ self.num_samples = len(data_source)
173
+
174
+ def __iter__(self):
175
+ return iter(InfiniteSampler(self.num_samples))
176
+
177
+ def __len__(self):
178
+ return 2 ** 31
179
+
180
+
181
+ def collate_fn(batch):
182
+ img = [b['img'] for b in batch]
183
+ img = torch.stack(img, dim=0)
184
+
185
+ laplacian_m = [b['laplacian_m'] for b in batch]
186
+
187
+ return {'img': img, 'laplacian_m': laplacian_m}
188
+
dataprocesser/archive/gan_loader.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ import os
3
+ import numpy as np
4
+ from monai.transforms import (
5
+ Compose,
6
+ LoadImaged,
7
+ EnsureChannelFirstd,
8
+ SqueezeDimd,
9
+ CenterSpatialCropd,
10
+ Rotate90d,
11
+ ScaleIntensityd,
12
+ ResizeWithPadOrCropd,
13
+ DivisiblePadd,
14
+ ThresholdIntensityd,
15
+ NormalizeIntensityd,
16
+ ShiftIntensityd,
17
+ Identityd,
18
+ ScaleIntensityRanged,
19
+ Spacingd,
20
+ )
21
+
22
+ from monai.data import Dataset
23
+ from torch.utils.data import DataLoader
24
+ import torch
25
+ from .checkdata import check_volumes, save_volumes, check_batch_data, test_volumes_pixdim
26
+
27
+ def get_transforms(configs, mode='train'):
28
+ normalize=configs.dataset.normalize
29
+ pad=configs.dataset.pad
30
+ resized_size=configs.dataset.resized_size
31
+ WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
32
+ WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
33
+ prob=configs.dataset.augmentationProb
34
+ background=configs.dataset.background
35
+
36
+ transform_list=[]
37
+ min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
38
+ #transform_list.append(ThresholdIntensityd(keys=["target"], threshold=min, above=True, cval=background))
39
+ #transform_list.append(ThresholdIntensityd(keys=["target"], threshold=max, above=False, cval=-1000))
40
+ # filter the source images
41
+ # transform_list.append(ThresholdIntensityd(keys=["source"], threshold=configs.dataset.MRImax, above=False, cval=0))
42
+ if normalize=='zscore':
43
+ transform_list.append(NormalizeIntensityd(keys=["source", "target"], nonzero=False, channel_wise=True))
44
+ print('zscore normalization')
45
+ elif normalize=='minmax':
46
+ transform_list.append(ScaleIntensityd(keys=["source", "target"], minv=-1.0, maxv=1.0))
47
+ print('minmax normalization')
48
+
49
+ elif normalize=='scale4000':
50
+ transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1))
51
+ transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
52
+ transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975)) # x=x(1+factor)
53
+ print('scale1000 normalization')
54
+
55
+ elif normalize=='scale1000':
56
+ transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1))
57
+ transform_list.append(ScaleIntensityd(keys=["target"], minv=0))
58
+ transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.999))
59
+ print('scale1000 normalization')
60
+
61
+ elif normalize=='inputonlyzscore':
62
+ transform_list.append(NormalizeIntensityd(keys=["source"], nonzero=False, channel_wise=True))
63
+ print('only normalize input MRI images')
64
+
65
+ elif normalize=='inputonlyminmax':
66
+ transform_list.append(ScaleIntensityd(keys=["source"], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
67
+ print('only normalize input MRI images')
68
+
69
+ elif normalize=='none' or normalize=='nonorm':
70
+ print('no normalization')
71
+
72
+ spaceXY=0
73
+ if spaceXY>0:
74
+ transform_list.append(Spacingd(keys=["source"], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear")) #
75
+ transform_list.append(Spacingd(keys=["target", "mask"], pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear")) #
76
+ transform_list.append(ResizeWithPadOrCropd(keys=["source", "target"], spatial_size=resized_size,mode=pad))
77
+ # transform_list.append(ScaleIntensityRanged(keys=["target"],a_min=WINDOW_LEVEL-(WINDOW_WIDTH/2), a_max=WINDOW_LEVEL+(WINDOW_WIDTH/2),b_min=0, b_max=1, clip=True))
78
+
79
+ if configs.dataset.rotate:
80
+ transform_list.append(Rotate90d(keys=["source", "target"], k=3))
81
+
82
+ if mode == 'train':
83
+ from monai.transforms import (
84
+ # data augmentation
85
+ RandRotated,
86
+ RandZoomd,
87
+ RandBiasFieldd,
88
+ RandAffined,
89
+ RandGridDistortiond,
90
+ RandGridPatchd,
91
+ RandShiftIntensityd,
92
+ RandGibbsNoised,
93
+ RandAdjustContrastd,
94
+ RandGaussianSmoothd,
95
+ RandGaussianSharpend,
96
+ RandGaussianNoised,
97
+ )
98
+ Aug=True
99
+ if Aug:
100
+ transform_list.append(RandRotated(keys=["source", "target", "mask"], range_x = 0.1, range_y = 0.1, range_z = 0.1, prob=prob, padding_mode="border", keep_size=True))
101
+ transform_list.append(RandZoomd(keys=["source", "target", "mask"], prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True))
102
+ transform_list.append(RandAffined(keys=["source", "target"],padding_mode="border" , prob=prob))
103
+ #transform_list.append(Rand3DElasticd(keys=["source", "target"], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
104
+ intensityAug=False
105
+ if intensityAug:
106
+ print('intensity data augmentation is used')
107
+ transform_list.append(RandBiasFieldd(keys=["source"], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
108
+ transform_list.append(RandGaussianNoised(keys=["source"], prob=prob, mean=0.0, std=0.01))
109
+ transform_list.append(RandAdjustContrastd(keys=["source"], prob=prob, gamma=(0.5, 1.5)))
110
+ transform_list.append(RandShiftIntensityd(keys=["source"], prob=prob, offsets=20))
111
+ transform_list.append(RandGaussianSharpend(keys=["source"], alpha=(0.2, 0.8), prob=prob))
112
+
113
+ #transform_list.append(Rotate90d(keys=["source", "target"], k=3))
114
+ #transform_list.append(DivisiblePadd(keys=["source", "target"], k=div_size, mode="minimum"))
115
+ #transform_list.append(Identityd(keys=["source", "target"])) # do nothing for the no norm case
116
+ train_transforms = Compose(transform_list)
117
+ return train_transforms
118
+ def myslicesloader(configs,paths):
119
+ data_path=configs.dataset.data_dir
120
+ train_number=configs.dataset.train_number
121
+ val_number=configs.dataset.val_number
122
+ train_batch_size=configs.dataset.batch_size
123
+ val_batch_size=configs.dataset.val_batch_size
124
+ saved_name_train=paths["saved_name_train"]
125
+ saved_name_val=paths["saved_name_val"]
126
+ center_crop=configs.dataset.center_crop
127
+ source=configs.dataset.source
128
+ target=configs.dataset.target
129
+
130
+ # volume-level transforms for both image and label
131
+ train_transforms = get_transforms(configs,mode='train')
132
+ val_transforms = get_transforms(configs,mode='val')
133
+
134
+ #list all files in the folder
135
+ file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
136
+ file_list_path=[os.path.join(data_path,i) for i in file_list]
137
+ #list all ct and mr files in folder
138
+ mask='mask'
139
+ source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
140
+ target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path]
141
+ mask_file_list=[os.path.join(j,f'{mask}.nii.gz') for j in file_list_path]
142
+ train_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
143
+ for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
144
+ val_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
145
+ for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
146
+ print('all files in dataset:',len(file_list))
147
+
148
+ # load volumes and center crop
149
+ if center_crop>0:
150
+ crop=Compose([LoadImaged(keys=["source", "target", "mask"]),
151
+ EnsureChannelFirstd(keys=["source", "target", "mask"]),
152
+ CenterSpatialCropd(keys=["source", "target", "mask"], roi_size=(-1,-1,center_crop)),
153
+
154
+ ])
155
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
156
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
157
+ print('center crop:',center_crop)
158
+ else:
159
+ crop=Compose([LoadImaged(keys=["source", "target", "mask"]),
160
+ EnsureChannelFirstd(keys=["source", "target", "mask"]),
161
+ ])
162
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
163
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
164
+
165
+ # load volumes
166
+ train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
167
+ val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
168
+ ifsave,ifcheck,iftest=False,False,False
169
+ if ifsave:
170
+ save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
171
+ if ifcheck:
172
+ check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
173
+ if iftest:
174
+ test_volumes_pixdim(train_volume_ds)
175
+
176
+ # batch-level slicer for both image and label
177
+ window_width=1
178
+ patch_func = monai.data.PatchIterd(
179
+ keys=["source", "target", "mask"],
180
+ patch_size=(None, None, window_width), # dynamic first two dimensions
181
+ start_pos=(0, 0, 0)
182
+ )
183
+ if window_width==1:
184
+ patch_transform = Compose(
185
+ [
186
+ SqueezeDimd(keys=["source", "target", "mask"], dim=-1), # squeeze the last dim
187
+ ]
188
+ )
189
+ else:
190
+ patch_transform = None
191
+
192
+ # for training
193
+ train_patch_ds = monai.data.GridPatchDataset(
194
+ data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
195
+ train_loader = DataLoader(
196
+ train_patch_ds,
197
+ batch_size=train_batch_size,
198
+ num_workers=2,
199
+ pin_memory=torch.cuda.is_available(),
200
+ )
201
+
202
+ # for validation
203
+ val_loader = DataLoader(
204
+ val_volume_ds,
205
+ num_workers=1,
206
+ batch_size=val_batch_size,
207
+ pin_memory=torch.cuda.is_available())
208
+
209
+ if ifcheck:
210
+ check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
211
+
212
+ return train_crop_ds,val_crop_ds,train_loader,val_loader,train_transforms,val_transforms
213
+
214
+ def ddpmloader(configs,paths):
215
+ data_path=configs.dataset.data_dir
216
+ train_number=configs.dataset.train_number
217
+ val_number=configs.dataset.val_number
218
+ train_batch_size=configs.dataset.batch_size
219
+ val_batch_size=configs.dataset.val_batch_size
220
+ saved_name_train=paths["saved_name_train"]
221
+ saved_name_val=paths["saved_name_val"]
222
+ center_crop=configs.dataset.center_crop
223
+ source=configs.dataset.source
224
+ target=configs.dataset.target
225
+
226
+ # volume-level transforms for both image and label
227
+ train_transforms = get_transforms(configs,mode='train')
228
+ val_transforms = get_transforms(configs,mode='val')
229
+
230
+ #list all files in the folder
231
+ file_list=[i for i in os.listdir(data_path) if 'overview' not in i]
232
+ file_list_path=[os.path.join(data_path,i) for i in file_list]
233
+ #list all ct and mr files in folder
234
+ mask='mask'
235
+ source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path]
236
+ target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path]
237
+ mask_file_list=[os.path.join(j,f'{mask}.nii.gz') for j in file_list_path]
238
+ train_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
239
+ for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
240
+ val_ds = [{'source': i, 'target': j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
241
+ for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
242
+ print('all files in dataset:',len(file_list))
243
+
244
+ # load volumes and center crop
245
+ if center_crop>0:
246
+ crop=Compose([LoadImaged(keys=["source", "target"]),
247
+ EnsureChannelFirstd(keys=["source", "target"]),
248
+ CenterSpatialCropd(keys=["source", "target"], roi_size=(-1,-1,center_crop)),
249
+
250
+ ])
251
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
252
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
253
+ print('center crop:',center_crop)
254
+ else:
255
+ crop=Compose([LoadImaged(keys=["source", "target"]),
256
+ EnsureChannelFirstd(keys=["source", "target"]),
257
+ ])
258
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop)
259
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop)
260
+
261
+ # load volumes
262
+ train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
263
+ val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
264
+ ifsave,ifcheck,iftest=False,False,False
265
+ if ifsave:
266
+ save_volumes(train_ds, val_ds, saved_name_train, saved_name_val)
267
+ if ifcheck:
268
+ check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds)
269
+ if iftest:
270
+ test_volumes_pixdim(train_volume_ds)
271
+
272
+ # batch-level slicer for both image and label
273
+ window_width=1
274
+ patch_func = monai.data.PatchIterd(
275
+ keys=["source", "target"],
276
+ patch_size=(None, None, window_width), # dynamic first two dimensions
277
+ start_pos=(0, 0, 0)
278
+ )
279
+ if window_width==1:
280
+ patch_transform = Compose(
281
+ [
282
+ SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
283
+ ]
284
+ )
285
+ else:
286
+ patch_transform = None
287
+
288
+ # for training
289
+ train_patch_ds = monai.data.GridPatchDataset(
290
+ data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
291
+ train_loader = DataLoader(
292
+ train_patch_ds,
293
+ batch_size=train_batch_size,
294
+ num_workers=0,
295
+ pin_memory=torch.cuda.is_available(),
296
+ )
297
+
298
+ # for validation
299
+ val_patch_ds = monai.data.GridPatchDataset(
300
+ data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
301
+ val_loader = DataLoader(
302
+ val_patch_ds, #val_volume_ds,
303
+ num_workers=0,
304
+ batch_size=val_batch_size,
305
+ pin_memory=torch.cuda.is_available())
306
+
307
+ if ifcheck:
308
+ check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
309
+
310
+ return train_crop_ds,val_crop_ds,train_loader,val_loader,train_transforms,val_transforms
dataprocesser/archive/init_dataset.py ADDED
File without changes
dataprocesser/archive/json_dataset_slices.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import os.path
3
+ from PIL import Image
4
+ import torch
5
+ from PIL import ImageFile
6
+ import os
7
+ import pandas as pd
8
+ import monai
9
+
10
+ Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
12
+
13
+ IMG_EXTENSIONS = [
14
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.nrrd', '.nii.gz'
17
+ ]
18
+
19
+ def is_image_file(filename):
20
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21
+
22
+
23
+ from dataprocesser.list_dataset_base import BaseDataLoader
24
+
25
+
26
+
27
+
28
+
dataprocesser/archive/list_dataset_Anika.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+
7
+
8
+
9
+
10
+
dataprocesser/archive/list_dataset_Anish.py ADDED
File without changes
dataprocesser/archive/list_dataset_Anish_seg.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset, random_split
2
+ import torch.utils.data as data
3
+ import os.path
4
+ import random
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import torch
8
+ from PIL import ImageFile
9
+ #from utils.MattingLaplacian import compute_laplacian
10
+
11
+ import nibabel as nib
12
+ import numpy as np
13
+ import os
14
+ import csv
15
+ import pandas as pd
16
+ #from transformers import CLIPTokenizer
17
+ from tqdm import tqdm
18
+ Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
19
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
20
+
21
+ import os
22
+ import numpy as np
23
+
24
+ import torch
25
+
26
+ IMG_EXTENSIONS = [
27
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
28
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
29
+ '.nrrd', '.nii.gz'
30
+ ]
31
+
32
+ def is_image_file(filename):
33
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
34
+
35
+
36
+
37
+
38
+ from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd
39
+ from dataprocesser.list_dataset_base import BaseDataLoader
40
+
41
+
42
+
dataprocesser/archive/list_dataset_base.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ import os
3
+ import numpy as np
4
+ from monai.transforms import (
5
+ Compose,
6
+ LoadImaged,
7
+ EnsureChannelFirstd,
8
+ SqueezeDimd,
9
+ CenterSpatialCropd,
10
+ Rotate90d,
11
+ ScaleIntensityd,
12
+ ResizeWithPadOrCropd,
13
+ DivisiblePadd,
14
+ Zoomd,
15
+ ThresholdIntensityd,
16
+ NormalizeIntensityd,
17
+ ShiftIntensityd,
18
+ Identityd,
19
+ ScaleIntensityRanged,
20
+ Spacingd,
21
+ )
22
+ from torch.utils.data import DataLoader
23
+ from torch.utils.data import ConcatDataset
24
+ import torch
25
+ from abc import ABC, abstractmethod
26
+
27
+ from datetime import datetime
28
+ import json
29
+ from tqdm import tqdm
30
+
31
+ from step1_init_data_list import (
32
+ list_img_ad_from_anish_csv,
33
+ list_img_ad_pIDs_from_anish_csv,
34
+ list_img_pID_from_synthrad_folder,
35
+ list_from_anika_dataset,
36
+ list_from_json,
37
+ list_from_slice_csv,
38
+ )
39
+ from step5_data_check_and_log import finalcheck
40
+ VERBOSE = False
41
+
42
+ def make_dataset_modality():
43
+ images = []
44
+ return images
45
+
46
+ class ABCLoader(ABC):
47
+ @abstractmethod
48
+ def __init__(self):
49
+ """Subclass must implement this method."""
50
+ pass
51
+
52
+ def get_loader(self):
53
+ """Subclass must implement this method."""
54
+ pass
55
+
56
+ def create_dataset(self):
57
+ """Subclass must implement this method."""
58
+ pass
59
+
60
+ def get_transforms(self):
61
+ """Subclass must implement this method."""
62
+ pass
63
+
64
+ def get_normlization(self):
65
+ """Subclass must implement this method."""
66
+ pass
67
+
68
+ def get_shape_transform(self):
69
+ """Subclass must implement this method."""
70
+ print("no shape transform here!!!!!!!!!!!!!!!!!!!!!!")
71
+ pass
72
+
73
+ def get_augmentation(self):
74
+ """Subclass must implement this method."""
75
+ pass
76
+
77
+ class BaseDataLoader(ABCLoader):
78
+ def __init__(self,configs,paths=None,dimension=2, **kwargs):
79
+ self.configs=configs
80
+ self.paths=paths
81
+ self.init_parameters_and_transforms()
82
+ self.get_loader()
83
+ #print('all files in dataset:',len(self.source_file_list))
84
+
85
+
86
+ self.rotation_level = kwargs.get('rotation_level', 0) # Default to no rotation (0)
87
+ self.zoom_level = kwargs.get('zoom_level', 1.0) # Default to no zoom (1.0)
88
+ self.flip = kwargs.get('flip', 0) # Default to no flip
89
+
90
+ self.create_dataset(dimension=dimension)
91
+
92
+ ifsave = None if paths is None else True
93
+ finalcheck(self.train_ds, self.val_ds,
94
+ self.train_volume_ds, self.val_volume_ds,
95
+ self.train_loader, self.val_loader,
96
+ self.train_patch_ds,
97
+ self.train_batch_size, self.val_batch_size,
98
+ self.saved_name_train, self.saved_name_val,
99
+ self.indicator_A, self.indicator_B,
100
+ ifsave=ifsave, ifcheck=False,iftest_volumes_pixdim=False)
101
+
102
+ def get_loader(self):
103
+ self.source_file_list = []
104
+ self.train_ds=[]
105
+ self.val_ds=[]
106
+
107
+
108
+ def init_parameters_and_transforms(self):
109
+ self.indicator_A=self.configs.dataset.indicator_A
110
+ self.indicator_B=self.configs.dataset.indicator_B
111
+ self.train_number=self.configs.dataset.train_number
112
+ self.val_number=self.configs.dataset.val_number
113
+ self.train_batch_size=self.configs.dataset.batch_size
114
+ self.val_batch_size=self.configs.dataset.val_batch_size
115
+ self.load_masks=self.configs.dataset.load_masks
116
+
117
+ self.keys = [self.indicator_A, self.indicator_B, "mask"] if self.load_masks else [self.indicator_A, self.indicator_B]
118
+
119
+ if self.configs.model_name=='augmentation':
120
+ # Fixed parameters for rotation and zooming
121
+ self.train_transforms = self.get_augmentation(transform_list=[], flip=self.flip, rotation_level=self.rotation_level, zoom_level=self.zoom_level)
122
+ else:
123
+ self.train_transforms = self.get_transforms(mode='train')
124
+ self.val_transforms = self.get_transforms(mode='val')
125
+
126
+ if self.paths is not None:
127
+ self.saved_name_train=self.paths["saved_name_train"]
128
+ self.saved_name_val=self.paths["saved_name_val"]
129
+
130
+ def create_volume_dataset(self):
131
+ # load volumes and center crop
132
+ center_crop = self.configs.dataset.center_crop
133
+ transformations_crop = [
134
+ LoadImaged(keys=self.keys),
135
+ EnsureChannelFirstd(keys=self.keys),
136
+ ]
137
+ if center_crop>0:
138
+ transformations_crop.append(CenterSpatialCropd(keys=self.keys, roi_size=(-1,-1,center_crop)))
139
+ transformations_crop=Compose(transformations_crop)
140
+ train_crop_ds = monai.data.Dataset(data=self.train_ds, transform=transformations_crop)
141
+ val_crop_ds = monai.data.Dataset(data=self.val_ds, transform=transformations_crop)
142
+
143
+ # load volumes
144
+ self.train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=self.train_transforms)
145
+ self.val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=self.val_transforms)
146
+
147
+ def create_patch_dataset_and_dataloader(self, dimension=2):
148
+ train_batch_size=self.configs.dataset.batch_size
149
+ val_batch_size=self.configs.dataset.val_batch_size
150
+ if dimension==2:
151
+ # batch-level slicer for both image and label
152
+ window_width=1
153
+ patch_func = monai.data.PatchIterd(
154
+ keys=self.keys,
155
+ patch_size=(None, None, window_width), # dynamic first two dimensions
156
+ start_pos=(0, 0, 0)
157
+ )
158
+ if window_width==1:
159
+ patch_transform = Compose(
160
+ [
161
+ SqueezeDimd(keys=self.keys, dim=-1), # squeeze the last dim
162
+ ]
163
+ )
164
+ else:
165
+ patch_transform = None
166
+
167
+ # for training
168
+ train_patch_ds = monai.data.GridPatchDataset(
169
+ data=self.train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
170
+ train_loader = DataLoader(
171
+ train_patch_ds,
172
+ batch_size=train_batch_size,
173
+ num_workers=self.configs.dataset.num_workers,
174
+ pin_memory=torch.cuda.is_available(),
175
+ )
176
+
177
+ # for validation
178
+ if self.configs.model_name=='ddpm' or 'ddpm2d_seg2med' or 'ddpm2d':
179
+ val_patch_ds = monai.data.GridPatchDataset(
180
+ data=self.val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
181
+ val_loader = DataLoader(
182
+ val_patch_ds, #val_volume_ds,
183
+ num_workers=self.configs.dataset.num_workers,
184
+ batch_size=val_batch_size,
185
+ pin_memory=torch.cuda.is_available())
186
+ else:
187
+ val_loader = DataLoader(
188
+ self.val_volume_ds,
189
+ num_workers=self.configs.dataset.num_workers,
190
+ batch_size=val_batch_size,
191
+ pin_memory=torch.cuda.is_available())
192
+ self.train_patch_ds=train_patch_ds
193
+
194
+ elif dimension==2.5:
195
+ # batch-level slicer for both image and label
196
+ # 2.5 means stack slices together as a small volume patch
197
+ # if window_width>1, means we train a 2.5D network
198
+ patch_size=self.configs.dataset.patch_size # (None, None, window_width)
199
+ window_width=patch_size[-1]
200
+ patch_func = monai.data.PatchIterd(
201
+ keys=self.keys,
202
+ patch_size=patch_size, # dynamic first two dimensions: (None, None, window_width)
203
+ start_pos=(0, 0, 0)
204
+ )
205
+ if window_width==1:
206
+ print(f"slice patch is 1, we use 2D-training")
207
+ patch_transform = Compose(
208
+ [
209
+ SqueezeDimd(keys=self.keys, dim=-1), # squeeze the last dim
210
+ ]
211
+ )
212
+ else:
213
+ print(f"use consecutive {window_width} slices for 2.5D-training")
214
+ # there would be an error if original size < patch_size during training, so we should pad it in this case
215
+ patch_transform = ResizeWithPadOrCropd(keys=self.keys,
216
+ spatial_size=patch_size, mode='minimum')
217
+
218
+ # for training
219
+ train_patch_ds = monai.data.GridPatchDataset(
220
+ data=self.train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
221
+ train_loader = DataLoader(
222
+ train_patch_ds,
223
+ batch_size=train_batch_size,
224
+ num_workers=2,
225
+ pin_memory=torch.cuda.is_available(),
226
+ )
227
+
228
+ # for validation
229
+ if self.configs.model_name=='ddpm':
230
+ val_patch_ds = monai.data.GridPatchDataset(
231
+ data=self.val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
232
+ val_loader = DataLoader(
233
+ val_patch_ds, #val_volume_ds,
234
+ num_workers=0,
235
+ batch_size=val_batch_size,
236
+ pin_memory=torch.cuda.is_available())
237
+ else:
238
+ val_loader = DataLoader(
239
+ self.val_volume_ds,
240
+ num_workers=1,
241
+ batch_size=val_batch_size,
242
+ pin_memory=torch.cuda.is_available())
243
+ self.train_patch_ds=train_patch_ds
244
+
245
+ elif dimension==3:
246
+ # 3 means use the whole input volume for training
247
+ train_loader = DataLoader(
248
+ self.train_volume_ds,
249
+ num_workers=self.configs.dataset.num_workers,
250
+ batch_size=train_batch_size,
251
+ pin_memory=torch.cuda.is_available())
252
+ val_loader = DataLoader(
253
+ self.val_volume_ds,
254
+ num_workers=self.configs.dataset.num_workers,
255
+ batch_size=val_batch_size,
256
+ pin_memory=torch.cuda.is_available())
257
+
258
+ elif dimension==3.5:
259
+ # 3.5 means create patch from the original volume
260
+ patch_func = monai.data.PatchIterd(
261
+ keys=[self.indicator_A, self.indicator_B],
262
+ patch_size=self.configs.dataset.patch_size, # dynamic first two dimensions
263
+ start_pos=(0, 0, 0),
264
+ mode="replicate",
265
+ )
266
+ patch_transform = None
267
+
268
+ # for training
269
+ train_patch_ds = monai.data.GridPatchDataset(
270
+ data=self.train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
271
+ train_loader = DataLoader(
272
+ train_patch_ds,
273
+ batch_size=train_batch_size,
274
+ num_workers=self.configs.dataset.num_workers,
275
+ pin_memory=torch.cuda.is_available(),
276
+ )
277
+
278
+ val_patch_ds = monai.data.GridPatchDataset(
279
+ data=self.val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
280
+ val_loader = DataLoader(
281
+ val_patch_ds, #val_volume_ds,
282
+ num_workers=self.configs.dataset.num_workers,
283
+ batch_size=val_batch_size,
284
+ pin_memory=torch.cuda.is_available())
285
+ else:
286
+ print('dimension of input data must be 2 or 2.5 or 3 or 3.5!')
287
+
288
+
289
+
290
+ self.train_batch_size=train_batch_size
291
+ self.val_batch_size=val_batch_size
292
+ self.train_loader=train_loader
293
+ self.val_loader=val_loader
294
+
295
+ def create_dataset(self,dimension=2):
296
+ self.create_volume_dataset()
297
+ self.create_patch_dataset_and_dataloader(dimension=dimension)
298
+
299
+ def get_transforms(self, mode='train'):
300
+ transform_list=[]
301
+ transform_list = self.get_pretransforms(transform_list)
302
+ transform_list = self.get_intensity_transforms(transform_list)
303
+ transform_list = self.get_normlization(transform_list)
304
+ transform_list = self.get_shape_transform(transform_list)
305
+ train_transforms = Compose(transform_list)
306
+ return train_transforms
307
+
308
+ def get_pretransforms(self, transform_list):
309
+ #print("customized transforms")
310
+ return transform_list
311
+
312
+ def get_intensity_transforms(self, transform_list):
313
+ threshold_low=self.configs.dataset.WINDOW_LEVEL - self.configs.dataset.WINDOW_WIDTH / 2
314
+ threshold_high=self.configs.dataset.WINDOW_LEVEL + self.configs.dataset.WINDOW_WIDTH / 2
315
+ offset=(-1)*threshold_low
316
+ # if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
317
+ # mask = img > self.threshold if self.above else img < self.threshold
318
+ # res = where(mask, img, self.cval)
319
+ transform_list.append(ThresholdIntensityd(keys=[self.indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
320
+ transform_list.append(ThresholdIntensityd(keys=[self.indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
321
+ transform_list.append(ShiftIntensityd(keys=[self.indicator_B], offset=offset))
322
+ return transform_list
323
+
324
+ def get_normlization(self, transform_list):
325
+ normalize=self.configs.dataset.normalize
326
+ indicator_A=self.configs.dataset.indicator_A
327
+ indicator_B=self.configs.dataset.indicator_B
328
+ # offset = self.configs.dataset.offset
329
+ # we don't need normalization for segmentation mask
330
+ if normalize=='zscore':
331
+ transform_list.append(NormalizeIntensityd(keys=[indicator_B], nonzero=False, channel_wise=True))
332
+ print('zscore normalization')
333
+ elif normalize=='minmax':
334
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=-1.0, maxv=1.0))
335
+ print('minmax normalization')
336
+
337
+ elif normalize=='scale1000_wrongbutworks':
338
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
339
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
340
+ print('scale1000 normalization')
341
+
342
+ elif normalize=='scale4000':
343
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975))
344
+ print('scale4000 normalization')
345
+
346
+ elif normalize=='scale2000':
347
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.9995))
348
+ print('scale2000 normalization')
349
+
350
+ elif normalize=='scale1000':
351
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
352
+ print('scale1000 normalization')
353
+
354
+ elif normalize=='scale100':
355
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.99))
356
+ print('scale10 normalization')
357
+
358
+ elif normalize=='scale10':
359
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
360
+ print('scale10 normalization')
361
+
362
+ elif normalize=='inputonlyzscore':
363
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
364
+ print('only normalize input MRI images')
365
+
366
+ elif normalize=='inputonlyminmax':
367
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=self.configs.dataset.normmin, maxv=self.configs.dataset.normmax))
368
+ print('only normalize input MRI images')
369
+
370
+ elif normalize == 'nonegative':
371
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=self.configs.dataset.offset))
372
+ print('none negative normalization')
373
+
374
+ elif normalize=='none' or normalize=='nonorm':
375
+ print('no normalization')
376
+
377
+ return transform_list
378
+
379
+ def get_shape_transform(self, transform_list):
380
+ spaceXY=self.configs.dataset.spaceXY
381
+ load_masks=self.configs.dataset.load_masks
382
+
383
+ indicator_A=self.configs.dataset.indicator_A
384
+ indicator_B=self.configs.dataset.indicator_B
385
+ pad_value=0 #offset*(-1)
386
+ keys = self.keys #[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B]
387
+ if spaceXY>0:
388
+ transform_list.append(Spacingd(keys=[indicator_A], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear", ensure_same_shape=True)) #
389
+ transform_list.append(Spacingd(keys=[indicator_B, "mask"] if load_masks else [indicator_B],
390
+ pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear", ensure_same_shape=True))
391
+
392
+ transform_list.append(Zoomd(keys=keys,
393
+ zoom=self.configs.dataset.zoom, keep_size=False, mode='area', padding_mode="constant", value=pad_value))
394
+ transform_list.append(DivisiblePadd(keys=keys,
395
+ k=self.configs.dataset.div_size, mode="constant", value=pad_value))
396
+ transform_list.append(ResizeWithPadOrCropd(keys=keys,
397
+ spatial_size=self.configs.dataset.resized_size,mode="constant", value=pad_value))
398
+
399
+ if self.configs.dataset.rotate:
400
+ transform_list.append(Rotate90d(keys=keys, k=3))
401
+ return transform_list
402
+
403
+ class anish_loader(BaseDataLoader):
404
+ def __init__(self,configs,paths,dimension=2):
405
+ self.configs=configs
406
+ self.paths=paths
407
+ self.get_loader()
408
+ super().create_dataset(dimension=dimension)
409
+ self.finalcheck(ifsave=True,ifcheck=False,iftest_volumes_pixdim=False)
410
+
411
+ def get_loader(self):
412
+ indicator_A=self.configs.dataset.indicator_A
413
+ indicator_B=self.configs.dataset.indicator_B
414
+ self.indicator_A=indicator_A
415
+ self.indicator_B=indicator_B
416
+ train_number=self.configs.dataset.train_number
417
+ val_number=self.configs.dataset.val_number
418
+ train_batch_size=self.configs.dataset.batch_size
419
+ val_batch_size=self.configs.dataset.val_batch_size
420
+ load_masks=self.configs.dataset.load_masks
421
+
422
+
423
+ #source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
424
+ #target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
425
+ #mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
426
+ if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
427
+ # check if import data is csv file
428
+ if self.configs.dataset.data_dir.endswith('.csv'):
429
+ csv_file = self.configs.dataset.data_dir
430
+ else:
431
+ raise ValueError('The data directory in this case must be a csv file!')
432
+ else:
433
+ if self.configs.server == 'helix' or self.configs.server == 'helixSingle' or self.configs.server=='helixMultiple':
434
+ csv_file = './healthy_dissec_helix.csv'
435
+ else:
436
+ csv_file = './healthy_dissec.csv'
437
+
438
+ if self.configs.dataset.input_is_mask:
439
+ load_seg=True
440
+ else:
441
+ load_seg=False
442
+ source_file_list, source_Aorta_diss_list=list_img_ad_from_anish_csv(csv_file, load_seg)
443
+ target_file_list, target_Aorta_diss_list=list_img_ad_from_anish_csv(csv_file)
444
+ mask_file_list, mask_Aorta_diss_list=list_img_ad_from_anish_csv(csv_file)
445
+ if load_masks:
446
+ train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
447
+ for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
448
+ val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
449
+ for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
450
+ else:
451
+ train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
452
+ for i, j, ad in zip(source_file_list[0:train_number], target_file_list[0:train_number], source_Aorta_diss_list[0:train_number])]
453
+ val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
454
+ for i, j, ad in zip(source_file_list[-val_number:], target_file_list[-val_number:], source_Aorta_diss_list[-val_number:])]
455
+ self.train_ds=train_ds
456
+ self.val_ds=val_ds
457
+ self.source_file_list=source_file_list
458
+ self.target_file_list=target_file_list
459
+ self.mask_file_list=mask_file_list
460
+
461
+ def get_pretransforms(self, transform_list):
462
+ normalize=self.configs.dataset.normalize
463
+ indicator_A=self.configs.dataset.indicator_A
464
+ indicator_B=self.configs.dataset.indicator_B
465
+ load_masks=self.configs.dataset.load_masks
466
+ input_is_mask=self.configs.dataset.input_is_mask
467
+ if not input_is_mask:
468
+ transform_list.append(CreateMaskTransformd(keys=[indicator_A],
469
+ tissue_min=self.configs.dataset.tissue_min,
470
+ tissue_max=self.configs.dataset.tissue_max,
471
+ bone_min=self.configs.dataset.bone_min,
472
+ bone_max=self.configs.dataset.bone_max))
473
+
474
+ from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd
475
+
476
+ class synthrad_seg_loader(BaseDataLoader):
477
+ def __init__(self,configs,paths,dimension=2,**kwargs):
478
+ super().__init__(configs,paths,dimension,**kwargs)
479
+
480
+ def get_loader(self):
481
+ # volume-level transforms for both image and label
482
+ indicator_A=self.configs.dataset.indicator_A
483
+ indicator_B=self.configs.dataset.indicator_B
484
+ train_number=self.configs.dataset.train_number
485
+ val_number=self.configs.dataset.val_number
486
+ self.indicator_A=indicator_A
487
+ self.indicator_B=indicator_B
488
+ load_masks=self.configs.dataset.load_masks
489
+ # Conditional dictionary keys based on whether masks are loaded
490
+
491
+ #list all files in the folder
492
+ file_list=[i for i in os.listdir(self.configs.dataset.data_dir) if 'overview' not in i]
493
+ file_list_path=[os.path.join(self.configs.dataset.data_dir,i) for i in file_list]
494
+ #list all ct and mr files in folder
495
+
496
+
497
+ # mask file means the images are used for extracting body contour, see get_pretransforms() below
498
+ source_file_list, patient_IDs=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.source_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"source_filenames.txt"))
499
+ target_file_list, _=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"target_filenames.txt"))
500
+ mask_file_list, _=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"mask_filenames.txt"))
501
+
502
+ self.source_file_list=source_file_list
503
+ self.target_file_list=target_file_list
504
+ self.mask_file_list=mask_file_list
505
+
506
+ Manual_Set_Aorta_Diss = 0
507
+ ad = Manual_Set_Aorta_Diss
508
+ train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
509
+ for i, j, k, pID in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number], patient_IDs[0:train_number])]
510
+ val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
511
+ for i, j, k, pID in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:], patient_IDs[-val_number:])]
512
+ self.train_ds=train_ds
513
+ self.val_ds=val_ds
514
+
515
+ def get_pretransforms(self, transform_list):
516
+ indicator_A=self.configs.dataset.indicator_A
517
+ indicator_B=self.configs.dataset.indicator_B
518
+
519
+ transform_list.append(CreateMaskTransformd(keys=['mask'],
520
+ body_threshold=-500,
521
+ body_mask_value=1,
522
+ ))
523
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
524
+ return transform_list
525
+
526
+
527
+ from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd, MaskHUAssigmentd
528
+
529
+ from monai.transforms import (
530
+ ScaleIntensityd,
531
+ ThresholdIntensityd,
532
+ NormalizeIntensityd,
533
+ ShiftIntensityd,
534
+ )
535
+
536
+ class anish_seg_loader(BaseDataLoader):
537
+ def __init__(self,configs,paths=None,dimension=2, **kwargs):
538
+ super().__init__(configs,paths,dimension, **kwargs)
539
+
540
+ def get_loader(self):
541
+ indicator_A=self.configs.dataset.indicator_A
542
+ indicator_B=self.configs.dataset.indicator_B
543
+ self.indicator_A=indicator_A
544
+ self.indicator_B=indicator_B
545
+ train_number=self.configs.dataset.train_number
546
+ val_number=self.configs.dataset.val_number
547
+ train_batch_size=self.configs.dataset.batch_size
548
+ val_batch_size=self.configs.dataset.val_batch_size
549
+ load_masks=self.configs.dataset.load_masks
550
+
551
+
552
+ #source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
553
+ #target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
554
+ #mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
555
+ print('use csv dataset:',self.configs.dataset.data_dir)
556
+ if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
557
+ # check if import data is csv file
558
+ if self.configs.dataset.data_dir.endswith('.csv'):
559
+ csv_file = self.configs.dataset.data_dir
560
+ else:
561
+ raise ValueError('The data directory in this case must be a csv file!')
562
+ else:
563
+ if self.configs.server == 'helix' or self.configs.server == 'helixSingle' or self.configs.server=='helixMultiple':
564
+ csv_file = './healthy_dissec_helix.csv'
565
+ else:
566
+ csv_file = './healthy_dissec.csv'
567
+
568
+ if self.configs.dataset.input_is_mask:
569
+ load_seg=True
570
+ else:
571
+ load_seg=False
572
+ source_file_list, source_Aorta_diss_list, patient_IDs=list_img_ad_pIDs_from_anish_csv(csv_file, load_seg)
573
+ target_file_list, _, _ =list_img_ad_pIDs_from_anish_csv(csv_file)
574
+ mask_file_list, _, _=list_img_ad_pIDs_from_anish_csv(csv_file)
575
+
576
+ # here the original CT images are loaded as mask because they will be further processed as body contour and merged into mask.
577
+
578
+ if load_masks:
579
+ train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
580
+ for i, j, k, ad, pID in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number], source_Aorta_diss_list[0:train_number], patient_IDs[0:train_number])]
581
+
582
+ val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k, 'Aorta_diss':ad, 'patient_ID': pID}
583
+ for i, j, k, ad, pID in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:], source_Aorta_diss_list[-val_number:], patient_IDs[-val_number:])]
584
+ else:
585
+ train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
586
+ for i, j, ad in zip(source_file_list[0:train_number], target_file_list[0:train_number], source_Aorta_diss_list[0:train_number])]
587
+ val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j, 'Aorta_diss':ad}
588
+ for i, j, ad in zip(source_file_list[-val_number:], target_file_list[-val_number:], source_Aorta_diss_list[-val_number:])]
589
+ print('train_ds: \n')
590
+ for i in train_ds:
591
+ print(i)
592
+ print('\n')
593
+ self.train_ds=train_ds
594
+ self.val_ds=val_ds
595
+ self.source_file_list=source_file_list
596
+ self.target_file_list=target_file_list
597
+ self.mask_file_list=mask_file_list
598
+
599
+ def get_pretransforms(self, transform_list):
600
+ indicator_A=self.configs.dataset.indicator_A
601
+ indicator_B=self.configs.dataset.indicator_B
602
+
603
+ transform_list.append(CreateMaskTransformd(keys=['mask'],
604
+ body_threshold=-500,
605
+ body_mask_value=1,
606
+ ))
607
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
608
+ return transform_list
609
+
610
+
611
+ class combined_seg_loader(BaseDataLoader):
612
+ def __init__(self,configs,paths,dimension=2,**kwargs):
613
+ self.dimension = dimension
614
+ self.train_number_1 = kwargs.get('train_number_1', 170)
615
+ self.train_number_2 = kwargs.get('train_number_2', 152)
616
+ self.val_number_1 = kwargs.get('val_number_1', 10)
617
+ self.val_number_2 = kwargs.get('val_number_2', 10)
618
+ self.data_dir_1 = kwargs.get('data_dir_1', 'E:\Projects\yang_proj\data\synthrad\Task1\pelvis')
619
+ self.data_dir_2 = kwargs.get('data_dir_2', 'E:\Projects\yang_proj\SynthRad_GAN\synthrad_conversion\healthy_dissec.csv')
620
+ super().__init__(configs,paths,dimension,**kwargs)
621
+
622
+
623
+ def get_loader(self):
624
+ # define the dataset sizes for the dataset 1
625
+ self.configs.dataset.data_dir = self.data_dir_1
626
+ self.configs.dataset.train_number = self.train_number_1
627
+ self.configs.dataset.val_number = self.val_number_1
628
+ self.configs.dataset.source_name = ["ct_seg"]
629
+ self.configs.dataset.target_name = ["ct"]
630
+ self.configs.dataset.offset = 1024
631
+ loader1 = synthrad_seg_loader(self.configs,self.paths,self.dimension)
632
+ source_file_list1 = loader1.source_file_list
633
+
634
+ # define the dataset sizes for the dataset 2
635
+ self.configs.dataset.data_dir = self.data_dir_2
636
+ self.configs.dataset.train_number = self.train_number_2
637
+ self.configs.dataset.val_number = self.val_number_2
638
+ self.configs.dataset.offset = 1000
639
+ loader2 = anish_seg_loader(self.configs,self.paths,self.dimension)
640
+ source_file_list2 = loader2.source_file_list
641
+
642
+ train_ds1 = loader1.train_ds
643
+ train_ds2 = loader2.train_ds
644
+
645
+ val_ds1 = loader1.val_ds
646
+ val_ds2 = loader2.val_ds
647
+
648
+ self.train_ds = ConcatDataset([train_ds1, train_ds2])
649
+ self.val_ds = ConcatDataset([val_ds1, val_ds2])
650
+ self.source_file_list = source_file_list1+source_file_list2
651
+ def get_pretransforms(self, transform_list):
652
+ indicator_A=self.configs.dataset.indicator_A
653
+ indicator_B=self.configs.dataset.indicator_B
654
+
655
+ transform_list.append(CreateMaskTransformd(keys=['mask'],
656
+ body_threshold=-500,
657
+ body_mask_value=1,
658
+ ))
659
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
660
+ return transform_list
661
+
662
+ def save_nifti(self, save_output_path, case=0):
663
+ from monai.transforms import SaveImage
664
+ step = 0
665
+ with torch.no_grad():
666
+ for data in self.train_loader:
667
+ si_input = SaveImage(output_dir=f'{save_output_path}',
668
+ separate_folder=False,
669
+ output_postfix=f'', # aug_{step}
670
+ resample=False)
671
+ si_seg = SaveImage(output_dir=f'{save_output_path}',
672
+ separate_folder=False,
673
+ output_postfix=f'', # aug_{step}
674
+ resample=False)
675
+
676
+ image_batch = data['img'].squeeze()
677
+ seg_batch = data['seg'].squeeze()
678
+ file_path_batch = data['B_paths']
679
+ Aorta_diss = data['Aorta_diss']
680
+
681
+ batch_size = len(file_path_batch)
682
+
683
+ for i in range(batch_size):
684
+ step += 1
685
+
686
+ file_path = file_path_batch[i]
687
+ image = image_batch[i]
688
+ seg = seg_batch[i]
689
+
690
+ patient_ID = os.path.splitext(os.path.basename(file_path))[0]
691
+ save_name_img = patient_ID + str(case) + '_' + str(step)
692
+ save_name_img = os.path.join(save_output_path, save_name_img)
693
+
694
+ save_name_seg = patient_ID + str(case) + '_' + str(step) + '_seg'
695
+ save_name_seg = os.path.join(save_output_path, save_name_seg)
696
+
697
+ si_input(image.unsqueeze(0), data['img'].meta, filename=save_name_img)
698
+ si_seg(seg.unsqueeze(0), data['seg'].meta, filename=save_name_seg)
699
+
700
+ class combined_seg_assigned_loader(combined_seg_loader):
701
+ def __init__(self,configs,paths=None,dimension=2, **kwargs):
702
+ self.anatomy_list = kwargs.get('anatomy_list', 'synthrad_conversion/TA2_anatomy.csv')
703
+ super().__init__(configs, paths, dimension, **kwargs)
704
+
705
+ def get_pretransforms(self, transform_list):
706
+ indicator_A=self.configs.dataset.indicator_A
707
+ indicator_B=self.configs.dataset.indicator_B
708
+
709
+ transform_list.append(CreateMaskTransformd(keys=['mask'],
710
+ body_threshold=-500,
711
+ body_mask_value=1,
712
+ ))
713
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
714
+ transform_list.append(MaskHUAssigmentd(keys=[self.indicator_A], csv_file=self.anatomy_list))
715
+ return transform_list
716
+
717
+ def get_intensity_transforms(self, transform_list):
718
+ threshold_low=self.configs.dataset.WINDOW_LEVEL - self.configs.dataset.WINDOW_WIDTH / 2
719
+ threshold_high=self.configs.dataset.WINDOW_LEVEL + self.configs.dataset.WINDOW_WIDTH / 2
720
+ offset=(-1)*threshold_low
721
+ # if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
722
+ # mask = img > self.threshold if self.above else img < self.threshold
723
+ # res = where(mask, img, self.cval)
724
+ transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
725
+ transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
726
+ transform_list.append(ShiftIntensityd(keys=[self.indicator_A,self.indicator_B], offset=offset))
727
+ return transform_list
728
+
729
+ def get_normlization(self, transform_list):
730
+ normalize=self.configs.dataset.normalize
731
+ # offset = self.configs.dataset.offset
732
+ # we don't need normalization for segmentation mask
733
+ if normalize=='zscore':
734
+ transform_list.append(NormalizeIntensityd(keys=[self.indicator_A,self.indicator_B], nonzero=False, channel_wise=True))
735
+ print('zscore normalization')
736
+
737
+ elif normalize=='scale2000':
738
+ transform_list.append(ScaleIntensityd(keys=[self.indicator_A,self.indicator_B], minv=None, maxv=None, factor=-0.9995))
739
+ print('scale2000 normalization')
740
+
741
+ elif normalize=='none' or normalize=='nonorm':
742
+ print('no normalization')
743
+
744
+ return transform_list
745
+
746
+ class slices_nifti_DataLoader(BaseDataLoader):
747
+ def __init__(self,configs,paths=None,dimension=2, **kwargs):
748
+ super().__init__(configs, paths, dimension, **kwargs)
749
+
750
+ def get_loader(self):
751
+ print('use json dataset:',self.configs.dataset.data_dir)
752
+ if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
753
+ json_file_root = self.configs.dataset.data_dir
754
+ else:
755
+ raise ValueError('please check the data dir in config file!')
756
+ json_file_train = os.path.join(json_file_root, 'train', 'dataset.json')
757
+ json_file_val = os.path.join(json_file_root, 'val', 'dataset.json')
758
+
759
+ self.train_ds = list_from_json(json_file_train, self.indicator_A, self.indicator_B)
760
+ self.val_ds = list_from_json(json_file_val, self.indicator_A, self.indicator_B)
761
+
762
+ def create_patch_dataset_and_dataloader(self, dimension=2):
763
+ train_batch_size=self.configs.dataset.batch_size
764
+ val_batch_size=self.configs.dataset.val_batch_size
765
+ self.train_loader = DataLoader(
766
+ self.train_volume_ds,
767
+ num_workers=self.configs.dataset.num_workers,
768
+ batch_size=train_batch_size,
769
+ shuffle=True,
770
+ pin_memory=torch.cuda.is_available())
771
+
772
+ self.val_loader = DataLoader(
773
+ self.val_volume_ds,
774
+ num_workers=self.configs.dataset.num_workers,
775
+ batch_size=val_batch_size,
776
+ shuffle=False,
777
+ pin_memory=torch.cuda.is_available())
778
+
779
+ class csv_slices_DataLoader(BaseDataLoader):
780
+ def __init__(self,configs,paths=None,dimension=2, **kwargs):
781
+ super().__init__(configs, paths, dimension, **kwargs)
782
+
783
+ def get_loader(self):
784
+ print('use csv dataset:',self.configs.dataset.data_dir)
785
+ if self.configs.dataset.data_dir is not None and os.path.exists(self.configs.dataset.data_dir):
786
+ csv_file_root = self.configs.dataset.data_dir
787
+ else:
788
+ raise ValueError('please check the data dir in config file!')
789
+ folder_train = os.path.join(csv_file_root, 'train')
790
+ folder__val = os.path.join(csv_file_root, 'val')
791
+
792
+ self.train_ds = list_from_slice_csv(folder_train, self.indicator_A, self.indicator_B)
793
+ self.val_ds = list_from_slice_csv(folder__val, self.indicator_A, self.indicator_B)
794
+
795
+ def create_patch_dataset_and_dataloader(self, dimension=2):
796
+ train_batch_size=self.configs.dataset.batch_size
797
+ val_batch_size=self.configs.dataset.val_batch_size
798
+ self.train_loader = DataLoader(
799
+ self.train_volume_ds,
800
+ num_workers=self.configs.dataset.num_workers,
801
+ batch_size=train_batch_size,
802
+ shuffle=True,
803
+ pin_memory=torch.cuda.is_available())
804
+
805
+ self.val_loader = DataLoader(
806
+ self.val_volume_ds,
807
+ num_workers=self.configs.dataset.num_workers,
808
+ batch_size=val_batch_size,
809
+ shuffle=False,
810
+ pin_memory=torch.cuda.is_available())
811
+
812
+ class csv_slices_assigned_DataLoader(csv_slices_DataLoader):
813
+ def __init__(self,configs,paths=None,dimension=2, **kwargs):
814
+ super().__init__(configs, paths, dimension, **kwargs)
815
+
816
+ def get_pretransforms(self, transform_list):
817
+ transform_list.append(MaskHUAssigmentd(keys=[self.indicator_A], csv_file=r'synthrad_conversion\TA2_anatomy.csv'))
818
+ return transform_list
819
+
820
+ def get_intensity_transforms(self, transform_list):
821
+ threshold_low=self.configs.dataset.WINDOW_LEVEL - self.configs.dataset.WINDOW_WIDTH / 2
822
+ threshold_high=self.configs.dataset.WINDOW_LEVEL + self.configs.dataset.WINDOW_WIDTH / 2
823
+ offset=(-1)*threshold_low
824
+ # if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
825
+ # mask = img > self.threshold if self.above else img < self.threshold
826
+ # res = where(mask, img, self.cval)
827
+ transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
828
+ transform_list.append(ThresholdIntensityd(keys=[self.indicator_A,self.indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
829
+ transform_list.append(ShiftIntensityd(keys=[self.indicator_A,self.indicator_B], offset=offset))
830
+ return transform_list
831
+
832
+ def get_normlization(self, transform_list):
833
+ normalize=self.configs.dataset.normalize
834
+ # offset = self.configs.dataset.offset
835
+ # we don't need normalization for segmentation mask
836
+ if normalize=='zscore':
837
+ transform_list.append(NormalizeIntensityd(keys=[self.indicator_A,self.indicator_B], nonzero=False, channel_wise=True))
838
+ print('zscore normalization')
839
+
840
+ elif normalize=='scale2000':
841
+ transform_list.append(ScaleIntensityd(keys=[self.indicator_A,self.indicator_B], minv=None, maxv=None, factor=-0.9995))
842
+ print('scale2000 normalization')
843
+
844
+ elif normalize=='none' or normalize=='nonorm':
845
+ print('no normalization')
846
+
847
+ return transform_list
848
+
849
+ # for MRI -> CT task
850
+
851
+ class synthrad_mr2ct_loader(BaseDataLoader):
852
+ def __init__(self,configs,paths=None,dimension=2):
853
+ super().__init__(configs,paths,dimension)
854
+
855
+ def get_loader(self):
856
+ # volume-level transforms for both image and label
857
+ indicator_A=self.configs.dataset.indicator_A
858
+ indicator_B=self.configs.dataset.indicator_B
859
+ train_number=self.configs.dataset.train_number
860
+ val_number=self.configs.dataset.val_number
861
+ self.indicator_A=indicator_A
862
+ self.indicator_B=indicator_B
863
+ load_masks=self.configs.dataset.load_masks
864
+ # Conditional dictionary keys based on whether masks are loaded
865
+
866
+ #list all files in the folder
867
+ file_list=[i for i in os.listdir(self.configs.dataset.data_dir) if 'overview' not in i]
868
+ file_list_path=[os.path.join(self.configs.dataset.data_dir,i) for i in file_list]
869
+ #list all ct and mr files in folder
870
+
871
+
872
+ #source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
873
+ #target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
874
+ #mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
875
+ source_file_list,_=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.source_name,saved_name=None)
876
+ target_file_list,_=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name,saved_name=None)
877
+ mask_file_list,_=list_img_pID_from_synthrad_folder(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.mask_name,saved_name=None)
878
+
879
+ def write_write_file(images, file):
880
+ with open(file,"w") as file:
881
+ for image in images:
882
+ file.write(f'{image} \n')
883
+
884
+ if self.paths is not None:
885
+ write_write_file(source_file_list, os.path.join(self.paths["saved_logs_folder"],"source_filenames.txt"))
886
+ write_write_file(target_file_list, os.path.join(self.paths["saved_logs_folder"],"target_filenames.txt"))
887
+ write_write_file(mask_file_list, os.path.join(self.paths["saved_logs_folder"],"mask_filenames.txt"))
888
+
889
+ self.source_file_list=source_file_list
890
+ self.target_file_list=target_file_list
891
+ self.mask_file_list=mask_file_list
892
+
893
+ if load_masks:
894
+ train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
895
+ for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
896
+ val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
897
+ for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
898
+ else:
899
+ train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
900
+ for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])]
901
+ val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
902
+ for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])]
903
+ self.train_ds=train_ds
904
+ self.val_ds=val_ds
905
+
906
+ def get_normlization(self, transform_list):
907
+ normalize=self.configs.dataset.normalize
908
+ indicator_A=self.configs.dataset.indicator_A
909
+ indicator_B=self.configs.dataset.indicator_B
910
+ load_masks=self.configs.dataset.load_masks
911
+ if normalize=='zscore':
912
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A, indicator_B], nonzero=False, channel_wise=True))
913
+ print('zscore normalization')
914
+ elif normalize=='minmax':
915
+ transform_list.append(ScaleIntensityd(keys=[indicator_A, indicator_B], minv=-1.0, maxv=1.0))
916
+ print('minmax normalization')
917
+
918
+ elif normalize=='scale4000':
919
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
920
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
921
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975)) # x=x(1+factor)
922
+ print('scale4000 normalization')
923
+
924
+ elif normalize=='scale1000_wrongbutworks':
925
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
926
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
927
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
928
+ print('scale1000 normalization')
929
+
930
+ elif normalize=='scale1000':
931
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
932
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
933
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
934
+ print('scale1000 normalization')
935
+
936
+ elif normalize=='scale10':
937
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
938
+ #transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
939
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
940
+ print('scale10 normalization')
941
+
942
+ elif normalize=='inputonlyzscore':
943
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
944
+ print('only normalize input MRI images')
945
+
946
+ elif normalize=='inputonlyminmax':
947
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=self.configs.dataset.normmin, maxv=self.configs.dataset.normmax))
948
+ print('only normalize input MRI images')
949
+
950
+ elif normalize=='none' or normalize=='nonorm':
951
+ print('no normalization')
952
+ return transform_list
953
+
954
+ class anika_registrated_mr2ct_loader(synthrad_mr2ct_loader):
955
+ def __init__(self,configs,paths,dimension):
956
+ super().__init__(configs,paths,dimension)
957
+
958
+ def get_loader(self):
959
+ indicator_A=self.configs.dataset.indicator_A
960
+ indicator_B=self.configs.dataset.indicator_B
961
+ self.indicator_A=indicator_A
962
+ self.indicator_B=indicator_B
963
+ train_number=self.configs.dataset.train_number
964
+ val_number=self.configs.dataset.val_number
965
+ train_batch_size=self.configs.dataset.batch_size
966
+ val_batch_size=self.configs.dataset.val_batch_size
967
+ load_masks=self.configs.dataset.load_masks
968
+
969
+ # Conditional dictionary keys based on whether masks are loaded
970
+ keys = [indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B]
971
+
972
+ ct_dir = r'E:\Datasets\M2olie_Patientdata\CT'
973
+ mri_dir = r'E:\Results\MultistepReg\M2olie_Patientdata\Multistep_network_A\predict'
974
+
975
+ ct_dir = self.configs.dataset.ct_dir #'E:\Datasets\M2olie_Patientdata\CT'
976
+ mri_dir = self.configs.dataset.mri_dir #'E:\Results\MultistepReg\M2olie_Patientdata\Multistep_network_A\predict'
977
+ matched_pairs = list_from_anika_dataset(ct_dir, mri_dir, self.configs.dataset.mri_mode)
978
+ for patient_id, paths in matched_pairs.items():
979
+ print(f"Patient ID: {patient_id}, CT: {paths['CT']}, MRI: {paths['MRI']}")
980
+
981
+ # use the matched pairs to form the dataset
982
+ train_ds = [{indicator_A: paths['MRI'], indicator_B: paths['CT']} for patient_id, paths in list(matched_pairs.items())[:train_number]]
983
+ val_ds = [{indicator_A: paths['MRI'], indicator_B: paths['CT']} for patient_id, paths in list(matched_pairs.items())[-val_number:]]
dataprocesser/archive/list_dataset_combined_seg.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataprocesser.customized_transforms import CreateMaskTransformd, MergeMasksTransformd
3
+ IMG_EXTENSIONS = [
4
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
5
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
6
+ '.nrrd', '.nii.gz'
7
+ ]
8
+
9
+ def is_image_file(filename):
10
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
11
+ import torch
12
+ from dataprocesser.list_dataset_synthrad_seg import synthrad_seg_loader
13
+ from dataprocesser.list_dataset_Anish_seg import anish_seg_loader
14
+ from dataprocesser.list_dataset_base import BaseDataLoader
15
+
dataprocesser/archive/list_dataset_combined_seg_assigned.py ADDED
@@ -0,0 +1 @@
 
 
1
+
dataprocesser/archive/list_dataset_synthrad.py ADDED
File without changes
dataprocesser/archive/list_dataset_synthrad_seg.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+
3
+
dataprocesser/archive/monai_loader_3D.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ import os
3
+ import numpy as np
4
+ from monai.transforms import (
5
+ Compose,
6
+ LoadImaged,
7
+ EnsureChannelFirstd,
8
+ SqueezeDimd,
9
+ CenterSpatialCropd,
10
+ Rotate90d,
11
+ ScaleIntensityd,
12
+ ResizeWithPadOrCropd,
13
+ DivisiblePadd,
14
+
15
+ ThresholdIntensityd,
16
+ NormalizeIntensityd,
17
+ ShiftIntensityd,
18
+ Identityd,
19
+ ScaleIntensityRanged,
20
+ Spacingd,
21
+ )
22
+ from torch.utils.data import DataLoader
23
+ import torch
24
+
25
+ IMG_EXTENSIONS = [
26
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
27
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
28
+ '.nrrd', '.nii.gz'
29
+ ]
30
+
31
+ def is_image_file(filename):
32
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
33
+
34
+ def make_dataset_modality(dir, accepted_modalities = ["ct"], saved_name="source_filenames.txt"):
35
+ # it works for root path of any layer:
36
+ # data_path/Task1 or Task2/pelvis or brain
37
+ # |-patient1
38
+ # |-ct.nii.gz
39
+ # |-mr.nii.gz
40
+ # |-patient2
41
+ # |-ct.nii.gz
42
+ # |-mr.nii.gz
43
+ images = []
44
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
45
+ for roots, _, files in sorted(os.walk(dir)): # os.walk digs all folders and subfolders in all layers of dir
46
+ for file in files:
47
+ if is_image_file(file) and file.split('.')[0] in accepted_modalities:
48
+ path = os.path.join(roots, file)
49
+ images.append(path)
50
+ print(f'Found {len(images)} {accepted_modalities} files in {dir} \n')
51
+ with open(saved_name,"w") as file:
52
+ for image in images:
53
+ file.write(f'{image} \n')
54
+ return images
55
+
56
+ class monai_loader_3D:
57
+ def __init__(self,configs,paths):
58
+ self.configs=configs
59
+ self.paths=paths
60
+ self.get_loader()
61
+ self.finalcheck(ifsave=True,ifcheck=False,iftest_volumes_pixdim=False)
62
+
63
+ def get_loader(self):
64
+ # volume-level transforms for both image and label
65
+ train_transforms = self.get_transforms(self.configs,mode='train')
66
+ val_transforms = self.get_transforms(self.configs,mode='val')
67
+ indicator_A=self.configs.dataset.indicator_A
68
+ indicator_B=self.configs.dataset.indicator_B
69
+ self.indicator_A=indicator_A
70
+ self.indicator_B=indicator_B
71
+ train_number=self.configs.dataset.train_number
72
+ val_number=self.configs.dataset.val_number
73
+ train_batch_size=self.configs.dataset.batch_size
74
+ val_batch_size=self.configs.dataset.val_batch_size
75
+ load_masks=self.configs.dataset.load_masks
76
+
77
+ # Conditional dictionary keys based on whether masks are loaded
78
+ keys = [indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B]
79
+
80
+
81
+ #list all files in the folder
82
+ file_list=[i for i in os.listdir(self.configs.dataset.data_dir) if 'overview' not in i]
83
+ file_list_path=[os.path.join(self.configs.dataset.data_dir,i) for i in file_list]
84
+ #list all ct and mr files in folder
85
+
86
+
87
+ #source_file_list=[os.path.join(j,f'{self.configs.dataset.source_name}.nii.gz') for j in file_list_path] # "ct" for example
88
+ #target_file_list=[os.path.join(j,f'{self.configs.dataset.target_name}.nii.gz') for j in file_list_path] # "mr" for example
89
+ #mask_file_list=[os.path.join(j,f'{self.configs.dataset.mask_name}.nii.gz') for j in file_list_path]
90
+ source_file_list=make_dataset_modality(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.source_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"source_filenames.txt"))
91
+ target_file_list=make_dataset_modality(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.target_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"target_filenames.txt"))
92
+ mask_file_list=make_dataset_modality(self.configs.dataset.data_dir, accepted_modalities=self.configs.dataset.mask_name, saved_name=os.path.join(self.paths["saved_logs_folder"],"mask_filenames.txt"))
93
+
94
+ if load_masks:
95
+ train_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
96
+ for i, j, k in zip(source_file_list[0:train_number], target_file_list[0:train_number], mask_file_list[0:train_number])]
97
+ val_ds = [{indicator_A: i, indicator_B: j, 'mask': k, 'A_paths': i, 'B_paths': j, 'mask_path': k}
98
+ for i, j, k in zip(source_file_list[-val_number:], target_file_list[-val_number:], mask_file_list[-val_number:])]
99
+ else:
100
+ train_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
101
+ for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])]
102
+ val_ds = [{indicator_A: i, indicator_B: j, 'A_paths': i, 'B_paths': j}
103
+ for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])]
104
+
105
+ print('all files in dataset:',len(source_file_list))
106
+
107
+ # load volumes and center crop
108
+ center_crop = self.configs.dataset.center_crop
109
+ transformations_crop = [
110
+ LoadImaged(keys=keys),
111
+ EnsureChannelFirstd(keys=keys),
112
+ ]
113
+ if center_crop>0:
114
+ transformations_crop.append(CenterSpatialCropd(keys=keys, roi_size=(-1,-1,center_crop)))
115
+ transformations_crop=Compose(transformations_crop)
116
+ train_crop_ds = monai.data.Dataset(data=train_ds, transform=transformations_crop)
117
+ val_crop_ds = monai.data.Dataset(data=val_ds, transform=transformations_crop)
118
+
119
+ # load volumes
120
+ train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms)
121
+ val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms)
122
+
123
+ train_loader = DataLoader(train_volume_ds, batch_size=train_batch_size, shuffle=True, num_workers=self.configs.dataset.num_workers)
124
+ val_loader = DataLoader(val_volume_ds, batch_size=val_batch_size, shuffle=False, num_workers=self.configs.dataset.num_workers)
125
+
126
+ self.saved_name_train=self.paths["saved_name_train"]
127
+ self.saved_name_val=self.paths["saved_name_val"]
128
+
129
+ self.train_ds=train_ds
130
+ self.val_ds=val_ds
131
+ self.train_volume_ds=train_volume_ds
132
+ self.val_volume_ds=val_volume_ds
133
+
134
+ self.train_batch_size=train_batch_size
135
+ self.val_batch_size=val_batch_size
136
+
137
+ self.train_crop_ds=train_crop_ds
138
+ self.val_crop_ds=val_crop_ds
139
+ self.train_transforms=train_transforms
140
+ self.val_transforms=val_transforms
141
+
142
+ self.train_loader=train_loader
143
+ self.val_loader=val_loader
144
+
145
+ def get_transforms(self, configs, mode='train'):
146
+ normalize=configs.dataset.normalize
147
+ pad=configs.dataset.pad
148
+ resized_size=configs.dataset.resized_size
149
+ WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH
150
+ WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL
151
+ prob=configs.dataset.augmentationProb
152
+ background=configs.dataset.background
153
+ indicator_A=configs.dataset.indicator_A
154
+ indicator_B=configs.dataset.indicator_B
155
+ load_masks=self.configs.dataset.load_masks
156
+ transform_list=[]
157
+ min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2)
158
+ #transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=min, above=True, cval=background))
159
+ #transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=max, above=False, cval=-1000))
160
+ # filter the source images
161
+ # transform_list.append(ThresholdIntensityd(keys=[indicator_A], threshold=configs.dataset.MRImax, above=False, cval=0))
162
+ if normalize=='zscore':
163
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A, indicator_B], nonzero=False, channel_wise=True))
164
+ print('zscore normalization')
165
+ elif normalize=='minmax':
166
+ transform_list.append(ScaleIntensityd(keys=[indicator_A, indicator_B], minv=-1.0, maxv=1.0))
167
+ print('minmax normalization')
168
+
169
+ elif normalize=='scale4000':
170
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
171
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
172
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975)) # x=x(1+factor)
173
+ print('scale4000 normalization')
174
+
175
+ elif normalize=='scale1000_wrongbutworks':
176
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
177
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
178
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
179
+ print('scale1000 normalization')
180
+
181
+ elif normalize=='scale1000':
182
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
183
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=1024))
184
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
185
+ print('scale1000 normalization')
186
+
187
+ elif normalize=='scale10':
188
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=0, maxv=1))
189
+ #transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
190
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
191
+ print('scale10 normalization')
192
+
193
+ elif normalize=='inputonlyzscore':
194
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
195
+ print('only normalize input MRI images')
196
+
197
+ elif normalize=='inputonlyminmax':
198
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=configs.dataset.normmin, maxv=configs.dataset.normmax))
199
+ print('only normalize input MRI images')
200
+
201
+ elif normalize=='none' or normalize=='nonorm':
202
+ print('no normalization')
203
+
204
+ spaceXY=0
205
+ if spaceXY>0:
206
+ transform_list.append(Spacingd(keys=[indicator_A], pixdim=(spaceXY, spaceXY, 2.5), mode="bilinear")) #
207
+ transform_list.append(Spacingd(keys=[indicator_B, "mask"] if load_masks else [indicator_B],
208
+ pixdim=(spaceXY, spaceXY , 2.5), mode="bilinear"))
209
+
210
+ transform_list.append(ResizeWithPadOrCropd(keys=[indicator_A, indicator_B,"mask"] if load_masks else [indicator_A, indicator_B],
211
+ spatial_size=resized_size,mode=pad))
212
+
213
+ if configs.dataset.rotate:
214
+ transform_list.append(Rotate90d(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B], k=3))
215
+
216
+ if mode == 'train':
217
+ from monai.transforms import (
218
+ # data augmentation
219
+ RandRotated,
220
+ RandZoomd,
221
+ RandBiasFieldd,
222
+ RandAffined,
223
+ RandGridDistortiond,
224
+ RandGridPatchd,
225
+ RandShiftIntensityd,
226
+ RandGibbsNoised,
227
+ RandAdjustContrastd,
228
+ RandGaussianSmoothd,
229
+ RandGaussianSharpend,
230
+ RandGaussianNoised,
231
+ )
232
+ shapeAug=configs.dataset.shapeAug
233
+ if shapeAug:
234
+ transform_list.append(RandRotated(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
235
+ range_x = 0.1, range_y = 0.1, range_z = 0.1,
236
+ prob=prob, padding_mode="border", keep_size=True))
237
+ transform_list.append(RandZoomd(keys=[indicator_A, indicator_B, "mask"] if load_masks else [indicator_A, indicator_B],
238
+ prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True))
239
+ transform_list.append(RandAffined(keys=[indicator_A, indicator_B], padding_mode="border" , prob=prob))
240
+ #transform_list.append(Rand3DElasticd(keys=[indicator_A, indicator_B], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear'))
241
+ intensityAug=configs.dataset.intensityAug
242
+ if intensityAug:
243
+ print('intensity data augmentation is used')
244
+ transform_list.append(RandBiasFieldd(keys=[indicator_A], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images
245
+ transform_list.append(RandGaussianNoised(keys=[indicator_A], prob=prob, mean=0.0, std=0.01))
246
+ transform_list.append(RandAdjustContrastd(keys=[indicator_A], prob=prob, gamma=(0.5, 1.5)))
247
+ transform_list.append(RandShiftIntensityd(keys=[indicator_A], prob=prob, offsets=20))
248
+ transform_list.append(RandGaussianSharpend(keys=[indicator_A], alpha=(0.2, 0.8), prob=prob))
249
+
250
+ #transform_list.append(Rotate90d(keys=[indicator_A, indicator_B], k=3))
251
+ #transform_list.append(DivisiblePadd(keys=[indicator_A, indicator_B], k=div_size, mode="minimum"))
252
+ #transform_list.append(Identityd(keys=[indicator_A, indicator_B])) # do nothing for the no norm case
253
+ train_transforms = Compose(transform_list)
254
+ return train_transforms
255
+
256
+ def finalcheck(self,ifsave=False,ifcheck=False,iftest_volumes_pixdim=False):
257
+ if ifsave:
258
+ self.save_volumes(self.train_ds, self.val_ds, self.saved_name_train, self.saved_name_val)
259
+ if iftest_volumes_pixdim:
260
+ self.test_volumes_pixdim(self.train_volume_ds)
261
+ if ifcheck:
262
+ self.check_volumes(self.train_ds, self.train_volume_ds, self.val_volume_ds, self.val_ds)
263
+ self.check_batch_data(self.train_loader,self.val_loader,
264
+ self.train_patch_ds,self.val_volume_ds,
265
+ self.train_batch_size,self.val_batch_size)
266
+
267
+ def test_volumes_pixdim(self, train_volume_ds):
268
+ train_loader = DataLoader(train_volume_ds, batch_size=1)
269
+ for step, data in enumerate(train_loader):
270
+ mr_data=data[self.indicator_A]
271
+ ct_data=data[self.indicator_B]
272
+
273
+ print(f"source image shape: {mr_data.shape}")
274
+ print(f"source image affine:\n{mr_data.meta['affine']}")
275
+ print(f"source image pixdim:\n{mr_data.pixdim}")
276
+
277
+ # target image information
278
+ print(f"target image shape: {ct_data.shape}")
279
+ print(f"target image affine:\n{ct_data.meta['affine']}")
280
+ print(f"target image pixdim:\n{ct_data.pixdim}")
281
+
282
+ def check_volumes(self, train_ds, train_volume_ds, val_volume_ds, val_ds):
283
+ # use batch_size=1 to check the volumes because the input volumes have different shapes
284
+ train_loader = DataLoader(train_volume_ds, batch_size=1)
285
+ val_loader = DataLoader(val_volume_ds, batch_size=1)
286
+ train_iterator = iter(train_loader)
287
+ val_iterator = iter(val_loader)
288
+ print('check training data:')
289
+ idx=0
290
+ for idx in range(len(train_loader)):
291
+ try:
292
+ train_check_data = next(train_iterator)
293
+ ds_idx = idx * 1
294
+ current_item = train_ds[ds_idx]
295
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
296
+ print(idx, current_name, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
297
+ except:
298
+ ds_idx = idx * 1
299
+ current_item = train_ds[ds_idx]
300
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
301
+ print('check data error! Check the input data:',current_name)
302
+ print("checked all training data.")
303
+
304
+ print('check validation data:')
305
+ idx=0
306
+ for idx in range(len(val_loader)):
307
+ try:
308
+ val_check_data = next(val_iterator)
309
+ ds_idx = idx * 1
310
+ current_item = val_ds[ds_idx]
311
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
312
+ print(idx, current_name, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
313
+ except:
314
+ ds_idx = idx * 1
315
+ current_item = val_ds[ds_idx]
316
+ current_name = os.path.basename(os.path.dirname(current_item['image']))
317
+ print('check data error! Check the input data:',current_name)
318
+ print("checked all validation data.")
319
+
320
+ def save_volumes(self, train_ds, val_ds, saved_name_train, saved_name_val):
321
+ shape_list_train=[]
322
+ shape_list_val=[]
323
+ # use the function of saving information before
324
+ for sample in train_ds:
325
+ name = os.path.basename(os.path.dirname(sample[self.indicator_A]))
326
+ shape_list_train.append({'patient': name})
327
+ for sample in val_ds:
328
+ name = os.path.basename(os.path.dirname(sample[self.indicator_A]))
329
+ shape_list_val.append({'patient': name})
330
+ np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
331
+ np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string
332
+
333
+ def check_batch_data(self, train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size):
334
+ for idx, train_check_data in enumerate(train_loader):
335
+ ds_idx = idx * train_batch_size
336
+ current_item = train_patch_ds[ds_idx]
337
+ print('check train data:')
338
+ print(current_item, 'image:', train_check_data['image'].shape, 'label:', train_check_data['label'].shape)
339
+
340
+ for idx, val_check_data in enumerate(val_loader):
341
+ ds_idx = idx * val_batch_size
342
+ current_item = val_volume_ds[ds_idx]
343
+ print('check val data:')
344
+ print(current_item, 'image:', val_check_data['image'].shape, 'label:', val_check_data['label'].shape)
345
+
346
+ def len_patchloader(self, train_volume_ds,train_batch_size):
347
+ slice_number=sum(train_volume_ds[i][self.indicator_A].shape[-1] for i in range(len(train_volume_ds)))
348
+ print('total slices in training set:',slice_number)
349
+
350
+ import math
351
+ batch_number=sum(math.ceil(train_volume_ds[i][self.indicator_A].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
352
+ print('total batches in training set:',batch_number)
353
+ return slice_number,batch_number
354
+
355
+ def get_length(self, dataset, patch_batch_size):
356
+ loader=DataLoader(dataset, batch_size=1)
357
+ iterator = iter(loader)
358
+ sum_nslices=0
359
+ for idx in range(len(loader)):
360
+ check_data = next(iterator)
361
+ nslices=check_data[self.indicator_A].shape[-1]
362
+ sum_nslices+=nslices
363
+ if sum_nslices%patch_batch_size==0:
364
+ return sum_nslices//patch_batch_size
365
+ else:
366
+ return sum_nslices//patch_batch_size+1
367
+
dataprocesser/archive/slice_loader.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ from monai.transforms import (
3
+ Compose,
4
+ LoadImaged,
5
+ Rotate90d,
6
+ ScaleIntensityd,
7
+ EnsureChannelFirstd,
8
+ ResizeWithPadOrCropd,
9
+ DivisiblePadd,
10
+ ThresholdIntensityd,
11
+ NormalizeIntensityd,
12
+ SqueezeDimd,
13
+ Identityd,
14
+ CenterSpatialCropd,
15
+ )
16
+ from monai.data import Dataset
17
+ from torch.utils.data import DataLoader
18
+ import torch
19
+
20
+ from .basics import get_file_list, get_transforms, load_volumes, crop_volumes
21
+ from .checkdata import check_batch_data, check_volumes, save_volumes
22
+ ##### slices #####
23
+ def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=8,val_batch_size=1,window_width=1,ifcheck=True):
24
+ patch_func = monai.data.PatchIterd(
25
+ keys=["source", "target"],
26
+ patch_size=(None, None, window_width), # dynamic first two dimensions
27
+ start_pos=(0, 0, 0)
28
+ )
29
+ if window_width==1:
30
+ patch_transform = Compose(
31
+ [
32
+ SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
33
+ ]
34
+ )
35
+ else:
36
+ patch_transform = None
37
+ # for training
38
+ train_patch_ds = monai.data.GridPatchDataset(
39
+ data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
40
+ train_loader = DataLoader(
41
+ train_patch_ds,
42
+ batch_size=train_batch_size,
43
+ num_workers=0,
44
+ pin_memory=torch.cuda.is_available(),
45
+ )
46
+
47
+ # for validation
48
+ val_patch_ds = monai.data.GridPatchDataset(
49
+ data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
50
+ val_loader = DataLoader(
51
+ val_patch_ds, #val_volume_ds,
52
+ num_workers=0,
53
+ batch_size=val_batch_size,
54
+ pin_memory=torch.cuda.is_available())
55
+
56
+ if ifcheck:
57
+ check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
58
+ return train_loader,val_loader
59
+ def myslicesloader(data_pelvis_path,
60
+ normalize='minmax',
61
+ pad='minimum',
62
+ train_number=1,
63
+ val_number=1,
64
+ train_batch_size=8,
65
+ val_batch_size=1,
66
+ saved_name_train='./train_ds_2d.csv',
67
+ saved_name_val='./val_ds_2d.csv',
68
+ resized_size=(512,512,None),
69
+ div_size=(16,16,None),
70
+ center_crop=20,
71
+ ifcheck_volume=True,
72
+ ifcheck_sclices=False,):
73
+
74
+ # volume-level transforms for both image and label
75
+ train_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='train',prob=0.8)
76
+ val_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='val')
77
+ train_ds, val_ds = get_file_list(data_pelvis_path,
78
+ train_number,
79
+ val_number)
80
+ train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop)
81
+ train_ds, val_ds = load_volumes(train_transforms, val_transforms,
82
+ train_crop_ds, val_crop_ds,
83
+ train_ds, val_ds,
84
+ saved_name_train, saved_name_val,
85
+ ifsave=True,
86
+ ifcheck=ifcheck_volume)
87
+ train_loader,val_loader = load_batch_slices(train_ds,
88
+ val_ds,
89
+ train_batch_size,
90
+ val_batch_size=val_batch_size,
91
+ window_width=1,
92
+ ifcheck=ifcheck_sclices)
93
+ return train_ds, val_ds, train_loader,val_loader,train_transforms,val_transforms
94
+
95
+ def len_patchloader(train_volume_ds,train_batch_size):
96
+ slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds)))
97
+ print('total slices in training set:',slice_number)
98
+
99
+ import math
100
+ batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
101
+ print('total batches in training set:',batch_number)
102
+ return slice_number,batch_number
103
+
104
+ if __name__ == '__main__':
105
+ dataset_path=r"F:\yang_Projects\Datasets\Task1\pelvis"
106
+ train_volume_ds,_,train_loader,_,_,_ = myslicesloader(dataset_path,
107
+ normalize='none',
108
+ train_number=2,
109
+ val_number=1,
110
+ train_batch_size=4,
111
+ val_batch_size=1,
112
+ saved_name_train='./train_ds_2d.csv',
113
+ saved_name_val='./val_ds_2d.csv',
114
+ resized_size=(512, 512, None),
115
+ div_size=(16,16,None),
116
+ ifcheck_volume=False,
117
+ ifcheck_sclices=False,)
118
+ from tqdm import tqdm
119
+ parameter_file=r'.\test.txt'
120
+ for data in tqdm(train_loader):
121
+ with open(parameter_file, 'a') as f:
122
+ f.write('image batch:' + str(data["image"].shape)+'\n')
123
+ f.write('label batch:' + str(data["label"].shape)+'\n')
124
+ f.write('\n')
dataprocesser/build_dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseDataLoader:
2
+ def __init__(self, configs, paths=None, dimension=2, **kwargs):
3
+ self.configs=configs
4
+ self.paths=paths
5
+ self.init_parameters_and_transforms()
6
+ self.get_loader()
7
+ #print('all files in dataset:',len(self.source_file_list))
8
+
9
+
10
+ self.rotation_level = kwargs.get('rotation_level', 0) # Default to no rotation (0)
11
+ self.zoom_level = kwargs.get('zoom_level', 1.0) # Default to no zoom (1.0)
12
+ self.flip = kwargs.get('flip', 0) # Default to no flip
13
+
14
+ self.create_dataset(dimension=dimension)
15
+
16
+ ifsave = None if paths is None else True
17
+ self.finalcheck(ifsave=ifsave,ifcheck=False,iftest_volumes_pixdim=False)
18
+
19
+ def get_loader(self):
20
+ self.source_file_list = []
21
+ self.train_ds=[]
22
+ self.val_ds=[]
dataprocesser/config_example.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: 'ddpm2d_seg2med'
2
+ GPU_ID: [3]
3
+ ckpt_path: 'logs\241118ddpm_512.pt'
4
+ mode: 'test'
5
+ dataset:
6
+ train_csv: 'synthrad_conversion\datacsv\ct_synthrad_testrest_newserver.csv'
7
+ test_csv: 'synthrad_conversion\datacsv\ct_synthrad_testrest_newserver.csv'
8
+ batch_size: 1
9
+ val_batch_size: 8
10
+ normalize: 'scale2000'
11
+ zoom: (1.0,1.0,1.0)
12
+ resized_size: (512,512,None)
13
+ div_size: (None,None,None)
14
+ WINDOW_WIDTH: 2000
15
+ WINDOW_LEVEL: 0
16
+
17
+ train:
18
+ val_epoch_interval: 1
19
+ save_ckpt_interval: 1
20
+ num_epochs: 100
21
+ learning_rate: 0.0002
22
+ writeTensorboard: True
23
+ sample_range_lower: 0
24
+ sample_range_upper: 100000000
25
+ earlystopping_patience: 10
26
+ earlystopping_delta: 0.001
27
+
28
+ validation:
29
+ evaluate_restore_transforms: True
30
+ x_lower_limit: -1000
31
+ x_upper_limit: 3000
32
+ manual_aorta_diss: -1
33
+ ddpm:
34
+ num_train_timesteps: 500
35
+ num_inference_steps: 500
36
+ num_channels: (64, 128, 256, 256)
37
+ attention_levels: (False, False, False, True)
38
+ num_res_units: 2
39
+ norm_num_groups: 32
40
+ num_head_channels: 32
41
+ noise_type: 'normal'
42
+
43
+
dataprocesser/create_csv.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from dataprocesser.dataset_anika import (
3
+ all_list_single_modality_from_anika_dataset_include_duplicate,
4
+ extract_patientID_from_Anika_dataset,
5
+ all_list_from_anika_dataset_include_duplicate)
6
+ from dataprocesser.dataset_synthrad import list_img_pID_from_synthrad_folder
7
+ from dataprocesser.dataset_anish import list_img_seg_ad_pIDs_from_anish_csv
8
+ from dataprocesser.dataset_dominik import all_list_from_dominik_dataset
9
+ from dataprocesser.step1_init_data_list import appart_img_and_seg, appart_merged_seg
10
+ from dataprocesser.step1_init_data_list import extract_patient_id
11
+
12
+ def create_csv_combine_lists_synthrad_anika_mr(synthrad_dir, anika_dir_mr, output_mr_csv_file, ifwrtiecsv=True):
13
+ #synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr_seg"], None)
14
+ seg_name_pattern = "mr_merged_seg" #r"^mr_merged_seg_\d{1}[A-Z]{2}\d{3}$"
15
+ synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, [seg_name_pattern], None)
16
+ synthrad_mr_list, _ = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr"], None)
17
+ synthrad_Aorta_diss = [0] * len(synthrad_seg_list)
18
+ datalist_synthrad = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(synthrad_pIDs, synthrad_Aorta_diss, synthrad_seg_list, synthrad_mr_list)]
19
+
20
+ mr_list = all_list_single_modality_from_anika_dataset_include_duplicate(anika_dir_mr)
21
+ mr_files, mr_seg_files = appart_img_and_seg(mr_list)
22
+ mr_seg_files = appart_merged_seg(mr_seg_files)
23
+ mr_pIDs = extract_patientID_from_Anika_dataset(mr_files)
24
+
25
+ mr_Aorta_diss = [0] * len(mr_files)
26
+ datalist_mr = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(mr_pIDs, mr_Aorta_diss, mr_seg_files, mr_files)]
27
+
28
+ print('length dataset 1: ', len(datalist_synthrad))
29
+ print('length dataset 2: ', len(datalist_mr))
30
+ dataset_list=datalist_synthrad+datalist_mr
31
+ if ifwrtiecsv:
32
+ create_csv_info_file(dataset_list, output_mr_csv_file)
33
+ return dataset_list
34
+
35
+ def create_csv_info_file(dataset_list, output_mr_csv_file):
36
+ with open(output_mr_csv_file, 'w', newline='') as f:
37
+ csvwriter = csv.writer(f)
38
+ csvwriter.writerow(['id', 'Aorta_diss', 'seg', 'img'])
39
+ csvwriter.writerows(dataset_list)
40
+
41
+ def create_csv_synthrad_mr(synthrad_dir, output_csv_file):
42
+ synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr_merged_seg"], None)
43
+ synthrad_ct_list, _ = list_img_pID_from_synthrad_folder(synthrad_dir, ["mr"], None)
44
+ synthrad_Aorta_diss = [0] * len(synthrad_seg_list)
45
+ datalist_synthrad = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(synthrad_pIDs, synthrad_Aorta_diss, synthrad_seg_list, synthrad_ct_list)]
46
+
47
+ print('length dataset 2: ', len(datalist_synthrad))
48
+ dataset_list=datalist_synthrad
49
+ create_csv_info_file(dataset_list, output_csv_file)
50
+
51
+ def create_csv_combine_lists_synthrad_anish(synthrad_dir, anish_csv, output_csv_file):
52
+ synthrad_seg_list, synthrad_pIDs = list_img_pID_from_synthrad_folder(synthrad_dir, ["ct_seg"], None)
53
+ synthrad_ct_list, _ = list_img_pID_from_synthrad_folder(synthrad_dir, ["ct"], None)
54
+ synthrad_Aorta_diss = [0] * len(synthrad_seg_list)
55
+
56
+ #anish_pIDs, anish_Aorta_diss, anish_seg_list, anish_ct_list = list_img_seg_ad_pIDs_from_new_simplified_csv(anish_csv)
57
+ anish_pIDs, anish_Aorta_diss, anish_seg_list, anish_ct_list = list_img_seg_ad_pIDs_from_anish_csv(anish_csv)
58
+ datalist_synthrad = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(synthrad_pIDs, synthrad_Aorta_diss, synthrad_seg_list, synthrad_ct_list)]
59
+ datalist_anish = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(anish_pIDs, anish_Aorta_diss, anish_seg_list, anish_ct_list)]
60
+
61
+ print('length dataset 1: ', len(synthrad_ct_list))
62
+ print('length dataset 2: ', len(datalist_synthrad))
63
+ dataset_list=datalist_synthrad+datalist_anish
64
+ create_csv_info_file(dataset_list, output_csv_file)
65
+
66
+ def create_csv_Anika(ct_dir, mri_dir, output_ct_csv_file, output_mr_csv_file):
67
+ ct_list, mr_list = all_list_from_anika_dataset_include_duplicate(ct_dir, mri_dir)
68
+ ct_files, ct_seg_files = appart_img_and_seg(ct_list)
69
+ ct_pIDs = extract_patientID_from_Anika_dataset(ct_files)
70
+ ct_Aorta_diss = [0] * len(ct_list)
71
+ datalist_ct = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(ct_pIDs, ct_Aorta_diss, ct_seg_files, ct_files)]
72
+ create_csv_info_file(datalist_ct, output_ct_csv_file)
73
+
74
+ mr_files, mr_seg_files = appart_img_and_seg(mr_list)
75
+ mr_pIDs = extract_patientID_from_Anika_dataset(mr_files)
76
+ mr_Aorta_diss = [0] * len(mr_files)
77
+ datalist_mr = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(mr_pIDs, mr_Aorta_diss, mr_seg_files, mr_files)]
78
+ create_csv_info_file(datalist_mr, output_mr_csv_file)
79
+
80
+ def create_csv_Dominik(mri_dir, output_mr_csv_file):
81
+ mr_list = all_list_from_dominik_dataset(mri_dir)
82
+ mr_files, mr_seg_files = appart_img_and_seg(mr_list)
83
+ mr_seg_files = appart_merged_seg(mr_seg_files)
84
+ mr_pIDs = [extract_patient_id(mr_file) for mr_file in mr_files]
85
+ mr_Aorta_diss = [0] * len(mr_files)
86
+ datalist_mr = [[id,Aorta_diss,seg,image] for id,Aorta_diss,seg,image in zip(mr_pIDs, mr_Aorta_diss, mr_seg_files, mr_files)]
87
+ create_csv_info_file(datalist_mr, output_mr_csv_file)
dataprocesser/create_csv_xcat.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+
4
+ def extract_prefixes_from_directory(directory):
5
+ prefixes = set()
6
+ for filename in os.listdir(directory):
7
+ if filename.endswith('.nrrd'):
8
+ prefix = filename.split('_')[0]
9
+ prefixes.add(prefix)
10
+ return sorted(prefixes)
11
+
12
+ def save_prefixes_to_csv(prefixes, output_csv_path):
13
+ with open(output_csv_path, mode='w', newline='') as file:
14
+ writer = csv.writer(file)
15
+ for prefix in prefixes:
16
+ writer.writerow([os.path.join(directory, prefix)])
17
+
18
+ if __name__ == "__main__":
19
+ directory = r"F:\yang_Projects\ICTUNET_torch\datasets\train"
20
+ output_csv_path = r"F:\yang_Projects\ICTUNET_torch\data_table\train_all.csv"
21
+
22
+ prefixes = extract_prefixes_from_directory(directory)
23
+ save_prefixes_to_csv(prefixes, output_csv_path)
24
+
25
+ print(f"CSV file with prefixes saved to: {output_csv_path}")
dataprocesser/create_json_lodopab.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import json
3
+
4
+ from configs import config as cfg
5
+ import os
6
+
7
+ VERBOSE = cfg.verbose
8
+ from collections import defaultdict
9
+
10
+ IMG_EXTENSIONS = [
11
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
12
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
13
+ '.nrrd', '.nii.gz',
14
+ '.hdf5',
15
+ ]
16
+
17
+ def is_image_file(filename):
18
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19
+
20
+ def create_metadata_jsonl_lodopab(base_path, mode='train', output_json_file= 'lodopab_dataset.json'):
21
+ ground_truth_path = os.path.join(base_path, 'ground_truth_'+mode)
22
+ observation_path = os.path.join(base_path, 'observation_'+mode)
23
+
24
+ # Initialize dataset list
25
+ dataset_list = []
26
+
27
+ # Iterate through the ground truth files
28
+ for gt_file in os.listdir(ground_truth_path):
29
+ if is_image_file(gt_file):
30
+ # Get the corresponding observation file
31
+ obs_file = gt_file.replace('ground_truth', 'observation')
32
+
33
+ # Create the entry
34
+ entry = {
35
+ 'ground_truth': os.path.join(ground_truth_path, gt_file),
36
+ 'observation': os.path.join(observation_path, obs_file)
37
+ }
38
+
39
+ # Append to the dataset list
40
+ dataset_list.append(entry)
41
+
42
+ # Save the dataset list as a JSON file
43
+ with open(output_json_file, 'w') as json_file:
44
+ json.dump(dataset_list, json_file, indent=4)
45
+
46
+ print(f'Dataset list saved to lodopab_dataset.json with {len(dataset_list)} entries.')
47
+
48
+ def read_metadata_jsonl(file_path):
49
+ with open(file_path, 'r') as f:
50
+ dataset = json.load(f)
51
+ return dataset
52
+
53
+ def print_json_info(data_info):
54
+ for entry in tqdm(data_info, desc="Calculating slice info"):
55
+ print(entry['patient_name'])
56
+
57
+ if __name__ == '__main__':
58
+ base_path = r"F:\yang_Projects\Datasets\LoDoPaB"
59
+ create_metadata_jsonl_lodopab(base_path, mode='train', output_json_file= './data_table/lodopab_dataset.json')
dataprocesser/create_json_xcat.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import json
3
+
4
+ from configs import config as cfg
5
+ import os
6
+
7
+ VERBOSE = cfg.verbose
8
+ from collections import defaultdict
9
+
10
+ IMG_EXTENSIONS = [
11
+ #'.jpg', '.JPG', '.jpeg', '.JPEG',
12
+ #'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
13
+ '.nrrd', '.nii.gz',
14
+ '.hdf5',
15
+ ]
16
+
17
+ def is_image_file(filename):
18
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19
+
20
+ def create_metadata_jsonl_xcat(base_path,
21
+ mode='train',
22
+ sino_entry = "_sino_Metal.nrrd",
23
+ img_entry = "_img_GT_noNoise.nrrd",
24
+ output_json_file= 'xcat_dataset.json'):
25
+
26
+ train_set_path = os.path.join(base_path, mode)
27
+ # Initialize dataset list
28
+ dataset_list = []
29
+ prefixes = set()
30
+ for filename in os.listdir(train_set_path):
31
+ if is_image_file(filename):
32
+ prefix = filename.split('_')[0]
33
+ prefixes.add(prefix)
34
+ prefixes = sorted(prefixes)
35
+
36
+ for prefix in prefixes:
37
+ sino_path = os.path.join(train_set_path, prefix + sino_entry)
38
+ img_path = os.path.join(train_set_path, prefix + img_entry)
39
+
40
+ # Create the entry
41
+ entry = {
42
+ 'ground_truth': img_path,
43
+ 'observation': sino_path
44
+ }
45
+
46
+ # Append to the dataset list
47
+ dataset_list.append(entry)
48
+
49
+ # Save the dataset list as a JSON file
50
+ with open(output_json_file, 'w') as json_file:
51
+ json.dump(dataset_list, json_file, indent=4)
52
+
53
+ print(f'Dataset list saved to xcat_dataset.json with {len(dataset_list)} entries.')
54
+
55
+ def read_metadata_jsonl(file_path):
56
+ with open(file_path, 'r') as f:
57
+ dataset = json.load(f)
58
+ return dataset
59
+
60
+ def print_json_info(data_info):
61
+ for entry in tqdm(data_info, desc="Calculating slice info"):
62
+ print(entry['patient_name'])
63
+
64
+ if __name__ == '__main__':
65
+ base_path = r"F:\yang_Projects\ICTUNET_torch\datasets"
66
+ create_metadata_jsonl_xcat(base_path,
67
+ mode='train',
68
+ sino_entry = "_sino_Metal.nrrd",
69
+ img_entry = "_img_GT_noNoise.nrrd",
70
+ output_json_file= './data_table/xcat_dataset.json')
dataprocesser/customized_datasets.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ import nibabel as nib
3
+ import torch
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+ VERBOSE = False
12
+
13
+
14
+ def volume_slicer(volume_tensor, transform, all_slices=None):
15
+ # Convert numpy array to PyTorch tensor
16
+ # Note: You might need to add channel dimension or perform other adjustments
17
+ volume_tensor = volume_tensor.permute(2, 1, 0) # [H, W, D] -> [D, H, W]
18
+ volume_tensor = volume_tensor.unsqueeze(1) # Add channel dimension [D, H, W] -> [D, 1, H, W]
19
+ if transform is not None:
20
+ volume_tensor = transform(volume_tensor)
21
+
22
+ #print('stacking volume tensor:',volume_tensor.shape)
23
+ if all_slices is None:
24
+ all_slices = volume_tensor
25
+ else:
26
+ all_slices = torch.cat((all_slices, volume_tensor), 0)
27
+ return all_slices
28
+
29
+ def infinite_loader(loader):
30
+ """Yield batches indefinitely from a DataLoader."""
31
+ while True:
32
+ for batch in loader:
33
+ yield batch
34
+ # This explicitly resets the iterator
35
+ loader.dataset.reset()
36
+
37
+ class csvDataset_3D(Dataset):
38
+ def __init__(self, csv_file, transform=None, load_patient_number=1):
39
+ """
40
+ Args:
41
+ csv_file (string): Path to the csv file with annotations.
42
+ transform (callable, optional): Optional transform to be applied on a sample.
43
+ """
44
+ self.data_frame = pd.read_csv(csv_file)
45
+ # control the length of the dataset
46
+ self.data_frame = self.data_frame[:load_patient_number]
47
+ self.transform = transform
48
+
49
+ def __len__(self):
50
+ return len(self.data_frame)
51
+
52
+ def __getitem__(self, idx):
53
+ if torch.is_tensor(idx):
54
+ idx = idx.tolist()
55
+
56
+ img_path = self.data_frame.iloc[idx, -1]
57
+ image = nib.load(img_path).get_fdata()
58
+ image = torch.tensor(image, dtype=torch.float32)
59
+
60
+ # Example: Using the 'Aorta_diss' column as a label
61
+ label = self.data_frame.iloc[idx, -3]
62
+ #label = torch.tensor(label, dtype=torch.float32)
63
+
64
+ # If more processing is needed (e.g., normalization, adding channel dimension), do it here
65
+ image = image.unsqueeze(0) # Add channel dimension if it's a single channel image
66
+
67
+ sample = {'image': image, 'label': label}
68
+
69
+ return sample
70
+
71
+ class csvDataset_2D(Dataset):
72
+ def __init__(self, csv_file, transform=None, load_patient_number=1):
73
+ self.csv_file = csv_file
74
+ self.transform = transform
75
+ self.load_patient_number = load_patient_number
76
+ self.data_frame = pd.read_csv(csv_file)
77
+ if len(self.data_frame) == 0:
78
+ raise RuntimeError(f"Found 0 images in: {csv_file}")
79
+
80
+ # Initialize dataset
81
+ self.initialize_dataset()
82
+
83
+ def initialize_dataset(self):
84
+ print('Loading dataset...')
85
+ self.data_frame = self.data_frame[:self.load_patient_number]
86
+ all_slices = None
87
+ all_labels = []
88
+
89
+ for idx in tqdm(range(len(self.data_frame))):
90
+ img_path = self.data_frame.iloc[idx, -1]
91
+ volume = nib.load(img_path)
92
+ volume_data = volume.get_fdata() # Load as [H, W, D]
93
+ volume_tensor = torch.tensor(volume_data, dtype=torch.float32)
94
+ all_slices = volume_slicer(volume_tensor, self.transform, all_slices) # -> [D, 1, H, W] and pile up all the slices
95
+ label = self.data_frame.iloc[idx, -3]
96
+ all_labels = all_labels + [label] * volume_tensor.shape[0]
97
+
98
+ print('All stacked slices:', all_slices.shape)
99
+ self.all_slices = all_slices
100
+ self.all_labels = all_labels
101
+
102
+ def __len__(self):
103
+ return self.all_slices.shape[0]
104
+
105
+ def __getitem__(self, idx):
106
+ if torch.is_tensor(idx):
107
+ idx = idx.tolist()
108
+ image = self.all_slices[idx]
109
+ label = self.all_labels[idx]
110
+ sample = {'source': image, 'target': label}
111
+ return sample
112
+
113
+ def reset(self):
114
+ print('Resetting dataset...')
115
+ self.initialize_dataset()
dataprocesser/customized_normalization.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ import numpy as np
3
+ from scipy.interpolate import interp1d
4
+
5
+ def nyul_apply_standard_scale(input_image,
6
+ standard_hist,
7
+ input_mask=None,
8
+ interp_type='linear'):
9
+ """
10
+
11
+ Based on J.Reinhold code:
12
+ https://github.com/jcreinhold/intensity-normalization
13
+
14
+ Use Nyul and Udupa method ([1,2]) to normalize the intensities
15
+ of a MRI image passed as input.
16
+
17
+ Args:
18
+ input_image (np.ndarray): input image to normalize
19
+ standard_hist (str): path to output or use standard histogram landmarks
20
+ input_mask (nii): optional brain mask
21
+
22
+ Returns:
23
+ normalized (np.ndarray): normalized input image
24
+
25
+ References:
26
+ [1] N. Laszlo G and J. K. Udupa, “On Standardizing the MR Image
27
+ Intensity Scale,” Magn. Reson. Med., vol. 42, pp. 1072–1081,
28
+ 1999.
29
+ [2] M. Shah, Y. Xiao, N. Subbanna, S. Francis, D. L. Arnold,
30
+ D. L. Collins, and T. Arbel, “Evaluating intensity
31
+ normalization on MRIs of human brain with multiple sclerosis,”
32
+ Med. Image Anal., vol. 15, no. 2, pp. 267–282, 2011.
33
+ """
34
+
35
+ # load learned standard scale and the percentiles
36
+ standard_scale, percs = np.load(standard_hist)
37
+
38
+ # apply transformation to image
39
+ return do_hist_normalization(input_image,
40
+ percs,
41
+ standard_scale,
42
+ input_mask,
43
+ interp_type=interp_type)
44
+
45
+ def do_hist_normalization(input_image,
46
+ landmark_percs,
47
+ standard_scale,
48
+ mask=None,
49
+ interp_type='linear'):
50
+ """
51
+ do the Nyul and Udupa histogram normalization routine with a given set of
52
+ learned landmarks
53
+
54
+ Based on J.Reinhold code:
55
+ https://github.com/jcreinhold/intensity-normalization
56
+
57
+ Args:
58
+ input_image (np.ndarray): image on which to find landmarks
59
+ landmark_percs (np.ndarray): corresponding landmark points of standard scale
60
+ standard_scale (np.ndarray): landmarks on the standard scale
61
+ mask (np.ndarray): foreground mask for img
62
+ interp_type (str): type of interpolation
63
+
64
+ Returns:
65
+ normalized (np.ndarray): normalized image
66
+ """
67
+
68
+ mask_data = input_image > input_image.mean() if mask is None else mask
69
+ masked = input_image[mask_data > 0] # extract only part of image where mask is non-emtpy
70
+ landmarks = get_landmarks(masked, landmark_percs)
71
+
72
+ f = interp1d(landmarks, standard_scale, kind=interp_type, fill_value='extrapolate') # define interpolating function
73
+
74
+ # apply transformation to input image
75
+ return f(input_image)
76
+
77
+ def get_landmarks(img, percs):
78
+ """
79
+ get the landmarks for the Nyul and Udupa norm method for a specific image
80
+
81
+ Based on J.Reinhold code:
82
+ https://github.com/jcreinhold/intensity-normalization
83
+
84
+ Args:
85
+ img (nibabel.nifti1.Nifti1Image): image on which to find landmarks
86
+ percs (np.ndarray): corresponding landmark percentiles to extract
87
+
88
+ Returns:
89
+ landmarks (np.ndarray): intensity values corresponding to percs in img
90
+ """
91
+ landmarks = np.percentile(img, percs)
92
+ return landmarks
93
+
94
+ def nyul_train_standard_scale(img_fns,
95
+ mask_fns=None,
96
+ i_min=1,
97
+ i_max=99,
98
+ i_s_min=1,
99
+ i_s_max=100,
100
+ l_percentile=10,
101
+ u_percentile=90,
102
+ step=10):
103
+ """
104
+ determine the standard scale for the set of images
105
+
106
+ Based on J.Reinhold code:
107
+ https://github.com/jcreinhold/intensity-normalization
108
+
109
+
110
+ Args:
111
+ img_fns (list): set of NifTI MR image paths which are to be normalized
112
+ mask_fns (list): set of corresponding masks (if not provided, estimated)
113
+ i_min (float): minimum percentile to consider in the images
114
+ i_max (float): maximum percentile to consider in the images
115
+ i_s_min (float): minimum percentile on the standard scale
116
+ i_s_max (float): maximum percentile on the standard scale
117
+ l_percentile (int): middle percentile lower bound (e.g., for deciles 10)
118
+ u_percentile (int): middle percentile upper bound (e.g., for deciles 90)
119
+ step (int): step for middle percentiles (e.g., for deciles 10)
120
+
121
+ Returns:
122
+ standard_scale (np.ndarray): average landmark intensity for images
123
+ percs (np.ndarray): array of all percentiles used
124
+ """
125
+
126
+ # compute masks is those are not entered as a parameters
127
+ mask_fns = [None] * len(img_fns) if mask_fns is None else mask_fns
128
+
129
+ percs = np.concatenate(([i_min],
130
+ np.arange(l_percentile, u_percentile+1, step),
131
+ [i_max]))
132
+ standard_scale = np.zeros(len(percs))
133
+
134
+ # process each image in order to build the standard scale
135
+ for i, (img_fn, mask_fn) in enumerate(zip(img_fns, mask_fns)):
136
+ print('processing scan ', img_fn)
137
+ img_data = nib.load(img_fn).get_data() # extract image as numpy array
138
+ mask = nib.load(mask_fn) if mask_fn is not None else None # load mask as nibabel object
139
+ mask_data = img_data > img_data.mean() \
140
+ if mask is None else mask.get_data() # extract mask as numpy array
141
+ masked = img_data[mask_data > 0] # extract only part of image where mask is non-emtpy
142
+ landmarks = get_landmarks(masked, percs)
143
+ min_p = np.percentile(masked, i_min)
144
+ max_p = np.percentile(masked, i_max)
145
+ f = interp1d([min_p, max_p], [i_s_min, i_s_max]) # create interpolating function
146
+ landmarks = np.array(f(landmarks)) # interpolate landmarks
147
+ standard_scale += landmarks # add landmark values of this volume to standard_scale
148
+ standard_scale = standard_scale / len(img_fns) # get mean values
149
+ return standard_scale, percs
dataprocesser/customized_transform_list.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataprocesser.customized_transforms import (
2
+ CreateBodyContourTransformd,
3
+ MergeMasksTransformd,
4
+ UseContourToFilterImaged,
5
+ MaskHUAssigmentd,
6
+ MergeSegTissueTransformd,
7
+ NormalizationMultimodal,
8
+ CreateMaskWithBonesTransformd)
9
+ from monai.transforms import (
10
+ Compose,
11
+ LoadImaged,
12
+ EnsureChannelFirstd,
13
+ SqueezeDimd,
14
+ CenterSpatialCropd,
15
+ Rotate90d,
16
+ ScaleIntensityd,
17
+ ResizeWithPadOrCropd,
18
+ DivisiblePadd,
19
+ Zoomd,
20
+ ThresholdIntensityd,
21
+ NormalizeIntensityd,
22
+ ShiftIntensityd,
23
+ Identityd,
24
+ ScaleIntensityRanged,
25
+ Spacingd,
26
+ SaveImage,
27
+ )
28
+ ## intensity transforms
29
+ def add_normalization_transform_single_B(transform_list, indicator_B, normalize):
30
+ if normalize=='zscore':
31
+ transform_list.append(NormalizeIntensityd(keys=[indicator_B], nonzero=False, channel_wise=True))
32
+ print('zscore normalization')
33
+ elif normalize=='minmax':
34
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=-1.0, maxv=1.0))
35
+ print('minmax normalization')
36
+
37
+ elif normalize=='scale1000_wrongbutworks':
38
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=0))
39
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], factor=-0.999))
40
+ print('scale1000 normalization')
41
+
42
+ elif normalize=='scale4000':
43
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.99975))
44
+ print('scale4000 normalization')
45
+
46
+ elif normalize=='scale2000':
47
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.9995))
48
+ print('scale2000 normalization')
49
+
50
+ elif normalize=='scale1000':
51
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None, factor=-0.999))
52
+ print('scale1000 normalization')
53
+
54
+ elif normalize=='scale100':
55
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.99))
56
+ print('scale10 normalization')
57
+
58
+ elif normalize=='scale10':
59
+ transform_list.append(ScaleIntensityd(keys=[indicator_B], minv=None, maxv=None,factor=-0.9))
60
+ print('scale10 normalization')
61
+
62
+ elif normalize == 'nonegative':
63
+ offset=1000
64
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=offset))
65
+ print('none negative normalization')
66
+
67
+ elif normalize=='none' or normalize=='nonorm':
68
+ print('no normalization')
69
+
70
+ return transform_list
71
+
72
+ def add_normalization_multimodal(transform_list, indicator_A, indicator_B):
73
+ transform_list.append(NormalizationMultimodal(keys=[indicator_A,indicator_B]))
74
+ return transform_list
75
+
76
+ def add_normalization_transform_A_B(transform_list, normalize, indicator_A, indicator_B):
77
+ if normalize=='zscore':
78
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A,indicator_B], nonzero=False, channel_wise=True))
79
+ print('zscore normalization')
80
+ elif normalize=='scale2000':
81
+ transform_list.append(ScaleIntensityd(keys=[indicator_A,indicator_B], minv=None, maxv=None, factor=-0.9995))
82
+ print('scale2000 normalization')
83
+ elif normalize=='none' or normalize=='nonorm':
84
+ print('no normalization')
85
+ return transform_list
86
+
87
+ def add_normalization_transform_input_only(transform_list, indicator_A, normalize):
88
+ if normalize=='inputonlyzscore':
89
+ transform_list.append(NormalizeIntensityd(keys=[indicator_A], nonzero=False, channel_wise=True))
90
+ print('only normalize input MRI images')
91
+
92
+ elif normalize=='inputonlyminmax':
93
+ normmin=0
94
+ normmax=1
95
+ transform_list.append(ScaleIntensityd(keys=[indicator_A], minv=normmin, maxv=normmax))
96
+ print('only normalize input MRI images')
97
+
98
+ def add_CreateContour_MergeMask_transforms(transform_list, indicator_A):
99
+ transform_list.append(CreateBodyContourTransformd(keys=['mask'],
100
+ body_threshold=-500,
101
+ body_mask_value=1,
102
+ ))
103
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask']))
104
+ return transform_list
105
+
106
+ def add_CreateContour_MergeMask_MaskHUAssign_transforms(transform_list, indicator_A, anatomy_list_csv):
107
+ transform_list.append(CreateBodyContourTransformd(keys=['mask'],
108
+ body_threshold=-500,
109
+ body_mask_value=1,
110
+ )) # image -> contour
111
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask'])) # seg+contour -> seg
112
+ transform_list.append(MaskHUAssigmentd(keys=[indicator_A], csv_file=anatomy_list_csv))
113
+ return transform_list
114
+
115
+ def add_CreateContour_MergeSegTissue_MergeMask_MaskHUAssign_transforms(transform_list, indicator_A, anatomy_list_csv, anatomy_list_csv_mr):
116
+ transform_list.append(CreateBodyContourTransformd(keys=['mask'],
117
+ body_threshold=-500,
118
+ body_mask_value=1,
119
+ )) # image -> contour
120
+ transform_list.append(MergeSegTissueTransformd(keys=[indicator_A, 'seg_tissue'])) # seg+seg_tissue -> seg
121
+ transform_list.append(MergeMasksTransformd(keys=[indicator_A, 'mask'])) # seg+contour -> seg
122
+ transform_list.append(MaskHUAssigmentd(keys=[indicator_A], csv_file=anatomy_list_csv))
123
+ return transform_list
124
+
125
+ def add_Windowing_ZeroShift_ContourFilter_A_B_transforms(transform_list, WINDOW_LEVEL, WINDOW_WIDTH, indicator_A, indicator_B):
126
+ threshold_low=WINDOW_LEVEL - WINDOW_WIDTH / 2
127
+ threshold_high=WINDOW_LEVEL + WINDOW_WIDTH / 2
128
+ offset=(-1)*threshold_low
129
+ # if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
130
+ # mask = img > self.threshold if self.above else img < self.threshold
131
+ # res = where(mask, img, self.cval)
132
+ transform_list.append(ThresholdIntensityd(keys=[indicator_A,indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
133
+ transform_list.append(ThresholdIntensityd(keys=[indicator_A,indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
134
+ transform_list.append(ShiftIntensityd(keys=[indicator_A,indicator_B], offset=offset))
135
+ transform_list.append(UseContourToFilterImaged(keys=[indicator_B, 'mask'])) # image*contour -> image
136
+ return transform_list
137
+
138
+ def add_Windowing_ZeroShift_ContourFilter_single_B_transforms(transform_list, WINDOW_LEVEL, WINDOW_WIDTH, indicator_B):
139
+ threshold_low=WINDOW_LEVEL - WINDOW_WIDTH / 2
140
+ threshold_high=WINDOW_LEVEL + WINDOW_WIDTH / 2
141
+ offset=(-1)*threshold_low
142
+ # if filter out the pixel with values below threshold1, set above=True, and the cval1>=threshold1, otherwise there will be problem
143
+ # mask = img > self.threshold if self.above else img < self.threshold
144
+ # res = where(mask, img``, self.cval)
145
+ transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=threshold_low, above=True, cval=threshold_low))
146
+ transform_list.append(ThresholdIntensityd(keys=[indicator_B], threshold=threshold_high, above=False, cval=threshold_high))
147
+ transform_list.append(ShiftIntensityd(keys=[indicator_B], offset=offset))
148
+ transform_list.append(UseContourToFilterImaged(keys=[indicator_B, 'mask'])) # image*contour -> image
149
+ return transform_list
dataprocesser/customized_transforms.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from typing import List
5
+
6
+ VERBOSE = False
7
+
8
+ def get_data_scaler(config):
9
+ """Data normalizer. Assume data are always in [0, 1]."""
10
+ if config.data.centered:
11
+ # Rescale to [-1, 1]
12
+ return lambda x: x * 2. - 1.
13
+ else:
14
+ return lambda x: x
15
+
16
+ def get_data_inverse_scaler(config):
17
+ """Inverse data normalizer."""
18
+ if config.data.centered:
19
+ # Rescale [-1, 1] to [0, 1]
20
+ return lambda x: (x + 1.) / 2.
21
+ else:
22
+ return lambda x: x
23
+
24
+ def separate_maps(real_images,
25
+ tissue_min, tissue_max,
26
+ bone_min, bone_max):
27
+ mask = torch.zeros_like(real_images)
28
+ # Assign label 1 to tissue regions
29
+ mask[(real_images > tissue_min) & (real_images <= tissue_max)] = 1
30
+ # Assign label 2 to bone regions
31
+ mask[(real_images >= bone_min) & (real_images <= bone_max)] = 2
32
+ return mask
33
+
34
+ def create_body_contour_old(tensor_img, body_threshold=-500):
35
+ """
36
+ Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
37
+ There would be problem if more body parts are presented (like two arms)
38
+ Args:
39
+ tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
40
+
41
+ Returns:
42
+ torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
43
+ """
44
+ # Convert tensor to numpy array
45
+ numpy_img = tensor_img.numpy().astype(np.int16) # Ensure we can handle negative values correctly
46
+
47
+ # Threshold the image at -500 to separate potential body from the background
48
+ binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
49
+ #print(binary_img.shape)
50
+ #print(binary_img)
51
+ # Find contours from the binary image
52
+ contours, _ = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
+ # Create an empty mask and fill the largest contour
54
+ mask = np.zeros_like(binary_img)
55
+ if contours:
56
+ # Assume the largest contour is the body contour
57
+ largest_contour = max(contours, key=cv2.contourArea)
58
+ cv2.drawContours(mask, [largest_contour], -1, 1, thickness=cv2.FILLED)
59
+
60
+ # Convert the mask back to a tensor
61
+ mask_tensor = torch.tensor(mask, dtype=torch.int32)
62
+
63
+ return mask_tensor
64
+
65
+ def create_body_contour(tensor_img, body_threshold=-500, min_contour_area=10000):
66
+ """
67
+ Create a binary body mask from a CT image tensor, using a specific threshold for the body parts.
68
+ Solve problem that more body parts are presented (like two arms)
69
+
70
+ Args:
71
+ tensor_img (torch.Tensor): A tensor representation of a grayscale CT image, with intensity values from -1024 to 1500.
72
+
73
+ Returns:
74
+ torch.Tensor: A binary mask tensor where the entire body region is 1 and the background is 0.
75
+ """
76
+ # Convert tensor to numpy array
77
+ if isinstance(tensor_img, torch.Tensor):
78
+ numpy_img = tensor_img.numpy().astype(np.int16) # Ensure we can handle negative values correctly
79
+ elif isinstance(tensor_img, np.ndarray):
80
+ numpy_img = np.ascontiguousarray(tensor_img.astype(np.int16))
81
+ else:
82
+ print("This is not a PyTorch tensor or a NumPy array. Please Check!")
83
+ # Threshold the image at -500 to separate potential body from the background
84
+ binary_img = np.where(numpy_img > body_threshold, 1, 0).astype(np.uint8)
85
+
86
+ # Find contours from the binary image
87
+ contours, _ = cv2.findContours(binary_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
88
+
89
+ # Create an empty mask
90
+ mask = np.zeros_like(binary_img)
91
+
92
+ # Fill all detected body contours
93
+ if contours:
94
+ for contour in contours:
95
+ if cv2.contourArea(contour) >= min_contour_area:
96
+ if VERBOSE:
97
+ print('current contour area: ', cv2.contourArea(contour), 'threshold: ', min_contour_area)
98
+ cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED)
99
+
100
+ # Convert the mask back to a tensor
101
+ mask_tensor = torch.tensor(mask, dtype=torch.int32)
102
+
103
+ return mask_tensor
104
+
105
+ import numpy as np
106
+ import cv2
107
+
108
+ def create_body_contour_by_seg_tissue(binary_mask: np.ndarray, area_threshold=1000) -> np.ndarray:
109
+ """
110
+ 提取组织分割图中的身体轮廓(保留最大连通域/多个大区域),输出二值 mask。
111
+
112
+ 参数:
113
+ binary_mask: np.ndarray, 2D 输入图,非 0 为组织区域
114
+ area_threshold: int, 保留的最小连通域面积
115
+ 返回:
116
+ contour_mask: np.uint8, 2D binary mask (0 or 1)
117
+ """
118
+ mask_uint8 = (binary_mask > 0).astype(np.uint8).copy() * 255
119
+
120
+ # 找轮廓(忽略空洞)
121
+ contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
122
+
123
+ contour_mask = np.zeros_like(mask_uint8, dtype=np.uint8)
124
+
125
+ for cnt in contours:
126
+ area = cv2.contourArea(cnt)
127
+ if area > area_threshold:
128
+ cv2.drawContours(contour_mask, [cnt], -1, 1, thickness=-1) # 填充轮廓
129
+
130
+ return (contour_mask > 0).astype(np.uint8)
131
+
132
+ import pandas as pd
133
+
134
+ def HU_assignment(mask, csv_file):
135
+ if isinstance(mask, torch.Tensor):
136
+ hu_mask = torch.zeros_like(mask)
137
+ elif isinstance(mask, np.ndarray):
138
+ hu_mask = np.zeros_like(mask)
139
+
140
+ df = pd.read_csv(csv_file)
141
+ hu_values = dict(zip(df['Order Number'], df['HU Value']))
142
+ order_begin_from_0 = True if df['Order Number'].min()==0 else False
143
+ # Value Assigment
144
+ hu_mask[mask == 0] = -1000 # background
145
+ for organ_index, hu_value in hu_values.items():
146
+ assert isinstance(hu_value, int), f"Expected mask value an integer, but got {hu_value}. Ensure the mask is created by fine mode of totalsegmentator"
147
+ assert isinstance(organ_index, int), f"Expected organ_index an integer, but got {organ_index}. Ensure the mask is created by fine mode of totalsegmentator"
148
+ if order_begin_from_0:
149
+ hu_mask[mask == (organ_index+1)] = hu_value # mask value begin from 1 as body value, other than 0 in TA2 table, so organ_index+1
150
+ else:
151
+ hu_mask[mask == (organ_index)] = hu_value
152
+ return hu_mask
153
+
154
+ class MaskHUAssigmentd:
155
+ def __init__(self, keys, csv_file):
156
+ self.keys = keys
157
+ # Read the CSV into a DataFrame
158
+ self.df = pd.read_csv(csv_file)
159
+ #print(self.hu_values)
160
+
161
+ def __call__(self, data):
162
+ # Create a dictionary to map organ index to HU values
163
+ for key in self.keys:
164
+ mask = data[key]
165
+
166
+ self.hu_values = dict(zip(self.df['Order Number'], self.df['HU Value']))
167
+ self.order_begin_from_0 = True if self.df['Order Number'].min()==0 else False
168
+ hu_mask = torch.zeros_like(mask)
169
+ # Value Assigment
170
+ hu_mask[mask == 0] = -1000 # background
171
+ for organ_index, hu_value in self.hu_values.items():
172
+ assert isinstance(hu_value, int), f"Expected mask value an integer, but got {hu_value}. Ensure the mask is created by fine mode of totalsegmentator"
173
+ assert isinstance(organ_index, int), f"Expected organ_index an integer, but got {organ_index}. Ensure the mask is created by fine mode of totalsegmentator"
174
+ if self.order_begin_from_0:
175
+ hu_mask[mask == (organ_index+1)] = hu_value # mask value begin from 1 as body value, other than 0 in TA2 table, so organ_index+1
176
+ else:
177
+ hu_mask[mask == (organ_index)] = hu_value
178
+ data[key] = hu_mask
179
+ return data
180
+ import pandas as pd
181
+ import numpy as np
182
+
183
+ def convert_segmentation_mask(source_mask, source_csv, target_csv, body_contour_value=1):
184
+ """
185
+ Converts segmentation mask values from source modality to target modality based on organ name mapping.
186
+
187
+ Parameters:
188
+ - source_mask (ndarray): The source segmentation mask array.
189
+ - source_csv (str): Path to the CSV file of the source modality (CT or MR).
190
+ - target_csv (str): Path to the CSV file of the target modality (MR or CT).
191
+ - body_contour_value (int): The class value for "body contour" in the target modality.
192
+
193
+ Returns:
194
+ - target_mask (ndarray): The converted segmentation mask.
195
+ """
196
+ # Load the source and target anatomy lists
197
+ source_df = pd.read_csv(source_csv)
198
+ target_df = pd.read_csv(target_csv)
199
+
200
+ # Create dictionaries mapping class values to organ names and vice versa
201
+ source_mapping = {row['Organ Name']: row.iloc[0] for _, row in source_df.iterrows()}
202
+ target_mapping = {row['Organ Name']: row.iloc[0] for _, row in target_df.iterrows()}
203
+
204
+ # Initialize the target mask
205
+ target_mask = np.full_like(source_mask, body_contour_value, dtype=source_mask.dtype)
206
+
207
+ # Convert each unique class in the source mask
208
+ for class_value in np.unique(source_mask):
209
+ # Find the corresponding organ name in the source modality
210
+ organ_name = {v: k for k, v in source_mapping.items()}.get(class_value, None)
211
+
212
+ # If organ name exists, find the target class value
213
+ if organ_name and organ_name in target_mapping:
214
+ target_value = target_mapping[organ_name]
215
+ else:
216
+ # Use body contour class value for unmapped organs
217
+ target_value = body_contour_value
218
+
219
+ # Replace class values in the target mask
220
+ target_mask[source_mask == class_value] = target_value
221
+
222
+ return target_mask
223
+
224
+ class CreateBodyContourTransformd:
225
+ def __init__(self, keys, body_threshold,body_mask_value):
226
+ self.keys = keys
227
+ self.body_threshold = body_threshold
228
+ self.body_mask_value = body_mask_value
229
+
230
+ def __call__(self, data):
231
+ # input medical image (CT) and create body contour, then replace the image by contour
232
+ for key in self.keys:
233
+ x = data[key]
234
+ #print(x)
235
+ mask = torch.zeros_like(x)
236
+ # [B, H, W, D]
237
+ # create a mask for each slice in the batch
238
+ for i in range(x.shape[0]):
239
+ for j in range(x.shape[-1]):
240
+ mask_slice = create_body_contour(x[i,:,:,j], body_threshold=self.body_threshold)
241
+ mask[i,:,:, j] = mask_slice
242
+ mask[mask == 1] = self.body_mask_value
243
+ if VERBOSE:
244
+ print("created mask shape:", mask.shape)
245
+ data[key] = mask
246
+ return data
247
+
248
+ class CreateBodyContourMultiModalTransformd:
249
+ def __init__(self, keys, body_threshold,body_mask_value):
250
+ self.keys = keys
251
+ self.body_threshold = body_threshold
252
+ self.body_mask_value = body_mask_value
253
+
254
+ def __call__(self, data):
255
+ # input medical image (CT) and create body contour, then replace the image by contour
256
+ for key in self.keys:
257
+ x = data[key]
258
+ #print(x)
259
+ mask = torch.zeros_like(x)
260
+ # [B, H, W, D]
261
+ # create a mask for each slice in the batch
262
+ for i in range(x.shape[0]):
263
+ for j in range(x.shape[-1]):
264
+ mask_slice = create_body_contour(x[i,:,:,j], body_threshold=self.body_threshold)
265
+ mask[i,:,:, j] = mask_slice
266
+ mask[mask == 1] = self.body_mask_value
267
+ if VERBOSE:
268
+ print("created mask shape:", mask.shape)
269
+ data[key] = mask
270
+ return data
271
+
272
+ def convert_xcat_to_ct_mask(xcat_image, mapping_csv, tolerance=0.5):
273
+ """
274
+ Converts XCAT CT digital phantom images to simulated CT masks.
275
+
276
+ Parameters:
277
+ - xcat_image (torch.Tensor): The XCAT CT image tensor (in HU values).
278
+ - mapping_csv (str): Path to the CSV file containing organ, HU value, and mask value mappings.
279
+ - tolerance (float): Tolerance for HU value matching (default is ±0.5).
280
+
281
+ Returns:
282
+ - ct_mask (torch.Tensor): The converted CT mask tensor.
283
+ """
284
+ # Load the mapping CSV
285
+ mapping_df = pd.read_csv(mapping_csv)
286
+
287
+ # Initialize the CT mask as a tensor filled with zeros (or another default background value)
288
+ if isinstance(xcat_image, np.ndarray):
289
+ ct_mask = np.zeros_like(xcat_image, dtype=np.int32)
290
+ elif isinstance(xcat_image, torch.Tensor):
291
+ ct_mask = torch.zeros_like(xcat_image, dtype=torch.int32)
292
+ else:
293
+ raise TypeError("xcat_image must be a NumPy ndarray or a PyTorch tensor.")
294
+
295
+ # Iterate over the mapping and replace pixel values
296
+ for _, row in mapping_df.iterrows():
297
+ organ = row['Organ']
298
+ hu_value = row['HU_Value']
299
+ mask_value = row['Mask_Value']
300
+
301
+ # Apply the tolerance range for matching
302
+ lower_bound = hu_value - tolerance
303
+ upper_bound = hu_value + tolerance
304
+
305
+ # Replace matching pixels with the mask value
306
+ match_condition = (xcat_image >= lower_bound) & (xcat_image <= upper_bound)
307
+ ct_mask[match_condition] = mask_value
308
+
309
+ print(f"Processed {organ} with HU range [{lower_bound}, {upper_bound}] to mask value {mask_value}")
310
+ return ct_mask
311
+
312
+ class UseContourToFilterImaged:
313
+ def __init__(self,
314
+ keys: List[str]
315
+ ):
316
+ if len(keys) != 2:
317
+ raise ValueError("Keys must be a list with exactly two string elements.")
318
+ self.image_key = keys[0]
319
+ self.contour_key = keys[1]
320
+ def __call__(self, data):
321
+ image = data[self.image_key]
322
+ contour = data[self.contour_key]
323
+ data[self.image_key] = image*contour
324
+ return data
325
+
326
+ class MergeMasksTransformd:
327
+ def __init__(self,
328
+ keys: List[str]):
329
+ if len(keys) != 2:
330
+ raise ValueError("Keys must be a list with exactly two string elements.")
331
+ self.seg_key = keys[0]
332
+ self.contour_key = keys[1]
333
+
334
+ def __call__(self, data):
335
+ seg = data[self.seg_key]
336
+ contour = data[self.contour_key]
337
+ merged_mask = seg + contour
338
+
339
+ data[self.seg_key] = merged_mask
340
+ return data
341
+
342
+ class MergeSegTissueTransformd:
343
+ def __init__(self,
344
+ keys: List[str]):
345
+ if len(keys) != 2:
346
+ raise ValueError("Keys must be a list with exactly two string elements.")
347
+ self.seg_key = keys[0]
348
+ self.tissue_key = keys[1]
349
+
350
+ def __call__(self, data):
351
+ seg = data[self.seg_key]
352
+ tissue = data[self.tissue_key]
353
+ tissue += 100 # keep the tissue value always higher as segmentation organs
354
+ # Create a mask for overlapping areas
355
+ overlap_mask = (seg > 0) & (tissue > 0)
356
+
357
+ # For overlapping areas, keep the lower value (organ values in seg)
358
+ merged_mask = tissue.copy()
359
+ merged_mask[overlap_mask] = seg[overlap_mask]
360
+
361
+ # Keep all non-overlapping areas
362
+ merged_mask[seg > 0] = seg[seg > 0]
363
+
364
+
365
+ data[self.seg_key] = merged_mask
366
+ return data
367
+
368
+ class DivideTransformd:
369
+ def __init__(self,
370
+ keys: List[str],
371
+ divide_factor):
372
+ self.keys=keys
373
+ self.divide_factor=divide_factor
374
+ def __call__(self, data):
375
+ for key in self.keys:
376
+ data[key] = data[key]/self.divide_factor
377
+ return data
378
+
379
+ class MergeMasksTransformOldd:
380
+ def __init__(self, keys):
381
+ self.keys = keys
382
+
383
+ def __call__(self, data):
384
+ #print('check MergeMasksTransformd:', data)
385
+ merged_mask = torch.zeros_like(data[self.keys[0]], dtype=torch.int32)
386
+
387
+ for key in self.keys:
388
+ merged_mask += data[key].to(torch.int32)
389
+ for key in self.keys:
390
+ data[key] = merged_mask
391
+ return data
392
+
393
+ # convert the integer segemented labels to one-hot codes for training
394
+ class ConvertToOneHotd:
395
+ def __init__(self, keys, number_classes):
396
+ self.keys = keys
397
+ self.nc = number_classes
398
+
399
+ def __call__(self, data):
400
+ for key in self.keys:
401
+ x = data[key]
402
+ # Ensure the tensor is of the correct type
403
+ if x.dtype != torch.long:
404
+ x = x.long()
405
+ # Create the one-hot encoded tensor
406
+ one_hot = torch.zeros(x.size(0), self.nc, x.size(1), x.size(2), device=x.device)
407
+ one_hot.scatter_(1, x.unsqueeze(1), 1)
408
+ data[key] = one_hot
409
+ return data
410
+
411
+ # Example usage
412
+ # Assuming `ct_image_tensor` is a PyTorch tensor of a CT image
413
+ # ct_image_tensor = torch.tensor(img_array, dtype=torch.float32)
414
+ # mask_tensor = create_body_contour(ct_image_tensor)
415
+
416
+ class CreateMaskWithBonesTransform:
417
+ def __init__(self,tissue_min,tissue_max,bone_min,bone_max):
418
+ # You can add initialization parameters if needed
419
+ self.tissue_min = tissue_min
420
+ self.tissue_max = tissue_max
421
+ self.bone_min = bone_min
422
+ self.bone_max = bone_max
423
+
424
+ def __call__(self, x):
425
+ # x is the input tensor
426
+ # Initialize mask with zeros (background)
427
+ mask = torch.zeros_like(x)
428
+
429
+ # Assign label 1 to tissue regions (-500 to 200)
430
+ mask[(x > self.tissue_min) & (x <= self.tissue_max)] = 1
431
+
432
+ # Assign label 2 to bone regions (200 to 1500)
433
+ mask[(x >= self.bone_min) & (x <= self.bone_max)] = 2
434
+
435
+ return mask
436
+
437
+ class CreateMaskWithBonesTransformd:
438
+ def __init__(self, keys, tissue_min, tissue_max, bone_min, bone_max):
439
+ self.keys = keys
440
+ self.tissue_min = tissue_min
441
+ self.tissue_max = tissue_max
442
+ self.bone_min = bone_min
443
+ self.bone_max = bone_max
444
+ def __call__(self, data):
445
+ for key in self.keys:
446
+ x = data[key]
447
+
448
+ mask = torch.zeros_like(x)
449
+ # [B, H, W, D]
450
+ for i in range(x.shape[0]):
451
+ for j in range(x.shape[-1]):
452
+ mask_slice = create_body_contour(x[i,:,:,j], body_threshold=self.tissue_min)
453
+ mask[i,:,:, j] = mask_slice
454
+ #mask = torch.zeros_like(x)
455
+ #mask[(x > self.tissue_min) & (x <= self.tissue_max)] = 1
456
+ mask[(x >= self.bone_min) & (x <= self.bone_max)] = 2
457
+ data[key] = mask
458
+ #print("input and mask shape: ",x.shape,data[key].shape)
459
+ return data
460
+
461
+ class NormalizationMultimodal:
462
+ def __init__(self, keys):
463
+ if len(keys) != 2:
464
+ raise ValueError("Keys must be a list with exactly two string elements.")
465
+ self.prior_key = keys[0]
466
+ self.target_key = keys[1]
467
+
468
+ self.prior_modality_norm_dict = {
469
+ 0: {'min': -300, 'max': 700}, # CT WW=1000, WL=200
470
+ 1: {'min': 0, 'max': 9}, # T1
471
+ 2: {'min': 0, 'max': 28}, # T2
472
+ 3: {'min': 0, 'max': 9}, # VIBE-IN
473
+ 4: {'min': 0, 'max': 10}, # VIBE-OPP
474
+ 5: {'min': 0, 'max': 6}, # DIXON
475
+ }
476
+
477
+ self.target_modality_norm_dict = {
478
+ 0: {'min': -300, 'max': 700}, # CT
479
+ 1: {'min': 0, 'max': 800}, # T1
480
+ 2: {'min': 0, 'max': 160}, # T2
481
+ 3: {'min': 0, 'max': 500}, # VIBE-IN
482
+ 4: {'min': 0, 'max': 520}, # VIBE-OPP
483
+ 5: {'min': 0, 'max': 560}, # DIXON
484
+ }
485
+
486
+ def __call__(self, data):
487
+ modality = int(data['modality'])
488
+
489
+ if modality not in self.target_modality_norm_dict:
490
+ raise ValueError(f"Unsupported modality id: {modality}")
491
+
492
+ # Normalize target
493
+ x_target = data[self.target_key]
494
+ target_params = self.target_modality_norm_dict[modality]
495
+ x_target = torch.clamp(x_target, target_params['min'], target_params['max'])
496
+ x_target = (x_target - target_params['min']) / (target_params['max'] - target_params['min'])
497
+ data[self.target_key] = x_target
498
+
499
+ # Normalize prior
500
+ x_prior = data[self.prior_key]
501
+ prior_params = self.prior_modality_norm_dict[modality]
502
+ x_prior = torch.clamp(x_prior, prior_params['min'], prior_params['max'])
503
+ x_prior = (x_prior - prior_params['min']) / (prior_params['max'] - prior_params['min'])
504
+ data[self.prior_key] = x_prior
505
+
506
+ return data
507
+
dataprocesser/data_processing/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ saved_png
2
+ data
3
+ __pycache__
4
+ */*.asv
dataprocesser/data_processing/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mydataloader
2
+ dataloader for all projects
3
+
4
+ --0709 add center_crop in slicer_loader
5
+
6
+ --0709 test recursively push submodule
7
+
8
+ --0710 add center_crop in slicer_loader
9
+
10
+ --0729 add make_cond.py
11
+
12
+ --0730 add conditional_loader.py
13
+
14
+ --1103 add input_only normalization in basics.py
15
+
16
+ --1106 change the place of ResizeWithPadOrCropd into crop_volumes of basics.py to directly get 512*512 reversed output
17
+
18
+ --1108 merge gan_loader.py together. Change get_file_list in basics.py, replace ct and mr as "source" and "target"
19
+
20
+ --1109 delete rotate in basics.py
dataprocesser/data_processing/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # python module for loading data from monai
2
+ __all__ = ['3d_loader', 'slice_loader', 'slice_loader2', 'basics','manual_slice_loader']
dataprocesser/data_processing/data_process/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **pycache**/
dataprocesser/data_processing/data_process/CTbatchevaluate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from CTevaluate import *
4
+
5
+ def singleevaluate(file_path, window_width = 150, window_level = 30):
6
+ # Load the NIfTI CT image data using nibabel.
7
+ if file_path.endswith('.nii.gz'):
8
+ ct_image_nifti = nib.load(file_path)
9
+ ct_image_data = ct_image_nifti.get_fdata()
10
+ ct_image_nifti = nib.load(file_path)
11
+ ct_image_data = ct_image_nifti.get_fdata()
12
+ elif file_path.endswith('.nrrd'):
13
+ ct_image_data, header = nrrd.read(file_path)
14
+
15
+ ct_data_shape=ct_image_data.shape
16
+
17
+ #plot_ct_value_distribution(ct_image_data)
18
+ ct_image_data = ct_windowing(ct_image_data, window_width, window_level)
19
+ #plot_ct_value_distribution(ct_image_data)
20
+
21
+ # cut roi
22
+ center_x = ct_image_data.shape[0] // 2
23
+ center_y = ct_image_data.shape[1] // 2
24
+ ct_image_roi=extract_roi(ct_image_data, center_x=center_x, center_y=center_y, length=300, width=300)
25
+ ct_image_roi_mean=np.mean(ct_image_roi)
26
+
27
+ # Calculate contrast and standard deviation of CT values.
28
+ contrast = calculate_contrast(ct_image_data)
29
+ std_deviation = calculate_standard_deviation(ct_image_data)
30
+ return ct_image_roi_mean, contrast, std_deviation, ct_data_shape
31
+
32
+ def batchevaluate(dataset_path, format='.nii.gz', save_path='', nii_name='test'):
33
+ for patient_data in glob.glob(dataset_path + "/*"):
34
+ if patient_data.endswith(format):
35
+ patient_name=os.path.basename(os.path.normpath(patient_data))
36
+ print('-------------', patient_name, '-------------')
37
+ ct_image_roi_mean, contrast, std_deviation, ct_data_shape = singleevaluate(patient_data)
38
+ with open(os.path.join(save_path, f'{nii_name}.txt'), 'a') as f:
39
+ f.write('-------------'+patient_name+'-------------\n')
40
+ f.write('Mean of CT values in ROI: '+str(ct_image_roi_mean)+'\n')
41
+ f.write('Contrast of CT image: '+str(contrast)+'\n')
42
+ f.write('Standard Deviation of CT values: '+str(std_deviation)+'\n')
43
+ f.write('Size of CT image: '+str(ct_data_shape)+'\n')
44
+ def main():
45
+ dataset_path=r'D:\Data\dataNeaotomAlpha\NIFTI23072115'
46
+ batchevaluate(dataset_path=dataset_path, format='.nii.gz', save_path=dataset_path, nii_name='evaluate')
47
+
48
+ if __name__=="__main__":
49
+ main()
dataprocesser/data_processing/data_process/CTevaluate.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import nibabel as nib
3
+ import nrrd
4
+
5
+ def extract_roi(ct_image, center_x=256, center_y=256, length=300, width=300):
6
+ """
7
+ Extract a Region of Interest (ROI) from the CT image.
8
+
9
+ Parameters:
10
+ ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
11
+ center_x (int): X-coordinate of the center of the ROI.
12
+ center_y (int): Y-coordinate of the center of the ROI.
13
+ length (int): Length of the square ROI.
14
+ width (int): Width of the square ROI.
15
+
16
+ Returns:
17
+ numpy.ndarray: The ROI extracted from the CT image.
18
+ """
19
+ half_length = length // 2
20
+ half_width = width // 2
21
+
22
+ start_x = max(0, center_x - half_length)
23
+ end_x = min(ct_image.shape[0], center_x + half_length)
24
+ start_y = max(0, center_y - half_width)
25
+ end_y = min(ct_image.shape[1], center_y + half_width)
26
+
27
+ return ct_image[:, start_x:end_x, start_y:end_y]
28
+
29
+ def calculate_contrast(ct_image):
30
+ """
31
+ Calculate the contrast of a CT image.
32
+
33
+ Parameters:
34
+ ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
35
+
36
+ Returns:
37
+ float: The contrast of the CT image.
38
+ """
39
+ # Assuming the CT image data ranges from -1024 to 3071 (typical Hounsfield Units range for CT scans)
40
+ min_value = -1024.0
41
+ max_value = 3071.0
42
+ #ct_image_roi=extract_roi(ct_image, center_x=0, center_y=0, length=300, width=300)
43
+ #ct_image_roi_mean=np.mean(ct_image_roi)
44
+ contrast = np.abs((np.max(ct_image) - np.min(ct_image))) / (max_value - min_value)
45
+ return contrast
46
+
47
+ def calculate_standard_deviation(ct_image):
48
+ """
49
+ Calculate the standard deviation of CT values in the image.
50
+
51
+ Parameters:
52
+ ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
53
+
54
+ Returns:
55
+ float: The standard deviation of CT values.
56
+ """
57
+ return np.std(ct_image)
58
+
59
+ import matplotlib.pyplot as plt
60
+
61
+ def plot_ct_value_distribution(ct_image):
62
+ """
63
+ Plot the distribution of CT values in the image.
64
+
65
+ Parameters:
66
+ ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
67
+ """
68
+ # Flatten the 3D array to a 1D array to get all CT values.
69
+ ct_values = ct_image.flatten()
70
+
71
+ # Create the histogram of CT values.
72
+ plt.hist(ct_values, bins=100, range=(-1024, 3071), color='blue', alpha=0.7)
73
+ plt.xlabel('CT Value')
74
+ plt.ylabel('Frequency')
75
+ plt.title('Distribution of CT Values')
76
+ plt.grid(True)
77
+ plt.show()
78
+
79
+ def ct_windowing(ct_image, window_width, window_level):
80
+ """
81
+ Apply CT windowing to the CT image.
82
+
83
+ Parameters:
84
+ ct_image (numpy.ndarray): The CT image data as a 3D NumPy array.
85
+ window_width (float): The window width.
86
+ window_level (float): The window level.
87
+
88
+ Returns:
89
+ numpy.ndarray: The CT image data after applying windowing.
90
+ """
91
+ # Calculate the lower and upper bounds of the window.
92
+ lower_bound = window_level - window_width / 2.0
93
+ upper_bound = window_level + window_width / 2.0
94
+
95
+ # Clip the CT values within the window bounds.
96
+ ct_image_windowed = np.clip(ct_image, lower_bound, upper_bound)
97
+
98
+ return ct_image_windowed
99
+
100
+ def main():
101
+ # Replace 'your_ct_image.nii' with the path to your NIfTI CT image file.
102
+ pcct_path = r'D:\Data\dataNeaotomAlpha\Nifti\2511\2511_2.nii.gz'
103
+ cbct_path = r'D:\Data\M2OLIE_Phantom\pre_cbct.nrrd'
104
+ nifti_file_path = pcct_path
105
+ nrrd_file_path = cbct_path
106
+
107
+ # Load the NIfTI CT image data using nibabel.
108
+ ct_image_nifti = nib.load(nifti_file_path)
109
+ ct_image_data = ct_image_nifti.get_fdata()
110
+
111
+ #ct_image_data, header = nrrd.read(nrrd_file_path)
112
+
113
+ window_width = 150
114
+ window_level = 30
115
+ #plot_ct_value_distribution(ct_image_data)
116
+ ct_image_data = ct_windowing(ct_image_data, window_width, window_level)
117
+ #plot_ct_value_distribution(ct_image_data)
118
+
119
+ # cut roi
120
+ center_x = ct_image_data.shape[0] // 2
121
+ center_y = ct_image_data.shape[1] // 2
122
+ ct_image_roi=extract_roi(ct_image_data, center_x=center_x, center_y=center_y, length=300, width=300)
123
+ ct_image_roi_mean=np.mean(ct_image_roi)
124
+
125
+ # Calculate contrast and standard deviation of CT values.
126
+ contrast = calculate_contrast(ct_image_data)
127
+ std_deviation = calculate_standard_deviation(ct_image_data)
128
+ print(ct_image_data.shape)
129
+
130
+ print("size of ROI:", ct_image_roi.shape)
131
+ print("Mean of CT values in ROI:", ct_image_roi_mean)
132
+ print("Contrast of CT image:", contrast)
133
+ print("Standard Deviation of CT values:", std_deviation)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
dataprocesser/data_processing/data_process/convert_dicoms.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dicom2nifti
2
+ import SimpleITK as sitk
3
+ from dicom2nifti.exceptions import ConversionValidationError
4
+ from MsSeg.SegmentationNetworkBasis.NetworkBasis import image as Image
5
+ import os
6
+ import glob
7
+ import pydicom
8
+
9
+ #dataset_path = r"C:\Users\ms97\Documents\MRF-Daten\Messdaten"
10
+
11
+ def fromrootgroupconvert(dataset_path, nii_name='test'):
12
+ i=0
13
+ for patient_folder in glob.glob(dataset_path + "/*/"):
14
+ print('-------------', patient_folder, '-------------')
15
+ i=i+1
16
+ t1_nii_path = os.path.join(dataset_path, patient_folder, f'{nii_name}_{i}.nii')
17
+ try:
18
+ try:
19
+ t1_dicom_path = os.path.join(dataset_path, patient_folder)
20
+ dicom2nifti.dicom_series_to_nifti(t1_dicom_path, t1_nii_path)
21
+ except OSError as err:
22
+ print("Finished for Sequence T1Map " + patient_folder)
23
+
24
+ t1_img = sitk.ReadImage(t1_nii_path)
25
+ data_info = Image.get_data_info(t1_img)
26
+ print('Data Info T1: ', data_info)
27
+ except (KeyError, IndexError) as err:
28
+ print("Failed for Sequence T1Map " + patient_folder + " ", err)
29
+
30
+ def simplepatientconvert(patient_folder, nii_name='test'):
31
+ t1_nii_path = os.path.join(patient_folder, f'{nii_name}.nii')
32
+ try:
33
+ try:
34
+ dicom2nifti.dicom_series_to_nifti(patient_folder, t1_nii_path)
35
+ except OSError as err:
36
+ print("Finished for Sequence Dicom " + patient_folder)
37
+
38
+ t1_img = sitk.ReadImage(t1_nii_path)
39
+ data_info = Image.get_data_info(t1_img)
40
+ print('Data Info T1: ', data_info)
41
+ except (KeyError, IndexError) as err:
42
+ print("Failed for Sequence Dicom " + patient_folder + " ", err)
43
+
44
+ def itkfromrootgroupconvert(dataset_path, nii_name='test'):
45
+ i=0
46
+ for patient_folder in glob.glob(dataset_path + "/*/"):
47
+ print('-------------', patient_folder, '-------------')
48
+ i=i+1
49
+ try:
50
+ try:
51
+ reader = sitk.ImageSeriesReader()
52
+ dicom_names = reader.GetGDCMSeriesFileNames(patient_folder)
53
+ reader.SetFileNames(dicom_names)
54
+ image = reader.Execute()
55
+ basefoldername=os.path.basename(os.path.normpath(patient_folder))
56
+ t1_nii_path = os.path.join(dataset_path, f'{basefoldername}.nii.gz')
57
+ # Added a call to PermuteAxes to change the axes of the data
58
+ image = sitk.PermuteAxes(image, [2, 1, 0])
59
+ sitk.WriteImage(image, t1_nii_path)
60
+ except OSError as err:
61
+ print("Finished for Sequence T1Map " + patient_folder)
62
+
63
+ t1_img = sitk.ReadImage(t1_nii_path)
64
+ data_info = Image.get_data_info(t1_img)
65
+ print('Data Info Dicom: ', data_info)
66
+ except (KeyError, IndexError) as err:
67
+ print("Failed for Sequence Dicom " + patient_folder + " ", err)
68
+
69
+ def itkforpatientconvert(patient_folder, nii_name='test'):
70
+ reader = sitk.ImageSeriesReader()
71
+ dicom_names = reader.GetGDCMSeriesFileNames(patient_folder)
72
+ reader.SetFileNames(dicom_names)
73
+ image = reader.Execute()
74
+ t1_nii_path = os.path.join(patient_folder, f'{nii_name}.nii.gz')
75
+ # Added a call to PermuteAxes to change the axes of the data
76
+ image = sitk.PermuteAxes(image, [2, 1, 0])
77
+ sitk.WriteImage(image, t1_nii_path)
78
+
79
+ if __name__=="__main__":
80
+ dataset_path = r"D:\Data\dataNeaotomAlpha\DICOM_Naeotom\DICOM\23072115"
81
+ itkfromrootgroupconvert(dataset_path)
82
+ #patient_folder = r"D:\Data\dataNeaotomAlpha\Q0Q1Q4"
83
+ #simpleitkconvert(patient_folder,'2511')
dataprocesser/data_processing/data_process/make_cond.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import nibabel as nib
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ def make_cond(dataset_path):
7
+ for patient_folder in tqdm(glob.glob(dataset_path + "/*/")):
8
+ if 'overview' not in patient_folder:
9
+ ct_file=os.path.join(patient_folder,'ct.nii.gz')
10
+ ct_image_nifti = nib.load(ct_file)
11
+ ct_image_data = ct_image_nifti.get_fdata()
12
+ ct_slice_number=ct_image_data.shape[-1]
13
+ ct_slice_label=np.arange(0,ct_slice_number-1,1)
14
+ # write into csv
15
+ with open(os.path.join(patient_folder, 'ct_slice_cond.csv'), 'w') as f:
16
+ f.write('slice\n')
17
+ for i in range(len(ct_slice_label)):
18
+ f.write(str(ct_slice_label[i])+'\n')
19
+
20
+ mr_file=os.path.join(patient_folder,'mr.nii.gz')
21
+ mr_image_nifti = nib.load(mr_file)
22
+ mr_image_data = mr_image_nifti.get_fdata()
23
+ mr_slice_number=mr_image_data.shape[-1]
24
+ mr_slice_label=np.arange(0,mr_slice_number-1,1)
25
+ # write into csv
26
+ with open(os.path.join(patient_folder, 'mr_slice_cond.csv'), 'w') as f:
27
+ f.write('slice\n')
28
+ for i in range(len(mr_slice_label)):
29
+ f.write(str(mr_slice_label[i])+'\n')
30
+
31
+ def main():
32
+ dataset_path=r'F:\yang_Projects\Datasets\Task1\pelvis'
33
+ dataset_path_razer=r'C:\Users\56991\Projects\Datasets\Task1\pelvis'
34
+ make_cond(dataset_path)
35
+
36
+ if __name__=="__main__":
37
+ main()
dataprocesser/data_processing/data_process/matlab/BCELossIllustration.m ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Create an example image (ground truth)
2
+ images1 = rand(256, 256); % Random matrix as an example
3
+ images1(images1 > 0.5) = 0.9999;
4
+ images1(images1 <= 0.5) = 0.0001;
5
+
6
+ % Add noise to the image
7
+ noise = randn(size(images1)) * 0.01; % Gaussian noise
8
+ images2 = images1+noise;
9
+
10
+ % Ensure values are in the range [0, 1]
11
+ % images2 = max(min(images2, 1), 0);
12
+
13
+ % Plot the images
14
+ subplot(1, 2, 1);
15
+ imshow(images1);
16
+ title('Original Image');
17
+
18
+ subplot(1, 2, 2);
19
+ imshow(images2);
20
+ title('Noisy Image');
21
+
22
+ % Calculate BCEWithLogitsLoss
23
+ BCEWithLogitsLoss = calculateBCEWithLogitsLoss(images1, images2);
24
+ disp(['BCEWithLogitsLoss: ', num2str(BCEWithLogitsLoss)]);
25
+
26
+ BCELoss=calculateBCELoss(images1, images2);
27
+ disp(['BCELoss: ', num2str(BCELoss)]);
28
+
29
+ function loss = calculateBCELoss(images1, images2)
30
+ % Convert probabilities to logits
31
+ logits2 = images2;
32
+
33
+ % Calculate BCEWithLogitsLoss
34
+ loss = mean(mean(-images1 .* log(logits2) - (1 - images1) .* log(1 - logits2)));
35
+ end
36
+
37
+ function loss = calculateBCEWithLogitsLoss(images1, images2)
38
+ % Convert probabilities to logits
39
+ logits2 = images2;
40
+
41
+ % Calculate BCEWithLogitsLoss
42
+ loss = mean(mean(-images1 .* log(sigmoid(logits2)) - (1 - images1) .* log(1 - sigmoid(logits2))));
43
+ end
44
+
45
+ function logit = probToLogit(p)
46
+ % Convert probability to logit
47
+ logit = log(p ./ (1 - p));
48
+ end
49
+
50
+ function s = sigmoid(x)
51
+ % Sigmoid function
52
+ s = 1 ./ (1 + exp(-x));
53
+ end