import monai import os import numpy as np from monai.transforms import ( Compose, LoadImaged, Rotate90d, ScaleIntensityd, EnsureChannelFirstd, ResizeWithPadOrCropd, DivisiblePadd, ThresholdIntensityd, NormalizeIntensityd, SqueezeDimd, ShiftIntensityd, Identityd, CenterSpatialCropd, ScaleIntensityRanged, ) from torch.utils.data import DataLoader def get_file_list(data_pelvis_path, train_number, val_number, source='mr', target='ct'): #list all files in the folder file_list=[i for i in os.listdir(data_pelvis_path) if 'overview' not in i] file_list_path=[os.path.join(data_pelvis_path,i) for i in file_list] #list all ct and mr files in folder source_file_list=[os.path.join(j,f'{source}.nii.gz') for j in file_list_path] target_file_list=[os.path.join(j,f'{target}.nii.gz') for j in file_list_path] #mr # Dict Version # source -> image # target -> label train_ds = [{'source': i, 'target': j, 'A_paths': i, 'B_paths': j} for i, j in zip(source_file_list[0:train_number], target_file_list[0:train_number])] val_ds = [{'source': i, 'target': j, 'A_paths': i, 'B_paths': j} for i, j in zip(source_file_list[-val_number:], target_file_list[-val_number:])] print('all files in dataset:',len(file_list)) return train_ds, val_ds def load_volumes(train_transforms,val_transforms, train_crop_ds, val_crop_ds, train_ds, val_ds, saved_name_train=None, saved_name_val=None, ifsave=False,ifcheck=False): train_volume_ds = monai.data.Dataset(data=train_crop_ds, transform=train_transforms) val_volume_ds = monai.data.Dataset(data=val_crop_ds, transform=val_transforms) if ifsave: save_volumes(train_ds, val_ds, saved_name_train, saved_name_val) if ifcheck: check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds) return train_volume_ds,val_volume_ds def crop_volumes(train_ds, val_ds,center_crop,resized_size=(512,512,None),pad='minimum'): if center_crop>0: crop=Compose([LoadImaged(keys=["source", "target"]), EnsureChannelFirstd(keys=["source", "target"]), CenterSpatialCropd(keys=["source", "target"], roi_size=(-1,-1,center_crop)), ]) train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop) val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop) print('center crop:',center_crop) else: crop=Compose([LoadImaged(keys=["source", "target"]), EnsureChannelFirstd(keys=["source", "target"]), ]) train_crop_ds = monai.data.Dataset(data=train_ds, transform=crop) val_crop_ds = monai.data.Dataset(data=val_ds, transform=crop) return train_crop_ds, val_crop_ds def get_transforms(configs, mode='train'): normalize=configs.dataset.normalize pad=configs.dataset.pad resized_size=configs.dataset.resized_size WINDOW_WIDTH=configs.dataset.WINDOW_WIDTH WINDOW_LEVEL=configs.dataset.WINDOW_LEVEL prob=configs.dataset.augmentationProb background=configs.dataset.background transform_list=[] min, max=WINDOW_LEVEL-(WINDOW_WIDTH/2), WINDOW_LEVEL+(WINDOW_WIDTH/2) transform_list.append(ThresholdIntensityd(keys=["target"], threshold=min, above=True, cval=background)) #transform_list.append(ThresholdIntensityd(keys=["target"], threshold=max, above=False, cval=-1000)) # filter the source images # transform_list.append(ThresholdIntensityd(keys=["source"], threshold=configs.dataset.MRImax, above=False, cval=0)) if normalize=='zscore': transform_list.append(NormalizeIntensityd(keys=["source", "target"], nonzero=False, channel_wise=True)) print('zscore normalization') elif normalize=='minmax': transform_list.append(ScaleIntensityd(keys=["source", "target"], minv=-1, maxv=1.0)) print('minmax normalization') elif normalize=='scale4000': transform_list.append(ScaleIntensityd(keys=["source"], minv=-1, maxv=1)) transform_list.append(ScaleIntensityd(keys=["target"], minv=0)) transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975)) # x=x(1+factor) print('scale1000 normalization') elif normalize=='scale1000': transform_list.append(ScaleIntensityd(keys=["source"], minv=0, maxv=1)) transform_list.append(ScaleIntensityd(keys=["target"], minv=0)) transform_list.append(ScaleIntensityd(keys=["target"], factor=-0.99975)) print('scale1000 normalization') elif normalize=='inputonlyzscore': transform_list.append(NormalizeIntensityd(keys=["source"], nonzero=False, channel_wise=True)) print('only normalize input MRI images') elif normalize=='inputonlyminmax': transform_list.append(ScaleIntensityd(keys=["source"], minv=configs.dataset.normmin, maxv=configs.dataset.normmax)) print('only normalize input MRI images') elif normalize=='none': print('no normalization') transform_list.append(ResizeWithPadOrCropd(keys=["source", "target", "mask"], spatial_size=resized_size,mode=pad)) # transform_list.append(ScaleIntensityRanged(keys=["target"],a_min=WINDOW_LEVEL-(WINDOW_WIDTH/2), a_max=WINDOW_LEVEL+(WINDOW_WIDTH/2),b_min=0, b_max=1, clip=True)) if mode == 'train': from monai.transforms import ( # data augmentation RandRotated, RandZoomd, RandBiasFieldd, RandAffined, RandGridDistortiond, RandGridPatchd, RandShiftIntensityd, RandGibbsNoised, RandAdjustContrastd, RandGaussianSmoothd, RandGaussianSharpend, RandGaussianNoised, ) Aug=True if Aug: transform_list.append(RandRotated(keys=["source", "target", "mask"], range_x = 0.1, range_y = 0.1, range_z = 0.1, prob=prob, padding_mode="border", keep_size=True)) transform_list.append(RandZoomd(keys=["source", "target", "mask"], prob=prob, min_zoom=0.9, max_zoom=1.3,padding_mode= "minimum" ,keep_size=True)) transform_list.append(RandAffined(keys=["source", "target", "mask"],padding_mode="border" , prob=prob)) #transform_list.append(Rand3DElasticd(keys=["source", "target"], prob=prob, sigma_range=(5, 8), magnitude_range=(100, 200), spatial_size=None, mode='bilinear')) intensityAug=False if intensityAug: print('intensity data augmentation is used') transform_list.append(RandBiasFieldd(keys=["source"], degree=3, coeff_range=(0.0, 0.1), prob=prob)) # only apply to MRI images transform_list.append(RandGaussianNoised(keys=["source"], prob=prob, mean=0.0, std=0.01)) transform_list.append(RandAdjustContrastd(keys=["source"], prob=prob, gamma=(0.5, 1.5))) transform_list.append(RandShiftIntensityd(keys=["source"], prob=prob, offsets=20)) transform_list.append(RandGaussianSharpend(keys=["source"], alpha=(0.2, 0.8), prob=prob)) #transform_list.append(Rotate90d(keys=["source", "target"], k=3)) #transform_list.append(DivisiblePadd(keys=["source", "target"], k=div_size, mode="minimum")) #transform_list.append(Identityd(keys=["source", "target"])) # do nothing for the no norm case train_transforms = Compose(transform_list) return train_transforms def get_length(dataset, patch_batch_size): loader=DataLoader(dataset, batch_size=1) iterator = iter(loader) sum_nslices=0 for idx in range(len(loader)): check_data = next(iterator) nslices=check_data['source'].shape[-1] sum_nslices+=nslices if sum_nslices%patch_batch_size==0: return sum_nslices//patch_batch_size else: return sum_nslices//patch_batch_size+1 def check_volumes(train_ds, train_volume_ds, val_volume_ds, val_ds): # use batch_size=1 to check the volumes because the input volumes have different shapes train_loader = DataLoader(train_volume_ds, batch_size=1) val_loader = DataLoader(val_volume_ds, batch_size=1) train_iterator = iter(train_loader) val_iterator = iter(val_loader) print('check training data:') idx=0 for idx in range(len(train_loader)): try: train_check_data = next(train_iterator) ds_idx = idx * 1 current_item = train_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['source'])) print(idx, current_name, 'image:', train_check_data['source'].shape, 'label:', train_check_data['target'].shape) except: ds_idx = idx * 1 current_item = train_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['source'])) print('check data error! Check the input data:',current_name) print("checked all training data.") print('check validation data:') idx=0 for idx in range(len(val_loader)): try: val_check_data = next(val_iterator) ds_idx = idx * 1 current_item = val_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['source'])) print(idx, current_name, 'image:', val_check_data['source'].shape, 'label:', val_check_data['target'].shape) except: ds_idx = idx * 1 current_item = val_ds[ds_idx] current_name = os.path.basename(os.path.dirname(current_item['source'])) print('check data error! Check the input data:',current_name) print("checked all validation data.") def save_volumes(train_ds, val_ds, saved_name_train, saved_name_val): shape_list_train=[] shape_list_val=[] # use the function of saving information before for sample in train_ds: name = os.path.basename(os.path.dirname(sample['source'])) shape_list_train.append({'patient': name}) for sample in val_ds: name = os.path.basename(os.path.dirname(sample['source'])) shape_list_val.append({'patient': name}) np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string def check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size): for idx, train_check_data in enumerate(train_loader): ds_idx = idx * train_batch_size current_item = train_patch_ds[ds_idx] print('check train data:') print(current_item, 'image:', train_check_data['source'].shape, 'label:', train_check_data['target'].shape) for idx, val_check_data in enumerate(val_loader): ds_idx = idx * val_batch_size current_item = val_volume_ds[ds_idx] print('check val data:') print(current_item, 'image:', val_check_data['source'].shape, 'label:', val_check_data['target'].shape)