import monai import os import numpy as np from monai.transforms import ( Compose, LoadImaged, Rotate90d, ScaleIntensityd, EnsureChannelFirstd, ResizeWithPadOrCropd, DivisiblePadd, ThresholdIntensityd, NormalizeIntensityd, SqueezeDimd, Identityd, CenterSpatialCropd, ) from monai.data import Dataset from torch.utils.data import DataLoader import torch from .basics import get_file_list, get_transforms def transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val, batch_size=8,ifsave=True): # Load 2D slices of CT images train_ds_2d = [] val_ds_2d = [] shape_list_train = [] shape_list_val = [] all_slices_train=0 all_slices_val=0 # Load 2D slices for training for sample in train_ds: train_ds_2d_image = LoadImaged(keys=["image", "label"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample) train_ds_2d_image=DivisiblePadd(["image", "label"], (-1,batch_size), mode="minimum")(train_ds_2d_image) name = os.path.basename(os.path.dirname(sample['image'])) num_slices = train_ds_2d_image["image"].shape[-1] #print(train_ds_2d_image["image"].shape) #print(num_slices) shape_list_train.append({'patient': name, 'shape': train_ds_2d_image["image"].shape}) for i in range(num_slices): train_ds_2d.append({'image': train_ds_2d_image['image'][:,:,i], 'label': train_ds_2d_image['label'][:,:,i]}) all_slices_train += num_slices # Load 2D slices for validation for sample in val_ds: val_ds_2d_image = LoadImaged(keys=["image", "label"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample) val_ds_2d_image=DivisiblePadd(["image", "label"], (-1, batch_size), mode="minimum")(val_ds_2d_image) name = os.path.basename(os.path.dirname(sample['image'])) shape_list_val.append({'patient': name, 'shape': val_ds_2d_image["image"].shape}) num_slices = val_ds_2d_image["image"].shape[-1] for i in range(num_slices): val_ds_2d.append({'image': val_ds_2d_image['image'][:,:,i], 'label': val_ds_2d_image['label'][:,:,i]}) all_slices_val += num_slices # Save shape list to csv if ifsave: np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string return train_ds_2d, val_ds_2d, all_slices_train, all_slices_val, shape_list_train, shape_list_val def get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size, val_batch_size,resized_size=(256,256)): # Define transforms ''' normalize='zscore' div_size=(16,16,None) train_transforms = get_transforms(normalize,resized_size,div_size) ''' train_transforms = Compose( [ EnsureChannelFirstd(keys=["image", "label"]), NormalizeIntensityd(keys=["image", "label"], nonzero=False, channel_wise=True), # z-score normalization ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=resized_size,mode="minimum"), Rotate90d(keys=["image", "label"], k=3), DivisiblePadd(["image", "label"], 16, mode="minimum"), ] ) train_transforms_list=train_transforms.__dict__['transforms'] # Create training dataset and data loader train_dataset = Dataset(data=train_ds_2d, transform=train_transforms) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) val_batch_size = val_batch_size # Create validation dataset and data loader val_dataset = Dataset(data=val_ds_2d, transform=train_transforms) val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=0, pin_memory=True) return train_loader, val_loader, train_transforms_list,train_transforms def mydataloader(data_pelvis_path, normalize='zscore', pad='minimum', train_number=1, val_number=1, batch_size=8, val_batch_size=1, saved_name_train='./train_ds_2d.csv',saved_name_val='./val_ds_2d.csv', resized_size=(512,512), div_size=(16,16,None), center_crop=20,): #train_transforms = get_transforms(normalize,pad,resized_size,div_size) train_ds, val_ds = get_file_list(data_pelvis_path, train_number, val_number) train_ds_2d, val_ds_2d,\ all_slices_train,all_slices_val,\ shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val, batch_size=batch_size, ifsave=False) train_loader, val_loader, \ train_transforms_list,train_transforms = get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size=batch_size, val_batch_size=val_batch_size, resized_size=resized_size) return train_loader,val_loader,\ train_transforms_list,train_transforms,\ all_slices_train,all_slices_val,\ shape_list_train,shape_list_val