| import os |
| from typing import List |
|
|
| import numpy as np |
| import shutil |
|
|
| from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile |
| from nnunetv2.training.dataloading.utils import get_case_identifiers |
|
|
|
|
| class nnUNetDataset(object): |
| def __init__(self, folder: str, case_identifiers: List[str] = None, |
| num_images_properties_loading_threshold: int = 0, |
| folder_with_segs_from_previous_stage: str = None): |
| """ |
| This does not actually load the dataset. It merely creates a dictionary where the keys are training case names and |
| the values are dictionaries containing the relevant information for that case. |
| dataset[training_case] -> info |
| Info has the following key:value pairs: |
| - dataset[case_identifier]['properties']['data_file'] -> the full path to the npz file associated with the training case |
| - dataset[case_identifier]['properties']['properties_file'] -> the pkl file containing the case properties |
| |
| In addition, if the total number of cases is < num_images_properties_loading_threshold we load all the pickle files |
| (containing auxiliary information). This is done for small datasets so that we don't spend too much CPU time on |
| reading pkl files on the fly during training. However, for large datasets storing all the aux info (which also |
| contains locations of foreground voxels in the images) can cause too much RAM utilization. In that |
| case is it better to load on the fly. |
| |
| If properties are loaded into the RAM, the info dicts each will have an additional entry: |
| - dataset[case_identifier]['properties'] -> pkl file content |
| |
| IMPORTANT! THIS CLASS ITSELF IS READ-ONLY. YOU CANNOT ADD KEY:VALUE PAIRS WITH nnUNetDataset[key] = value |
| USE THIS INSTEAD: |
| nnUNetDataset.dataset[key] = value |
| (not sure why you'd want to do that though. So don't do it) |
| """ |
| super().__init__() |
| |
| if case_identifiers is None: |
| case_identifiers = get_case_identifiers(folder) |
| case_identifiers.sort() |
|
|
| self.dataset = {} |
| for c in case_identifiers: |
| self.dataset[c] = {} |
| self.dataset[c]['data_file'] = join(folder, f"{c}.npz") |
| self.dataset[c]['properties_file'] = join(folder, f"{c}.pkl") |
| if folder_with_segs_from_previous_stage is not None: |
| self.dataset[c]['seg_from_prev_stage_file'] = join(folder_with_segs_from_previous_stage, f"{c}.npz") |
|
|
| if len(case_identifiers) <= num_images_properties_loading_threshold: |
| for i in self.dataset.keys(): |
| self.dataset[i]['properties'] = load_pickle(self.dataset[i]['properties_file']) |
|
|
| self.keep_files_open = ('nnUNet_keep_files_open' in os.environ.keys()) and \ |
| (os.environ['nnUNet_keep_files_open'].lower() in ('true', '1', 't')) |
| |
|
|
| def __getitem__(self, key): |
| ret = {**self.dataset[key]} |
| if 'properties' not in ret.keys(): |
| ret['properties'] = load_pickle(ret['properties_file']) |
| return ret |
|
|
| def __setitem__(self, key, value): |
| return self.dataset.__setitem__(key, value) |
|
|
| def keys(self): |
| return self.dataset.keys() |
|
|
| def __len__(self): |
| return self.dataset.__len__() |
|
|
| def items(self): |
| return self.dataset.items() |
|
|
| def values(self): |
| return self.dataset.values() |
|
|
| def load_case(self, key): |
| entry = self[key] |
| if 'open_data_file' in entry.keys(): |
| data = entry['open_data_file'] |
| |
| elif isfile(entry['data_file'][:-4] + ".npy"): |
| data = np.load(entry['data_file'][:-4] + ".npy", 'r') |
| if self.keep_files_open: |
| self.dataset[key]['open_data_file'] = data |
| |
| else: |
| data = np.load(entry['data_file'])['data'] |
|
|
| if 'open_seg_file' in entry.keys(): |
| seg = entry['open_seg_file'] |
| |
| elif isfile(entry['data_file'][:-4] + "_seg.npy"): |
| seg = np.load(entry['data_file'][:-4] + "_seg.npy", 'r') |
| if self.keep_files_open: |
| self.dataset[key]['open_seg_file'] = seg |
| |
| else: |
| seg = np.load(entry['data_file'])['seg'] |
|
|
| if 'seg_from_prev_stage_file' in entry.keys(): |
| if isfile(entry['seg_from_prev_stage_file'][:-4] + ".npy"): |
| seg_prev = np.load(entry['seg_from_prev_stage_file'][:-4] + ".npy", 'r') |
| else: |
| seg_prev = np.load(entry['seg_from_prev_stage_file'])['seg'] |
| seg = np.vstack((seg, seg_prev[None])) |
|
|
| return data, seg, entry['properties'] |
| |
|
|
| class nnUNetDatasetMask(nnUNetDataset): |
| def __init__(self, folder: str, case_identifiers: List[str] = None, |
| num_images_properties_loading_threshold: int = 0, |
| folder_with_segs_from_previous_stage: str = None): |
| super().__init__(folder, case_identifiers, num_images_properties_loading_threshold, folder_with_segs_from_previous_stage) |
|
|
| def load_case(self, key): |
| data, seg, properties = super().load_case(key) |
| |
| entry = self[key] |
| if isfile(entry['data_file'][:-4] + "_mask.npy"): |
| mask = np.load(entry['data_file'][:-4] + "_mask.npy", 'r') |
| if self.keep_files_open: |
| self.dataset[key]['open_mask_file'] = mask |
| |
| else: |
| mask = np.load(entry['data_file'])['mask'] |
|
|
| return data, seg, properties, mask |
|
|
|
|
| if __name__ == '__main__': |
| |
|
|
| folder = '/media/fabian/data/nnUNet_preprocessed/Dataset003_Liver/3d_lowres' |
| ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) |
| |
| ks = ds['liver_0'].keys() |
| assert 'properties' in ks |
| |
|
|
| |
| ds = nnUNetDataset(folder, num_images_properties_loading_threshold=1000) |
| |
| shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl')) |
| |
| ks = ds['liver_0'].keys() |
| assert 'properties' in ks |
| |
| shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl')) |
|
|
| |
| ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) |
| |
| shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl')) |
| |
| try: |
| ks = ds['liver_0'].keys() |
| raise RuntimeError('we should not have come here') |
| except FileNotFoundError: |
| print('all good') |
| |
| shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl')) |
|
|
|
|