siddharthdhara17 commited on
Commit
345e387
·
verified ·
1 Parent(s): 8b47631

Upload baselines/prepare_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. baselines/prepare_data.py +189 -0
baselines/prepare_data.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare LIDC-IDRI data for deterministic baselines.
3
+ Creates flat directories with majority-vote merged masks.
4
+ Also prepares nnU-Net format dataset.
5
+ """
6
+ import os
7
+ import sys
8
+ import glob
9
+ import argparse
10
+ import numpy as np
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ import shutil
14
+
15
+
16
+ def majority_vote_mask(mask_paths):
17
+ """Create majority vote mask from multiple annotator masks (>=2/4 agree)."""
18
+ masks = []
19
+ for p in mask_paths:
20
+ m = np.array(Image.open(p).convert("L"))
21
+ m = (m > 127).astype(np.uint8) # Binarize
22
+ masks.append(m)
23
+
24
+ # Stack and sum: pixel = 1 if >= 2 annotators agree
25
+ stacked = np.stack(masks, axis=0)
26
+ vote = (np.sum(stacked, axis=0) >= 2).astype(np.uint8)
27
+ return vote * 255 # Save as 0/255 PNG
28
+
29
+
30
+ def process_split(data_dir, output_dir, split_name):
31
+ """Process a train or test split."""
32
+ images_dir = os.path.join(output_dir, "images")
33
+ masks_dir = os.path.join(output_dir, "masks")
34
+ os.makedirs(images_dir, exist_ok=True)
35
+ os.makedirs(masks_dir, exist_ok=True)
36
+
37
+ # Find all patient directories
38
+ patient_dirs = sorted(glob.glob(os.path.join(data_dir, "LIDC-IDRI-*")))
39
+
40
+ count = 0
41
+ skipped = 0
42
+ for patient_dir in tqdm(patient_dirs, desc=f"Processing {split_name}"):
43
+ patient_id = os.path.basename(patient_dir)
44
+ nodule_dirs = sorted(glob.glob(os.path.join(patient_dir, "nodule-*")))
45
+
46
+ for nodule_dir in nodule_dirs:
47
+ nodule_id = os.path.basename(nodule_dir)
48
+ image_files = sorted(glob.glob(os.path.join(nodule_dir, "images", "slice-*.png")))
49
+
50
+ for img_path in image_files:
51
+ slice_name = os.path.basename(img_path) # e.g., slice-0.png
52
+ slice_id = slice_name.replace(".png", "") # e.g., slice-0
53
+
54
+ # Find all annotator masks for this slice
55
+ mask_paths = []
56
+ for mask_dir in sorted(glob.glob(os.path.join(nodule_dir, "mask-*"))):
57
+ mask_path = os.path.join(mask_dir, slice_name)
58
+ if os.path.exists(mask_path):
59
+ mask_paths.append(mask_path)
60
+
61
+ if len(mask_paths) < 2:
62
+ skipped += 1
63
+ continue
64
+
65
+ # Create output filename: LIDC-IDRI-0001_nodule-0_slice-0
66
+ out_name = f"{patient_id}_{nodule_id}_{slice_id}.png"
67
+
68
+ # Copy image
69
+ shutil.copy2(img_path, os.path.join(images_dir, out_name))
70
+
71
+ # Create and save majority vote mask
72
+ mv_mask = majority_vote_mask(mask_paths)
73
+ Image.fromarray(mv_mask).save(os.path.join(masks_dir, out_name))
74
+
75
+ count += 1
76
+
77
+ print(f"{split_name}: Processed {count} slices, skipped {skipped}")
78
+ return count
79
+
80
+
81
+ def prepare_nnunet_format(flat_train_dir, flat_test_dir, nnunet_raw_dir):
82
+ """Convert flat dataset to nnU-Net v2 format."""
83
+ dataset_dir = os.path.join(nnunet_raw_dir, "Dataset001_LIDC")
84
+
85
+ imagesTr = os.path.join(dataset_dir, "imagesTr")
86
+ labelsTr = os.path.join(dataset_dir, "labelsTr")
87
+ imagesTs = os.path.join(dataset_dir, "imagesTs")
88
+ labelsTs = os.path.join(dataset_dir, "labelsTs")
89
+
90
+ for d in [imagesTr, labelsTr, imagesTs, labelsTs]:
91
+ os.makedirs(d, exist_ok=True)
92
+
93
+ # nnU-Net expects: case_XXXX_0000.png for images, case_XXXX.png for labels
94
+ # Channel suffix _0000 for single-channel
95
+
96
+ print("Converting to nnU-Net format...")
97
+
98
+ # Training
99
+ train_images = sorted(glob.glob(os.path.join(flat_train_dir, "images", "*.png")))
100
+ for i, img_path in enumerate(tqdm(train_images, desc="nnU-Net train")):
101
+ basename = os.path.splitext(os.path.basename(img_path))[0]
102
+ case_id = f"LIDC_{i:05d}"
103
+
104
+ # Copy image with _0000 suffix
105
+ shutil.copy2(img_path, os.path.join(imagesTr, f"{case_id}_0000.png"))
106
+
107
+ # Copy mask (convert 0/255 to 0/1 for nnU-Net)
108
+ mask_path = os.path.join(flat_train_dir, "masks", os.path.basename(img_path))
109
+ mask = np.array(Image.open(mask_path).convert("L"))
110
+ mask = (mask > 127).astype(np.uint8)
111
+ Image.fromarray(mask).save(os.path.join(labelsTr, f"{case_id}.png"))
112
+
113
+ # Testing
114
+ test_images = sorted(glob.glob(os.path.join(flat_test_dir, "images", "*.png")))
115
+ for i, img_path in enumerate(tqdm(test_images, desc="nnU-Net test")):
116
+ basename = os.path.splitext(os.path.basename(img_path))[0]
117
+ case_id = f"LIDC_{i:05d}"
118
+
119
+ shutil.copy2(img_path, os.path.join(imagesTs, f"{case_id}_0000.png"))
120
+
121
+ mask_path = os.path.join(flat_test_dir, "masks", os.path.basename(img_path))
122
+ mask = np.array(Image.open(mask_path).convert("L"))
123
+ mask = (mask > 127).astype(np.uint8)
124
+ Image.fromarray(mask).save(os.path.join(labelsTs, f"{case_id}.png"))
125
+
126
+ # Create dataset.json
127
+ import json
128
+ dataset_json = {
129
+ "channel_names": {"0": "CT"},
130
+ "labels": {"background": 0, "nodule": 1},
131
+ "numTraining": len(train_images),
132
+ "file_ending": ".png",
133
+ "name": "Dataset001_LIDC",
134
+ "description": "LIDC-IDRI Lung Nodule Segmentation (majority vote GT)",
135
+ "reference": "LIDC-IDRI",
136
+ "licence": "CC BY 3.0",
137
+ "release": "1.0"
138
+ }
139
+ with open(os.path.join(dataset_dir, "dataset.json"), "w") as f:
140
+ json.dump(dataset_json, f, indent=2)
141
+
142
+ # Save mapping from nnU-Net case IDs to original names (for prediction conversion)
143
+ mapping = {}
144
+ for i, img_path in enumerate(sorted(glob.glob(os.path.join(flat_test_dir, "images", "*.png")))):
145
+ case_id = f"LIDC_{i:05d}"
146
+ original_name = os.path.splitext(os.path.basename(img_path))[0]
147
+ mapping[case_id] = original_name
148
+
149
+ with open(os.path.join(dataset_dir, "test_case_mapping.json"), "w") as f:
150
+ json.dump(mapping, f, indent=2)
151
+
152
+ print(f"nnU-Net dataset created at {dataset_dir}")
153
+ print(f" Training: {len(train_images)} cases")
154
+ print(f" Testing: {len(test_images)} cases")
155
+
156
+
157
+ def main():
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--data_root", type=str, default="data", help="Root data directory")
160
+ parser.add_argument("--skip_nnunet", action="store_true", help="Skip nnU-Net format conversion")
161
+ args = parser.parse_args()
162
+
163
+ train_dir = os.path.join(args.data_root, "training")
164
+ test_dir = os.path.join(args.data_root, "testing")
165
+
166
+ flat_train = os.path.join(args.data_root, "flat_train")
167
+ flat_test = os.path.join(args.data_root, "flat_test")
168
+
169
+ print("=" * 60)
170
+ print("Preparing flat dataset with majority-vote masks")
171
+ print("=" * 60)
172
+
173
+ n_train = process_split(train_dir, flat_train, "Training")
174
+ n_test = process_split(test_dir, flat_test, "Testing")
175
+
176
+ print(f"\nTotal: {n_train} train, {n_test} test slices")
177
+
178
+ if not args.skip_nnunet:
179
+ print("\n" + "=" * 60)
180
+ print("Preparing nnU-Net format dataset")
181
+ print("=" * 60)
182
+ nnunet_raw = os.path.join(args.data_root, "nnUNet_raw")
183
+ prepare_nnunet_format(flat_train, flat_test, nnunet_raw)
184
+
185
+ print("\nDone!")
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()