Spaces:
Runtime error
Runtime error
File size: 5,217 Bytes
fd601de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import monai
from monai.transforms import (
Compose,
LoadImaged,
Rotate90d,
ScaleIntensityd,
EnsureChannelFirstd,
ResizeWithPadOrCropd,
DivisiblePadd,
ThresholdIntensityd,
NormalizeIntensityd,
SqueezeDimd,
Identityd,
CenterSpatialCropd,
)
from monai.data import Dataset
from torch.utils.data import DataLoader
import torch
from .basics import get_file_list, check_batch_data, get_transforms, load_volumes, crop_volumes
##### slices #####
def load_batch_slices(train_volume_ds,val_volume_ds, train_batch_size=8,val_batch_size=1,window_width=1,ifcheck=True):
patch_func = monai.data.PatchIterd(
keys=["source", "target"],
patch_size=(None, None, window_width), # dynamic first two dimensions
start_pos=(0, 0, 0)
)
if window_width==1:
patch_transform = Compose(
[
SqueezeDimd(keys=["source", "target"], dim=-1), # squeeze the last dim
]
)
else:
patch_transform = None
# for training
train_patch_ds = monai.data.GridPatchDataset(
data=train_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
train_loader = DataLoader(
train_patch_ds,
batch_size=train_batch_size,
num_workers=0,
pin_memory=torch.cuda.is_available(),
)
# for validation
val_patch_ds = monai.data.GridPatchDataset(
data=val_volume_ds, patch_iter=patch_func, transform=patch_transform, with_coordinates=False)
val_loader = DataLoader(
val_patch_ds, #val_volume_ds,
num_workers=0,
batch_size=val_batch_size,
pin_memory=torch.cuda.is_available())
if ifcheck:
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
return train_loader,val_loader
def myslicesloader(data_pelvis_path,
normalize='minmax',
pad='minimum',
train_number=1,
val_number=1,
train_batch_size=8,
val_batch_size=1,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
resized_size=(512,512,None),
div_size=(16,16,None),
center_crop=20,
ifcheck_volume=True,
ifcheck_sclices=False,):
# volume-level transforms for both image and label
train_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='train',prob=0.8)
val_transforms = get_transforms(normalize,pad,resized_size,div_size,mode='val')
train_ds, val_ds = get_file_list(data_pelvis_path,
train_number,
val_number)
train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop)
train_ds, val_ds = load_volumes(train_transforms, val_transforms,
train_crop_ds, val_crop_ds,
train_ds, val_ds,
saved_name_train, saved_name_val,
ifsave=True,
ifcheck=ifcheck_volume)
train_loader,val_loader = load_batch_slices(train_ds,
val_ds,
train_batch_size,
val_batch_size=val_batch_size,
window_width=1,
ifcheck=ifcheck_sclices)
return train_ds, val_ds, train_loader,val_loader,train_transforms,val_transforms
def len_patchloader(train_volume_ds,train_batch_size):
slice_number=sum(train_volume_ds[i]['source'].shape[-1] for i in range(len(train_volume_ds)))
print('total slices in training set:',slice_number)
import math
batch_number=sum(math.ceil(train_volume_ds[i]['source'].shape[-1]/train_batch_size) for i in range(len(train_volume_ds)))
print('total batches in training set:',batch_number)
return slice_number,batch_number
if __name__ == '__main__':
dataset_path=r"F:\yang_Projects\Datasets\Task1\pelvis"
train_volume_ds,_,train_loader,_,_,_ = myslicesloader(dataset_path,
normalize='none',
train_number=2,
val_number=1,
train_batch_size=4,
val_batch_size=1,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
resized_size=(512, 512, None),
div_size=(16,16,None),
ifcheck_volume=False,
ifcheck_sclices=False,)
from tqdm import tqdm
parameter_file=r'.\test.txt'
for data in tqdm(train_loader):
with open(parameter_file, 'a') as f:
f.write('image batch:' + str(data["image"].shape)+'\n')
f.write('label batch:' + str(data["label"].shape)+'\n')
f.write('\n') |