peirong26's picture
Upload 187 files
2571f24 verified
import math
import os
import os.path
import nibabel
import nibabel as nib
import numpy as np
import torch
import torch.nn
class BRATSDataset(torch.utils.data.Dataset):
def __init__(self, directory, test_flag=True):
super().__init__()
self.directory = os.path.expanduser(directory)
self.test_flag = test_flag
if test_flag:
self.seqtypes = ["voided", "mask"]
else:
#self.seqtypes = ["diseased", "mask", "healthy"]
self.seqtypes = ["healthyvoided", "healthy", "t1n"]
self.seqtypes_set = set(self.seqtypes)
self.database = []
self.mask_vis = []
for root, dirs, files in os.walk(self.directory):
dirs_sorted = sorted(dirs)
for dir_id in dirs_sorted:
datapoint = dict()
sli_dict = dict()
for ro, di, fi in os.walk(root + "/" + str(dir_id)):
fi_sorted = sorted(fi)
for f in fi_sorted:
seqtype = f.split("-")[-1].split(".")[0]
#print('seqtype:', seqtype)
datapoint[seqtype] = os.path.join(root, dir_id, f)
if seqtype == "mask":
slice_range = []
mask_to_define_rand = np.array(
nibabel.load(datapoint["mask"]).dataobj
)
if test_flag:
mask_to_define_rand = np.pad(
mask_to_define_rand, ((0, 0), (0, 0), (34, 35))
)
mask_to_define_rand = mask_to_define_rand[8:-8, 8:-8, :]
for i in range(0, 224):
mask_slice = mask_to_define_rand[:, :, i]
if np.sum(mask_slice) != 0:
slice_range.append(i)
# assert (
# set(datapoint.keys()) == self.seqtypes_set
# ), f"datapoint is incomplete, keys are {datapoint.keys()}"
self.database.append(datapoint)
self.mask_vis.append(slice_range)
break
def __getitem__(self, x):
filedict = self.database[x]
slicedict = self.mask_vis[x]
#print("input files: ", filedict)
#print("slice dict:", slicedict)
out_single = []
if self.test_flag:
for seqtype in self.seqtypes:
if seqtype == "voided":
nib_img = np.array(nibabel.load(filedict[seqtype]).dataobj).astype(
np.float32
)
path = filedict[seqtype]
t1_numpy_pad = np.pad(nib_img, ((0, 0), (0, 0), (34, 35)))
t1_numpy_crop = t1_numpy_pad[8:-8, 8:-8, :] # crop-pad to (224, 224, 224)
t1_clipped = np.clip(
t1_numpy_crop,
np.quantile(t1_numpy_crop, 0.001),
np.quantile(t1_numpy_crop, 0.999),
)
t1_normalized = (t1_clipped - np.min(t1_clipped)) / (
np.max(t1_clipped) - np.min(t1_clipped)
)
img_preprocessed = torch.tensor(t1_normalized)
elif seqtype == "mask":
nib_img = np.array(nibabel.load(filedict[seqtype]).dataobj).astype(
np.float32
)
path = filedict[seqtype]
mask_numpy_pad = np.pad(nib_img, ((0, 0), (0, 0), (34, 35)))
mask_numpy_crop = mask_numpy_pad[8:-8, 8:-8, :]
img_preprocessed = torch.tensor(mask_numpy_crop)
else:
print("unknown seqtype")
out_single.append(img_preprocessed)
out_single = torch.stack(out_single)
image = out_single[0:2, ...]
path = filedict[seqtype]
return (image, path, slicedict)
else:
for seqtype in self.seqtypes:
nib_img = np.array(nibabel.load(filedict[seqtype]).dataobj).astype(
np.float32
)
path = filedict[seqtype]
img_preprocessed = torch.tensor(nib_img)
out_single.append(img_preprocessed)
out_single = torch.stack(out_single) # "diseased", "mask", "healthy"
image = out_single[0:2, ...] # "diseased", "mask"
label = out_single[2, ...] # "healthy"
label = label.unsqueeze(0)
path = filedict[seqtype]
return (image, label, path, slicedict)
def __len__(self):
return len(self.database)