from torch.utils.data import DataLoader from PatientDataset import PatientMultiModalDataset def make_loader( split_dir: str, batch_size: int = 4, n_slices: int = 10, img_size: int = 64, num_workers: int = 4, shuffle: bool = True, pin_memory: bool = True, ): ds = PatientMultiModalDataset( split_dir=split_dir, n_slices=n_slices, img_size=(img_size, img_size), clinical_dim=6, radiomics_dim=128, pet_dim=5, seed=0, require_space=True, ) return DataLoader( ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=False, )