zy7_oldserver
1
fd601de
import monai
from monai.data import Dataset
from torch.utils.data import DataLoader
import torch
from .basics import get_file_list, check_batch_data, get_transforms,crop_volumes, load_volumes
def load_batch_slices3D(train_volume_ds,val_volume_ds, train_batch_size=5,val_batch_size=1,ifcheck=True):
patch_func = monai.data.PatchIterd(
keys=["image", "label"],
patch_size=(None, None,32), # dynamic first two dimensions
start_pos=(0, 0, 0)
)
# for training
train_patch_ds = monai.data.GridPatchDataset(
data=train_volume_ds, patch_iter=patch_func, with_coordinates=False)
train_loader = DataLoader(
train_patch_ds,
batch_size=train_batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
# for validation
val_loader = DataLoader(
val_volume_ds,
num_workers=1,
batch_size=val_batch_size,
pin_memory=torch.cuda.is_available())
if ifcheck:
check_batch_data(train_loader,val_loader,train_patch_ds,val_volume_ds,train_batch_size,val_batch_size)
return train_loader,val_loader
def mydataloader_3d(data_pelvis_path,
normalize='zscore',
pad='minimum',
train_number=10,
val_number=2,
train_batch_size=1,
val_batch_size=1,
saved_name_train='./train_ds_2d.csv',
saved_name_val='./val_ds_2d.csv',
resized_size=(512,512,128),
div_size=(16,16,16),
ifcheck_volume=True,):
# volume-level transforms for both image and segmentation
train_transforms = get_transforms(normalize,pad,resized_size,div_size)
train_ds, val_ds = get_file_list(data_pelvis_path,
train_number,
val_number)
#train_volume_ds, val_volume_ds
train_crop_ds, val_crop_ds = crop_volumes(train_ds, val_ds,center_crop=0)
train_volume_ds, val_volume_ds = load_volumes(train_transforms,
train_crop_ds, val_crop_ds,
train_ds, val_ds,
saved_name_train, saved_name_val,
ifsave=True,
ifcheck=ifcheck_volume)
train_loader = DataLoader(train_volume_ds, batch_size=train_batch_size)
val_loader = DataLoader(val_volume_ds, batch_size=val_batch_size)
'''
ifcheck_sclices=False
train_loader,val_loader = load_batch_slices3D(train_volume_ds,
val_volume_ds,
train_batch_size,
val_batch_size=val_batch_size,
ifcheck=ifcheck_sclices)
'''
return train_loader,val_loader,train_transforms