File size: 723 Bytes
a19a7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch.utils.data import DataLoader
from PatientDataset import PatientMultiModalDataset

def make_loader(
    split_dir: str,
    batch_size: int = 4,
    n_slices: int = 10,
    img_size: int = 64,
    num_workers: int = 4,
    shuffle: bool = True,
    pin_memory: bool = True,
):
    ds = PatientMultiModalDataset(
        split_dir=split_dir,
        n_slices=n_slices,
        img_size=(img_size, img_size),
        clinical_dim=6,
        radiomics_dim=128,
        pet_dim=5,
        seed=0,
        require_space=True,
    )
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )