File size: 723 Bytes
a19a7aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from torch.utils.data import DataLoader
from PatientDataset import PatientMultiModalDataset
def make_loader(
split_dir: str,
batch_size: int = 4,
n_slices: int = 10,
img_size: int = 64,
num_workers: int = 4,
shuffle: bool = True,
pin_memory: bool = True,
):
ds = PatientMultiModalDataset(
split_dir=split_dir,
n_slices=n_slices,
img_size=(img_size, img_size),
clinical_dim=6,
radiomics_dim=128,
pet_dim=5,
seed=0,
require_space=True,
)
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)
|