Spaces:
Runtime error
Runtime error
| import monai | |
| from monai.transforms import ( | |
| Compose, | |
| LoadImaged, | |
| Rotate90d, | |
| ScaleIntensityd, | |
| EnsureChannelFirstd, | |
| ResizeWithPadOrCropd, | |
| DivisiblePadd, | |
| ThresholdIntensityd, | |
| NormalizeIntensityd, | |
| SqueezeDimd, | |
| Identityd, | |
| CenterSpatialCropd, | |
| ) | |
| from monai.data import Dataset | |
| from torch.utils.data import DataLoader | |
| import torch | |
| from .basics import get_file_list, check_batch_data, get_transforms, load_volumes, crop_volumes | |
| ##### slices ##### | |
| def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=8,val_batch_size=1,window_width=1,ifcheck=True): | |
| patch_func = monai.data.PatchIterd( | |
| keys=["source", "target"], | |
| patch_size=(None, None, window_width), # dynamic first two dimensions | |
| start_pos=(0, 0, 0) | |
| ) | |
| if window_width==1: | |
| patch_transform = Compose( | |
| [ | |
| SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim | |
| ] | |
| ) | |
| else: | |
| patch_transform = None | |
| # for training | |
| train_patch_ds = monai.data.GridPatchDataset( | |
| data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False) | |
| train_loader = DataLoader( | |
| train_patch_ds, | |
| batch_size=train_batch_size, | |
| num_workers=0, | |
| pin_memory=torch.cuda.is_available(), | |
| ) | |
| # for validation | |
| val_patch_ds = monai.data.GridPatchDataset( | |
| data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False) | |
| val_loader = DataLoader( | |
| val_patch_ds, #val_volume_ds, | |
| num_workers=0, | |
| batch_size=val_batch_size, | |
| pin_memory=torch.cuda.is_available()) | |
| if ifcheck: | |
| check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size) | |
| return train_loader,val_loader | |
| def myslicesloader(data_pelvis_path, | |
| normalize='minmax', | |
| pad='minimum', | |
| train_number=1, | |
| val_number=1, | |
| train_batch_size=8, | |
| val_batch_size=1, | |
| saved_name_train='./train_ds_2d.csv', | |
| saved_name_val='./val_ds_2d.csv', | |
| resized_size=(512,512,None), | |
| div_size=(16,16,None), | |
| center_crop=20, | |
| ifcheck_volume=True, | |
| ifcheck_sclices=False,): | |
| # volume-level transforms for both image and label | |
| train_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='train',prob=0.8) | |
| val_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='val') | |
| train_ds, val_ds = get_file_list(data_pelvis_path, | |
| train_number, | |
| val_number) | |
| train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop) | |
| train_ds, val_ds = load_volumes(train_transforms, val_transforms, | |
| train_crop_ds, val_crop_ds, | |
| train_ds, val_ds, | |
| saved_name_train, saved_name_val, | |
| ifsave=True, | |
| ifcheck=ifcheck_volume) | |
| train_loader,val_loader = load_batch_slices(train_ds, | |
| val_ds, | |
| train_batch_size, | |
| val_batch_size=val_batch_size, | |
| window_width=1, | |
| ifcheck=ifcheck_sclices) | |
| return train_ds, val_ds, train_loader,val_loader,train_transforms,val_transforms | |
| def len_patchloader(train_volume_ds,train_batch_size): | |
| slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds))) | |
| print('total slices in training set:',slice_number) | |
| import math | |
| batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds))) | |
| print('total batches in training set:',batch_number) | |
| return slice_number,batch_number | |
| if __name__ == '__main__': | |
| dataset_path=r"F:\yang_Projects\Datasets\Task1\pelvis" | |
| train_volume_ds,_,train_loader,_,_,_ = myslicesloader(dataset_path, | |
| normalize='none', | |
| train_number=2, | |
| val_number=1, | |
| train_batch_size=4, | |
| val_batch_size=1, | |
| saved_name_train='./train_ds_2d.csv', | |
| saved_name_val='./val_ds_2d.csv', | |
| resized_size=(512, 512, None), | |
| div_size=(16,16,None), | |
| ifcheck_volume=False, | |
| ifcheck_sclices=False,) | |
| from tqdm import tqdm | |
| parameter_file=r'.\test.txt' | |
| for data in tqdm(train_loader): | |
| with open(parameter_file, 'a') as f: | |
| f.write('image batch:' + str(data["image"].shape)+'\n') | |
| f.write('label batch:' + str(data["label"].shape)+'\n') | |
| f.write('\n') |