File size: 9,018 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import csv
import os
import random
from pathlib import Path
import nibabel as nib
from batchgenerators.utilities.file_and_folder_operations import load_json, save_json
from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_preprocessed
def read_csv(csv_file: str):
patient_info = {}
with open(csv_file) as csvfile:
reader = csv.reader(csvfile)
headers = next(reader)
patient_index = headers.index("External code")
ed_index = headers.index("ED")
es_index = headers.index("ES")
vendor_index = headers.index("Vendor")
for row in reader:
patient_info[row[patient_index]] = {
"ed": int(row[ed_index]),
"es": int(row[es_index]),
"vendor": row[vendor_index],
}
return patient_info
# ------------------------------------------------------------------------------
# Conversion to nnUNet format
# ------------------------------------------------------------------------------
def convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int):
out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name="MNMs")
patients_train = [f for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
patients_test = [f for f in (src_data_folder / "Testing").iterdir() if f.is_dir()]
patient_info = read_csv(str(src_data_folder / csv_file_name))
save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir)
save_cardiac_phases(patients_test, patient_info, out_test_dir)
# There are non-orthonormal direction cosines in the test and validation data.
# Not sure if the data should be fixed, or we should skip the problematic data.
# patients_val = [f for f in (src_data_folder / "Validation").iterdir() if f.is_dir()]
# save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir)
generate_dataset_json(
str(out_dir),
channel_names={
0: "cineMRI",
},
labels={"background": 0, "LVBP": 1, "LVM": 2, "RV": 3},
file_ending=".nii.gz",
num_training_cases=len(patients_train) * 2, # 2 since we have ED and ES for each patient
)
def save_cardiac_phases(
patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None
):
for patient in patients:
print(f"Processing patient: {patient.name}")
image = nib.load(patient / f"{patient.name}_sa.nii.gz")
ed_frame = patient_info[patient.name]["ed"]
es_frame = patient_info[patient.name]["es"]
save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient)
if labels_dir:
label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz")
save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient)
def save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path):
# Save only extracted diastole and systole slices from the 4D H x W x D x time volume.
image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine)
image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine)
# Labels do not have modality identifiers. Labels always end with 'gt'.
suffix = ".nii.gz" if image.get_filename().endswith("_gt.nii.gz") else "_0000.nii.gz"
nib.save(image_ed, str(out_dir / f"{patient.name}_frame{ed_frame:02d}{suffix}"))
nib.save(image_es, str(out_dir / f"{patient.name}_frame{es_frame:02d}{suffix}"))
# ------------------------------------------------------------------------------
# Create custom splits
# ------------------------------------------------------------------------------
def create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25):
existing_splits = os.path.join(nnUNet_preprocessed, f"Dataset{dataset_id}_MNMs", "splits_final.json")
splits = load_json(existing_splits)
patients_train = [f.name for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
# Filter out any patients not in the training set
patient_info = {
patient: data
for patient, data in read_csv(str(src_data_folder / csv_file)).items()
if patient in patients_train
}
# Get train and validation patients for both vendors
patients_a = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "A"]
patients_b = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "B"]
train_a, val_a = get_vendor_split(patients_a, num_val_patients)
train_b, val_b = get_vendor_split(patients_b, num_val_patients)
# Build filenames from corresponding patient frames
train_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_a for frame in ["es", "ed"]]
train_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_b for frame in ["es", "ed"]]
train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :]
train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :]
val_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_a for frame in ["es", "ed"]]
val_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_b for frame in ["es", "ed"]]
for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]:
# For each train set, we evaluate on A, B and (A + B) respectively
# See table 3 from the original paper for more details.
splits.append({"train": train_set, "val": val_a})
splits.append({"train": train_set, "val": val_b})
splits.append({"train": train_set, "val": val_a + val_b})
save_json(splits, existing_splits)
def get_vendor_split(patients: list[str], num_val_patients: int):
random.shuffle(patients)
total_patients = len(patients)
num_training_patients = total_patients - num_val_patients
return patients[:num_training_patients], patients[num_training_patients:]
if __name__ == "__main__":
import argparse
class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
pass
parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter)
parser.add_argument(
"-h",
"--help",
action="help",
default=argparse.SUPPRESS,
help="MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet "
"format. It can also be used to create additional custom splits, for explicitly training on combinations "
"of vendors A and B (see `--custom-splits`).\n"
"If you wish to generate the custom splits, run the following pipeline:\n\n"
"(1) Run `Dataset114_MNMs -i <raw_Data_dir>\n"
"(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\n"
"(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\n"
"(4) Re-run `Dataset114_MNMs`, with `-s True`.\n"
"(5) Re-run training.\n",
)
parser.add_argument(
"-i",
"--input_folder",
type=str,
default="./data/M&Ms/OpenDataset/",
help="The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing "
"folders.",
)
parser.add_argument(
"-c",
"--csv_file_name",
type=str,
default="211230_M&Ms_Dataset_information_diagnosis_opendataset.csv",
help="The csv file containing the dataset information.",
),
parser.add_argument("-d", "--dataset_id", type=int, default=114, help="nnUNet Dataset ID.")
parser.add_argument(
"-s",
"--custom_splits",
type=bool,
default=False,
help="Whether to append custom splits for training and testing on different vendors. If True, will create "
"splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out "
"validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from "
"https://arxiv.org/abs/2011.07592 for more info.",
)
args = parser.parse_args()
args.input_folder = Path(args.input_folder)
if args.custom_splits:
print("Appending custom splits...")
create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id)
else:
print("Converting...")
convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id)
print("Done!")
|