| import os
|
| import torch
|
| import torchvision
|
| from torch.utils.data import Dataset, DataLoader
|
| from torchvision import datasets, transforms
|
| import nibabel as nib
|
| from skimage.transform import rescale, resize, downscale_local_mean
|
| from scipy.ndimage import zoom
|
| import numpy as np
|
|
|
|
|
|
|
| import sys
|
| sys.path.append('./')
|
| from Dataloader.dataloader_utils import *
|
|
|
| EPS = 1e-7
|
|
|
| def get_dataloader(data_name='cmr',mode='train'):
|
| if data_name=='cmr':
|
| if mode=='train':
|
| dataloader=CMR_loader
|
| elif mode =='aug':
|
| dataloader=CMR_tgt_loader
|
| else:
|
| print('mode not exist')
|
| elif data_name=='lct':
|
| if mode=='train':
|
| dataloader=LCT_loader
|
| elif mode =='aug':
|
| dataloader=LCT_tgt_loader
|
| else:
|
| print('mode not exist')
|
| else:
|
| print('dataloader not exist')
|
| return dataloader
|
|
|
| class LCT_loader(Dataset):
|
| def __init__(self, data_root_path = f'Data/Src_data/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
|
|
|
| self.files = [data_root_path + f for f in os.listdir(data_root_path) if f.endswith('.npy')]
|
| self.transforms = transforms
|
| self.noise_scale=noise_scale
|
| self.d_p = data_root_path
|
|
|
| def __getitem__(self, item):
|
| array = np.load(self.files[item])
|
| if 'process' not in self.d_p:
|
| array = (array - array.min()) / (array.max() - array.min() + EPS)
|
| array = array[None,:,:,:]
|
| if self.transforms != None:
|
| array = self.transforms(array)
|
|
|
| return array, array, item
|
|
|
|
|
| def __len__(self):
|
| return len(self.files)
|
|
|
| class LCT_tgt_loader(Dataset):
|
| def __init__(self, data_root_path = "Data/Tgt_data/lct/",noise_scale=0.0, patient_index = None):
|
| self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
|
| self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
|
|
|
| self.files_tr.sort()
|
| self.files_gt.sort()
|
|
|
| self.transforms = transforms
|
| self.noise_scale=noise_scale
|
|
|
| def __getitem__(self, item):
|
| img_nib = nib.load(self.files_tr[item])
|
| mask_nib = nib.load(self.files_gt[item])
|
|
|
| image = img_nib.get_fdata()
|
| mask = mask_nib.get_fdata()
|
|
|
| image = image[None,:,:,:]
|
| mask = mask[None,:,:,:]
|
|
|
| print(self.files_tr[item],self.files_gt[item])
|
|
|
| return image, mask, item
|
|
|
|
|
|
|
| def __len__(self):
|
| assert len(self.files_gt) == len(self.files_tr)
|
| return len(self.files_gt)
|
|
|
| class LCT_seg(Dataset):
|
| def __init__(self, data_root_path = "/home/data/jzheng/CTLung_processed/testset/modality_0001/",noise_scale=0.0, patient_index = None):
|
| self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
|
| self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
|
|
|
| self.files_tr.sort()
|
| self.files_gt.sort()
|
|
|
| self.transforms = transforms
|
| self.noise_scale=noise_scale
|
|
|
| def __getitem__(self, item):
|
| img_nib = nib.load(self.files_tr[item])
|
| mask_nib = nib.load(self.files_gt[item])
|
|
|
| image = img_nib.get_fdata()
|
| mask = mask_nib.get_fdata()
|
|
|
| image = image[None,:,:,:]
|
| mask = mask[None,:,:,:]
|
|
|
| print(self.files_tr[item],self.files_gt[item])
|
|
|
| return image, mask, item
|
|
|
|
|
|
|
| def __len__(self):
|
| assert len(self.files_gt) == len(self.files_tr)
|
| return len(self.files_gt)
|
|
|
| class CMR_loader_preprocess(Dataset):
|
|
|
| def __init__(self, data_path = 'Data/CTLung_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
|
|
|
| self.d_p = data_path
|
| self.target_res = target_res
|
| self.files = [self.d_p + x for x in os.listdir(self.d_p)]
|
| self.transforms = transforms
|
| self.noise_scale=noise_scale
|
|
|
| def __getitem__(self, item):
|
| array = nib.load(self.files[item]).get_fdata()
|
| array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
|
| array = array[None, :, :]
|
| array = remove_background(array)
|
| array = (array - array.min()) / (array.max() - array.min() + EPS)
|
|
|
| if self.noise_scale > 0:
|
| array = thresh_img(array,[0,self.noise_scale])
|
| array = array * (np.random.normal(1, self.noise_scale*2))
|
|
|
| if self.transforms != None:
|
| array = self.transforms(array)
|
| return array, self.files[item]
|
|
|
| def __len__(self):
|
| return len(self.files)
|
|
|
| class CMR_loader(Dataset):
|
|
|
|
|
| def __init__(self, data_path = f'Data/Src_data/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
|
|
|
| self.d_p = data_path
|
| self.ndims = 2
|
| self.target_res = target_res
|
| self.files = [self.d_p + x for x in os.listdir(self.d_p)]
|
| self.transforms = transforms
|
|
|
| self.noise_scale=noise_scale
|
| self.preprocessed='resize' in data_path
|
|
|
| def __getitem__(self, item):
|
| array = nib.load(self.files[item]).get_fdata()
|
| if not self.preprocessed:
|
| array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
|
| array = array[None, :, :]
|
| if not self.preprocessed:
|
| array = remove_background(array)
|
| array = (array - array.min()) / (array.max() - array.min() + EPS)
|
|
|
|
|
|
|
|
|
|
|
| if self.transforms != None:
|
| array = self.transforms(array)
|
| return array, array, item
|
|
|
| def __len__(self):
|
| return len(self.files)
|
|
|
| def get_transform(self,degrees=np.pi,translate=0.125):
|
|
|
| self.transforms = torchvision.transforms.Compose([
|
|
|
| torchvision.transforms.ToTensor(),
|
| torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
|
|
|
|
|
|
|
| ])
|
| return
|
|
|
| class CMR_tgt_loader(Dataset):
|
| def __init__(self,
|
| data_path = 'Data/Tgt_data/cmr/',
|
|
|
| target_res = (256,256),
|
| is_3d = False,
|
| patient_index = [],
|
| ):
|
|
|
|
|
| self.d_p = os.path.join(data_path,'Tr','')
|
| self.gt_p = os.path.join(data_path,'Gt','')
|
| self.img_files = os.listdir(self.d_p)
|
| self.gt_files = os.listdir(self.gt_p)
|
| self.p_indice = patient_index
|
| self.target_res_2d = target_res
|
| self.img_files.sort()
|
| self.gt_files.sort()
|
| self.img_samples = []
|
| self.gt_samples = []
|
| self.p_id = []
|
|
|
| if len(self.p_indice) == 0:
|
| self.p_indice = [x for x in range(1,101)]
|
|
|
| p2f = {}
|
| assert len(self.gt_files) == len(self.img_files)
|
| print(self.p_indice)
|
| for i in self.p_indice:
|
| for gt_f, img_f in zip(self.gt_files, self.img_files):
|
| pf_id = gt_f.split('_')[0]
|
| pf_id = pf_id[-3:]
|
| if i == int(pf_id):
|
| img_volume = nib.load(self.d_p + img_f).get_fdata()
|
| gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
|
| assert img_volume.shape == gt_volume.shape
|
| depth = img_volume.shape[2]
|
| for si in range(depth):
|
| img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| img = (img - img.min()) / (img.max() - img.min() + EPS)
|
|
|
| gt = gt_volume[:, :, si]
|
|
|
| gt_1_index = gt == 1
|
| gt_2_index = gt == 2
|
| gt_3_index = gt == 3
|
| gt_4_index = gt == 4
|
|
|
| gt_1 = gt * gt_1_index
|
| gt_2 = gt * gt_2_index
|
| gt_3 = gt * gt_3_index
|
| gt_4 = gt * gt_4_index
|
|
|
|
|
| gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
|
|
|
|
| self.img_samples.append(img[np.newaxis, :, :])
|
| self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
|
| self.p_id.append(i)
|
|
|
|
|
| def __getitem__(self, item):
|
|
|
| return self.img_samples[item], self.gt_samples[item], self.p_id[item]
|
|
|
|
|
| def __len__(self):
|
|
|
| assert len(self.img_samples) == len(self.gt_samples)
|
| return len(self.img_samples)
|
|
|
| class acdc_seg(Dataset):
|
| def __init__(self,
|
| data_path = '/home/data/jzheng/acdc/train_images/',
|
| gt_path = '/home/data/jzheng/acdc/train_gt/',
|
| target_res = (256,256),
|
| is_3d = False,
|
| patient_index = [],
|
| ):
|
|
|
|
|
| self.d_p = data_path
|
| self.gt_p = gt_path
|
| self.img_files = os.listdir(self.d_p)
|
| self.gt_files = os.listdir(self.gt_p)
|
| self.p_indice = patient_index
|
| self.target_res_2d = target_res
|
| self.img_files.sort()
|
| self.gt_files.sort()
|
| self.img_samples = []
|
| self.gt_samples = []
|
| self.p_id = []
|
|
|
| if len(self.p_indice) == 0:
|
| self.p_indice = [x for x in range(1,101)]
|
|
|
| p2f = {}
|
| assert len(self.gt_files) == len(self.img_files)
|
| print(self.p_indice)
|
| for i in self.p_indice:
|
| for gt_f, img_f in zip(self.gt_files, self.img_files):
|
| pf_id = gt_f.split('_')[0]
|
| pf_id = pf_id[-3:]
|
| if i == int(pf_id):
|
| img_volume = nib.load(self.d_p + img_f).get_fdata()
|
| gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
|
| assert img_volume.shape == gt_volume.shape
|
| depth = img_volume.shape[2]
|
| for si in range(depth):
|
| img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| img = (img - img.min()) / (img.max() - img.min() + EPS)
|
|
|
| gt = gt_volume[:, :, si]
|
|
|
| gt_1_index = gt == 1
|
| gt_2_index = gt == 2
|
| gt_3_index = gt == 3
|
| gt_4_index = gt == 4
|
|
|
| gt_1 = gt * gt_1_index
|
| gt_2 = gt * gt_2_index
|
| gt_3 = gt * gt_3_index
|
| gt_4 = gt * gt_4_index
|
|
|
|
|
| gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
|
|
|
|
| self.img_samples.append(img[np.newaxis, :, :])
|
| self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
|
| self.p_id.append(i)
|
|
|
|
|
| def __getitem__(self, item):
|
|
|
| return self.img_samples[item], self.gt_samples[item], self.p_id[item]
|
|
|
|
|
| def __len__(self):
|
|
|
| assert len(self.img_samples) == len(self.gt_samples)
|
| return len(self.img_samples)
|
|
|
| class acdc_gan(Dataset):
|
| def __init__(self,
|
| train_path = '/home/data/jzheng/acdc/images/',
|
| target_res = (32, 256, 256),
|
| is_3d = False,
|
| transforms = None
|
| ):
|
| self.t_p = train_path
|
| self.files = os.listdir(self.t_p)
|
| self.sample_list_2d = []
|
| self.is_3d = is_3d
|
| self.target_res = target_res
|
| self.res_2d = (target_res[1], target_res[2])
|
| self.transforms = transforms
|
|
|
| if self.is_3d == False:
|
| for f in self.files:
|
| img = nib.load(self.t_p + f).get_fdata()
|
| depth = img.shape[2]
|
| f_i = int(round(depth*0.1))
|
| b_i = int(round(depth*0.9))
|
| interval_slice = img[:, :, f_i:b_i]
|
| for ii in range(interval_slice.shape[2]):
|
| single_slice = interval_slice[:,:,ii]
|
| single_slice = resize(single_slice, self.res_2d, anti_aliasing=True, preserve_range=True)
|
| single_slice = (single_slice - single_slice.min()) / ( single_slice.max() - single_slice.min() + EPS)
|
| self.sample_list_2d.append(single_slice[None,:,:])
|
|
|
|
|
| def __len__(self):
|
| if self.is_3d == False:
|
| return len(self.sample_list_2d)
|
| else:
|
| return len(self.files )
|
|
|
| def __getitem__(self, index):
|
| if self.is_3d == False:
|
| return self.sample_list_2d[index], self.sample_list_2d[index]
|
| for f in self.files:
|
| img = nib.load(self.t_p + f).get_fdata()
|
| target_d_ratio = self.target_res[0] / img.shape[2]
|
| target_w_ratio = self.target_res[1] / img.shape[0]
|
| target_h_ratio = self.target_res[2] / img.shape[1]
|
|
|
| resize_img = zoom(img, (target_w_ratio, target_h_ratio, target_d_ratio))
|
|
|
| resize_img = np.swapaxes(resize_img, 0, 2)
|
| resize_img = np.swapaxes(resize_img, 1, 2)
|
| resize_img = (resize_img - resize_img.min()) / (resize_img.max() - resize_img.min() + EPS)
|
| if transforms != None:
|
| resize_img = self.transforms(resize_img)
|
| return resize_img, resize_img
|
|
|
| class acdc_gan_single_slice(Dataset):
|
| def __init__(self, train_path = '/well/papiez/shared/ACDC/clean_training/images/'):
|
| self.t_p = train_path
|
| self.files = os.listdir(self.t_p)
|
|
|
| def __len__(self):
|
| return len(self.files)
|
|
|
| def __getitem__(self, index):
|
| img = self.files[index]
|
| img = nib.load(self.t_p + img).get_fdata()
|
| depth = img.shape[2]
|
| mid_d = int(depth/2)
|
| mid_slice = img[:,:,mid_d]
|
| mid_slice = resize(mid_slice, (128, 128), anti_aliasing=True, preserve_range=True)
|
| mid_slice = (mid_slice-mid_slice.min())/(mid_slice.max()-mid_slice.min()+EPS)
|
|
|
|
|
| return mid_slice, mid_slice
|
|
|
|
|
|
|
|
|
|
|