FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
import os
from typing import List
import numpy as np
import shutil
from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile
from nnunetv2.training.dataloading.utils import get_case_identifiers
class nnUNetDataset(object):
def __init__(self, folder: str, case_identifiers: List[str] = None,
num_images_properties_loading_threshold: int = 0,
folder_with_segs_from_previous_stage: str = None):
"""
This does not actually load the dataset. It merely creates a dictionary where the keys are training case names and
the values are dictionaries containing the relevant information for that case.
dataset[training_case] -> info
Info has the following key:value pairs:
- dataset[case_identifier]['properties']['data_file'] -> the full path to the npz file associated with the training case
- dataset[case_identifier]['properties']['properties_file'] -> the pkl file containing the case properties
In addition, if the total number of cases is < num_images_properties_loading_threshold we load all the pickle files
(containing auxiliary information). This is done for small datasets so that we don't spend too much CPU time on
reading pkl files on the fly during training. However, for large datasets storing all the aux info (which also
contains locations of foreground voxels in the images) can cause too much RAM utilization. In that
case is it better to load on the fly.
If properties are loaded into the RAM, the info dicts each will have an additional entry:
- dataset[case_identifier]['properties'] -> pkl file content
IMPORTANT! THIS CLASS ITSELF IS READ-ONLY. YOU CANNOT ADD KEY:VALUE PAIRS WITH nnUNetDataset[key] = value
USE THIS INSTEAD:
nnUNetDataset.dataset[key] = value
(not sure why you'd want to do that though. So don't do it)
"""
super().__init__()
# print('loading dataset')
if case_identifiers is None:
case_identifiers = get_case_identifiers(folder)
case_identifiers.sort()
self.dataset = {}
for c in case_identifiers:
self.dataset[c] = {}
self.dataset[c]['data_file'] = join(folder, f"{c}.npz")
self.dataset[c]['properties_file'] = join(folder, f"{c}.pkl")
if folder_with_segs_from_previous_stage is not None:
self.dataset[c]['seg_from_prev_stage_file'] = join(folder_with_segs_from_previous_stage, f"{c}.npz")
if len(case_identifiers) <= num_images_properties_loading_threshold:
for i in self.dataset.keys():
self.dataset[i]['properties'] = load_pickle(self.dataset[i]['properties_file'])
self.keep_files_open = ('nnUNet_keep_files_open' in os.environ.keys()) and \
(os.environ['nnUNet_keep_files_open'].lower() in ('true', '1', 't'))
# print(f'nnUNetDataset.keep_files_open: {self.keep_files_open}')
def __getitem__(self, key):
ret = {**self.dataset[key]}
if 'properties' not in ret.keys():
ret['properties'] = load_pickle(ret['properties_file'])
return ret
def __setitem__(self, key, value):
return self.dataset.__setitem__(key, value)
def keys(self):
return self.dataset.keys()
def __len__(self):
return self.dataset.__len__()
def items(self):
return self.dataset.items()
def values(self):
return self.dataset.values()
def load_case(self, key):
entry = self[key]
if 'open_data_file' in entry.keys():
data = entry['open_data_file']
# print('using open data file')
elif isfile(entry['data_file'][:-4] + ".npy"):
data = np.load(entry['data_file'][:-4] + ".npy", 'r')
if self.keep_files_open:
self.dataset[key]['open_data_file'] = data
# print('saving open data file')
else:
data = np.load(entry['data_file'])['data']
if 'open_seg_file' in entry.keys():
seg = entry['open_seg_file']
# print('using open data file')
elif isfile(entry['data_file'][:-4] + "_seg.npy"):
seg = np.load(entry['data_file'][:-4] + "_seg.npy", 'r')
if self.keep_files_open:
self.dataset[key]['open_seg_file'] = seg
# print('saving open seg file')
else:
seg = np.load(entry['data_file'])['seg']
if 'seg_from_prev_stage_file' in entry.keys():
if isfile(entry['seg_from_prev_stage_file'][:-4] + ".npy"):
seg_prev = np.load(entry['seg_from_prev_stage_file'][:-4] + ".npy", 'r')
else:
seg_prev = np.load(entry['seg_from_prev_stage_file'])['seg']
seg = np.vstack((seg, seg_prev[None]))
return data, seg, entry['properties']
class nnUNetDatasetMask(nnUNetDataset):
def __init__(self, folder: str, case_identifiers: List[str] = None,
num_images_properties_loading_threshold: int = 0,
folder_with_segs_from_previous_stage: str = None):
super().__init__(folder, case_identifiers, num_images_properties_loading_threshold, folder_with_segs_from_previous_stage)
def load_case(self, key):
data, seg, properties = super().load_case(key)
# Load mask if available
entry = self[key]
if isfile(entry['data_file'][:-4] + "_mask.npy"):
mask = np.load(entry['data_file'][:-4] + "_mask.npy", 'r')
if self.keep_files_open:
self.dataset[key]['open_mask_file'] = mask
# print('saving open mask file')
else:
mask = np.load(entry['data_file'])['mask']
return data, seg, properties, mask
if __name__ == '__main__':
# this is a mini test. Todo: We can move this to tests in the future (requires simulated dataset)
folder = '/media/fabian/data/nnUNet_preprocessed/Dataset003_Liver/3d_lowres'
ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) # this should not load the properties!
# this SHOULD HAVE the properties
ks = ds['liver_0'].keys()
assert 'properties' in ks
# amazing. I am the best.
# this should have the properties
ds = nnUNetDataset(folder, num_images_properties_loading_threshold=1000)
# now rename the properties file so that it does not exist anymore
shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl'))
# now we should still be able to access the properties because they have already been loaded
ks = ds['liver_0'].keys()
assert 'properties' in ks
# move file back
shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl'))
# this should not have the properties
ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0)
# now rename the properties file so that it does not exist anymore
shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl'))
# now this should crash
try:
ks = ds['liver_0'].keys()
raise RuntimeError('we should not have come here')
except FileNotFoundError:
print('all good')
# move file back
shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl'))