Spaces:
Runtime error
Runtime error
| import monai | |
| import os | |
| import numpy as np | |
| from monai.transforms import ( | |
| Compose, | |
| LoadImaged, | |
| Rotate90d, | |
| ScaleIntensityd, | |
| EnsureChannelFirstd, | |
| ResizeWithPadOrCropd, | |
| DivisiblePadd, | |
| ThresholdIntensityd, | |
| NormalizeIntensityd, | |
| SqueezeDimd, | |
| Identityd, | |
| CenterSpatialCropd, | |
| ) | |
| from monai.data import Dataset | |
| from torch.utils.data import DataLoader | |
| import torch | |
| from .basics import get_file_list, get_transforms | |
| def transform_datasets_to_2d(train_ds, val_ds, saved_name_train, saved_name_val, batch_size=8,ifsave=True): | |
| # Load 2D slices of CT images | |
| train_ds_2d = [] | |
| val_ds_2d = [] | |
| shape_list_train = [] | |
| shape_list_val = [] | |
| all_slices_train=0 | |
| all_slices_val=0 | |
| # Load 2D slices for training | |
| for sample in train_ds: | |
| train_ds_2d_image = LoadImaged(keys=["image", "label"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample) | |
| train_ds_2d_image=DivisiblePadd(["image", "label"], (-1,batch_size), mode="minimum")(train_ds_2d_image) | |
| name = os.path.basename(os.path.dirname(sample['image'])) | |
| num_slices = train_ds_2d_image["image"].shape[-1] | |
| #print(train_ds_2d_image["image"].shape) | |
| #print(num_slices) | |
| shape_list_train.append({'patient': name, 'shape': train_ds_2d_image["image"].shape}) | |
| for i in range(num_slices): | |
| train_ds_2d.append({'image': train_ds_2d_image['image'][:,:,i], 'label': train_ds_2d_image['label'][:,:,i]}) | |
| all_slices_train += num_slices | |
| # Load 2D slices for validation | |
| for sample in val_ds: | |
| val_ds_2d_image = LoadImaged(keys=["image", "label"],image_only=True, ensure_channel_first=False, simple_keys=True)(sample) | |
| val_ds_2d_image=DivisiblePadd(["image", "label"], (-1, batch_size), mode="minimum")(val_ds_2d_image) | |
| name = os.path.basename(os.path.dirname(sample['image'])) | |
| shape_list_val.append({'patient': name, 'shape': val_ds_2d_image["image"].shape}) | |
| num_slices = val_ds_2d_image["image"].shape[-1] | |
| for i in range(num_slices): | |
| val_ds_2d.append({'image': val_ds_2d_image['image'][:,:,i], 'label': val_ds_2d_image['label'][:,:,i]}) | |
| all_slices_val += num_slices | |
| # Save shape list to csv | |
| if ifsave: | |
| np.savetxt(saved_name_train,shape_list_train,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string | |
| np.savetxt(saved_name_val,shape_list_val,delimiter=',',fmt='%s',newline='\n') # f means format, r means raw string | |
| return train_ds_2d, val_ds_2d, all_slices_train, all_slices_val, shape_list_train, shape_list_val | |
| def get_train_val_loaders(train_ds_2d, val_ds_2d, batch_size, val_batch_size,resized_size=(256,256)): | |
| # Define transforms | |
| ''' | |
| normalize='zscore' | |
| div_size=(16,16,None) | |
| train_transforms = get_transforms(normalize,resized_size,div_size) | |
| ''' | |
| train_transforms = Compose( | |
| [ | |
| EnsureChannelFirstd(keys=["image", "label"]), | |
| NormalizeIntensityd(keys=["image", "label"], nonzero=False, channel_wise=True), # z-score normalization | |
| ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=resized_size,mode="minimum"), | |
| Rotate90d(keys=["image", "label"], k=3), | |
| DivisiblePadd(["image", "label"], 16, mode="minimum"), | |
| ] | |
| ) | |
| train_transforms_list=train_transforms.__dict__['transforms'] | |
| # Create training dataset and data loader | |
| train_dataset = Dataset(data=train_ds_2d, transform=train_transforms) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) | |
| val_batch_size = val_batch_size | |
| # Create validation dataset and data loader | |
| val_dataset = Dataset(data=val_ds_2d, transform=train_transforms) | |
| val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=0, pin_memory=True) | |
| return train_loader, val_loader, train_transforms_list,train_transforms | |
| def mydataloader(data_pelvis_path, | |
| normalize='zscore', | |
| pad='minimum', | |
| train_number=1, | |
| val_number=1, | |
| batch_size=8, | |
| val_batch_size=1, | |
| saved_name_train='./train_ds_2d.csv',saved_name_val='./val_ds_2d.csv', | |
| resized_size=(512,512), | |
| div_size=(16,16,None), | |
| center_crop=20,): | |
| #train_transforms = get_transforms(normalize,pad,resized_size,div_size) | |
| train_ds, val_ds = get_file_list(data_pelvis_path, | |
| train_number, | |
| val_number) | |
| train_ds_2d, val_ds_2d,\ | |
| all_slices_train,all_slices_val,\ | |
| shape_list_train,shape_list_val = transform_datasets_to_2d(train_ds, val_ds, | |
| saved_name_train, | |
| saved_name_val, | |
| batch_size=batch_size, | |
| ifsave=False) | |
| train_loader, val_loader, \ | |
| train_transforms_list,train_transforms = get_train_val_loaders(train_ds_2d, | |
| val_ds_2d, | |
| batch_size=batch_size, | |
| val_batch_size=val_batch_size, | |
| resized_size=resized_size) | |
| return train_loader,val_loader,\ | |
| train_transforms_list,train_transforms,\ | |
| all_slices_train,all_slices_val,\ | |
| shape_list_train,shape_list_val |