File size: 9,018 Bytes
19c1f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import csv
import os
import random
from pathlib import Path

import nibabel as nib
from batchgenerators.utilities.file_and_folder_operations import load_json, save_json

from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_preprocessed


def read_csv(csv_file: str):
    patient_info = {}

    with open(csv_file) as csvfile:
        reader = csv.reader(csvfile)
        headers = next(reader)
        patient_index = headers.index("External code")
        ed_index = headers.index("ED")
        es_index = headers.index("ES")
        vendor_index = headers.index("Vendor")

        for row in reader:
            patient_info[row[patient_index]] = {
                "ed": int(row[ed_index]),
                "es": int(row[es_index]),
                "vendor": row[vendor_index],
            }

    return patient_info


# ------------------------------------------------------------------------------
# Conversion to nnUNet format
# ------------------------------------------------------------------------------
def convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int):
    out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name="MNMs")
    patients_train = [f for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
    patients_test = [f for f in (src_data_folder / "Testing").iterdir() if f.is_dir()]

    patient_info = read_csv(str(src_data_folder / csv_file_name))

    save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir)
    save_cardiac_phases(patients_test, patient_info, out_test_dir)

    # There are non-orthonormal direction cosines in the test and validation data.
    # Not sure if the data should be fixed, or we should skip the problematic data.
    # patients_val = [f for f in (src_data_folder / "Validation").iterdir() if f.is_dir()]
    # save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir)

    generate_dataset_json(
        str(out_dir),
        channel_names={
            0: "cineMRI",
        },
        labels={"background": 0, "LVBP": 1, "LVM": 2, "RV": 3},
        file_ending=".nii.gz",
        num_training_cases=len(patients_train) * 2,  # 2 since we have ED and ES for each patient
    )


def save_cardiac_phases(
    patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None
):
    for patient in patients:
        print(f"Processing patient: {patient.name}")

        image = nib.load(patient / f"{patient.name}_sa.nii.gz")
        ed_frame = patient_info[patient.name]["ed"]
        es_frame = patient_info[patient.name]["es"]

        save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient)

        if labels_dir:
            label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz")
            save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient)


def save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path):
    # Save only extracted diastole and systole slices from the 4D H x W x D x time volume.
    image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine)
    image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine)

    # Labels do not have modality identifiers. Labels always end with 'gt'.
    suffix = ".nii.gz" if image.get_filename().endswith("_gt.nii.gz") else "_0000.nii.gz"

    nib.save(image_ed, str(out_dir / f"{patient.name}_frame{ed_frame:02d}{suffix}"))
    nib.save(image_es, str(out_dir / f"{patient.name}_frame{es_frame:02d}{suffix}"))


# ------------------------------------------------------------------------------
# Create custom splits
# ------------------------------------------------------------------------------
def create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25):
    existing_splits = os.path.join(nnUNet_preprocessed, f"Dataset{dataset_id}_MNMs", "splits_final.json")
    splits = load_json(existing_splits)

    patients_train = [f.name for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
    # Filter out any patients not in the training set
    patient_info = {
        patient: data
        for patient, data in read_csv(str(src_data_folder / csv_file)).items()
        if patient in patients_train
    }

    # Get train and validation patients for both vendors
    patients_a = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "A"]
    patients_b = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "B"]
    train_a, val_a = get_vendor_split(patients_a, num_val_patients)
    train_b, val_b = get_vendor_split(patients_b, num_val_patients)

    # Build filenames from corresponding patient frames
    train_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_a for frame in ["es", "ed"]]
    train_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_b for frame in ["es", "ed"]]
    train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :]
    train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :]
    val_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_a for frame in ["es", "ed"]]
    val_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_b for frame in ["es", "ed"]]

    for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]:
        # For each train set, we evaluate on A, B and (A + B) respectively
        # See table 3 from the original paper for more details.
        splits.append({"train": train_set, "val": val_a})
        splits.append({"train": train_set, "val": val_b})
        splits.append({"train": train_set, "val": val_a + val_b})

    save_json(splits, existing_splits)


def get_vendor_split(patients: list[str], num_val_patients: int):
    random.shuffle(patients)
    total_patients = len(patients)
    num_training_patients = total_patients - num_val_patients
    return patients[:num_training_patients], patients[num_training_patients:]


if __name__ == "__main__":
    import argparse

    class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
        pass

    parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "-h",
        "--help",
        action="help",
        default=argparse.SUPPRESS,
        help="MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet "
        "format. It can also be used to create additional custom splits, for explicitly training on combinations "
        "of vendors A and B (see `--custom-splits`).\n"
        "If you wish to generate the custom splits, run the following pipeline:\n\n"
        "(1) Run `Dataset114_MNMs -i <raw_Data_dir>\n"
        "(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\n"
        "(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\n"
        "(4) Re-run `Dataset114_MNMs`, with `-s True`.\n"
        "(5) Re-run training.\n",
    )
    parser.add_argument(
        "-i",
        "--input_folder",
        type=str,
        default="./data/M&Ms/OpenDataset/",
        help="The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing "
        "folders.",
    )
    parser.add_argument(
        "-c",
        "--csv_file_name",
        type=str,
        default="211230_M&Ms_Dataset_information_diagnosis_opendataset.csv",
        help="The csv file containing the dataset information.",
    ),
    parser.add_argument("-d", "--dataset_id", type=int, default=114, help="nnUNet Dataset ID.")
    parser.add_argument(
        "-s",
        "--custom_splits",
        type=bool,
        default=False,
        help="Whether to append custom splits for training and testing on different vendors. If True, will create "
        "splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out "
        "validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from "
        "https://arxiv.org/abs/2011.07592 for more info.",
    )

    args = parser.parse_args()
    args.input_folder = Path(args.input_folder)

    if args.custom_splits:
        print("Appending custom splits...")
        create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id)
    else:
        print("Converting...")
        convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id)

    print("Done!")