siddharthdhara17's picture
Upload baselines/prepare_data.py with huggingface_hub
345e387 verified
"""
Prepare LIDC-IDRI data for deterministic baselines.
Creates flat directories with majority-vote merged masks.
Also prepares nnU-Net format dataset.
"""
import os
import sys
import glob
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import shutil
def majority_vote_mask(mask_paths):
"""Create majority vote mask from multiple annotator masks (>=2/4 agree)."""
masks = []
for p in mask_paths:
m = np.array(Image.open(p).convert("L"))
m = (m > 127).astype(np.uint8) # Binarize
masks.append(m)
# Stack and sum: pixel = 1 if >= 2 annotators agree
stacked = np.stack(masks, axis=0)
vote = (np.sum(stacked, axis=0) >= 2).astype(np.uint8)
return vote * 255 # Save as 0/255 PNG
def process_split(data_dir, output_dir, split_name):
"""Process a train or test split."""
images_dir = os.path.join(output_dir, "images")
masks_dir = os.path.join(output_dir, "masks")
os.makedirs(images_dir, exist_ok=True)
os.makedirs(masks_dir, exist_ok=True)
# Find all patient directories
patient_dirs = sorted(glob.glob(os.path.join(data_dir, "LIDC-IDRI-*")))
count = 0
skipped = 0
for patient_dir in tqdm(patient_dirs, desc=f"Processing {split_name}"):
patient_id = os.path.basename(patient_dir)
nodule_dirs = sorted(glob.glob(os.path.join(patient_dir, "nodule-*")))
for nodule_dir in nodule_dirs:
nodule_id = os.path.basename(nodule_dir)
image_files = sorted(glob.glob(os.path.join(nodule_dir, "images", "slice-*.png")))
for img_path in image_files:
slice_name = os.path.basename(img_path) # e.g., slice-0.png
slice_id = slice_name.replace(".png", "") # e.g., slice-0
# Find all annotator masks for this slice
mask_paths = []
for mask_dir in sorted(glob.glob(os.path.join(nodule_dir, "mask-*"))):
mask_path = os.path.join(mask_dir, slice_name)
if os.path.exists(mask_path):
mask_paths.append(mask_path)
if len(mask_paths) < 2:
skipped += 1
continue
# Create output filename: LIDC-IDRI-0001_nodule-0_slice-0
out_name = f"{patient_id}_{nodule_id}_{slice_id}.png"
# Copy image
shutil.copy2(img_path, os.path.join(images_dir, out_name))
# Create and save majority vote mask
mv_mask = majority_vote_mask(mask_paths)
Image.fromarray(mv_mask).save(os.path.join(masks_dir, out_name))
count += 1
print(f"{split_name}: Processed {count} slices, skipped {skipped}")
return count
def prepare_nnunet_format(flat_train_dir, flat_test_dir, nnunet_raw_dir):
"""Convert flat dataset to nnU-Net v2 format."""
dataset_dir = os.path.join(nnunet_raw_dir, "Dataset001_LIDC")
imagesTr = os.path.join(dataset_dir, "imagesTr")
labelsTr = os.path.join(dataset_dir, "labelsTr")
imagesTs = os.path.join(dataset_dir, "imagesTs")
labelsTs = os.path.join(dataset_dir, "labelsTs")
for d in [imagesTr, labelsTr, imagesTs, labelsTs]:
os.makedirs(d, exist_ok=True)
# nnU-Net expects: case_XXXX_0000.png for images, case_XXXX.png for labels
# Channel suffix _0000 for single-channel
print("Converting to nnU-Net format...")
# Training
train_images = sorted(glob.glob(os.path.join(flat_train_dir, "images", "*.png")))
for i, img_path in enumerate(tqdm(train_images, desc="nnU-Net train")):
basename = os.path.splitext(os.path.basename(img_path))[0]
case_id = f"LIDC_{i:05d}"
# Copy image with _0000 suffix
shutil.copy2(img_path, os.path.join(imagesTr, f"{case_id}_0000.png"))
# Copy mask (convert 0/255 to 0/1 for nnU-Net)
mask_path = os.path.join(flat_train_dir, "masks", os.path.basename(img_path))
mask = np.array(Image.open(mask_path).convert("L"))
mask = (mask > 127).astype(np.uint8)
Image.fromarray(mask).save(os.path.join(labelsTr, f"{case_id}.png"))
# Testing
test_images = sorted(glob.glob(os.path.join(flat_test_dir, "images", "*.png")))
for i, img_path in enumerate(tqdm(test_images, desc="nnU-Net test")):
basename = os.path.splitext(os.path.basename(img_path))[0]
case_id = f"LIDC_{i:05d}"
shutil.copy2(img_path, os.path.join(imagesTs, f"{case_id}_0000.png"))
mask_path = os.path.join(flat_test_dir, "masks", os.path.basename(img_path))
mask = np.array(Image.open(mask_path).convert("L"))
mask = (mask > 127).astype(np.uint8)
Image.fromarray(mask).save(os.path.join(labelsTs, f"{case_id}.png"))
# Create dataset.json
import json
dataset_json = {
"channel_names": {"0": "CT"},
"labels": {"background": 0, "nodule": 1},
"numTraining": len(train_images),
"file_ending": ".png",
"name": "Dataset001_LIDC",
"description": "LIDC-IDRI Lung Nodule Segmentation (majority vote GT)",
"reference": "LIDC-IDRI",
"licence": "CC BY 3.0",
"release": "1.0"
}
with open(os.path.join(dataset_dir, "dataset.json"), "w") as f:
json.dump(dataset_json, f, indent=2)
# Save mapping from nnU-Net case IDs to original names (for prediction conversion)
mapping = {}
for i, img_path in enumerate(sorted(glob.glob(os.path.join(flat_test_dir, "images", "*.png")))):
case_id = f"LIDC_{i:05d}"
original_name = os.path.splitext(os.path.basename(img_path))[0]
mapping[case_id] = original_name
with open(os.path.join(dataset_dir, "test_case_mapping.json"), "w") as f:
json.dump(mapping, f, indent=2)
print(f"nnU-Net dataset created at {dataset_dir}")
print(f" Training: {len(train_images)} cases")
print(f" Testing: {len(test_images)} cases")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, default="data", help="Root data directory")
parser.add_argument("--skip_nnunet", action="store_true", help="Skip nnU-Net format conversion")
args = parser.parse_args()
train_dir = os.path.join(args.data_root, "training")
test_dir = os.path.join(args.data_root, "testing")
flat_train = os.path.join(args.data_root, "flat_train")
flat_test = os.path.join(args.data_root, "flat_test")
print("=" * 60)
print("Preparing flat dataset with majority-vote masks")
print("=" * 60)
n_train = process_split(train_dir, flat_train, "Training")
n_test = process_split(test_dir, flat_test, "Testing")
print(f"\nTotal: {n_train} train, {n_test} test slices")
if not args.skip_nnunet:
print("\n" + "=" * 60)
print("Preparing nnU-Net format dataset")
print("=" * 60)
nnunet_raw = os.path.join(args.data_root, "nnUNet_raw")
prepare_nnunet_format(flat_train, flat_test, nnunet_raw)
print("\nDone!")
if __name__ == "__main__":
main()