MM-DLS / mm-dls /ImageDataLoader.py
FangDai's picture
Upload 11 files
a19a7aa verified
from torch.utils.data import DataLoader
from PatientDataset import PatientMultiModalDataset
def make_loader(
split_dir: str,
batch_size: int = 4,
n_slices: int = 10,
img_size: int = 64,
num_workers: int = 4,
shuffle: bool = True,
pin_memory: bool = True,
):
ds = PatientMultiModalDataset(
split_dir=split_dir,
n_slices=n_slices,
img_size=(img_size, img_size),
clinical_dim=6,
radiomics_dim=128,
pet_dim=5,
seed=0,
require_space=True,
)
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
)