ius / data /dataloader.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
from typing import Union
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils.omega_parser import DataLoading
from data.dataset import EPUDataset
from data.loading import EPUDatasetFromConfig
def to_dataloader(dataset: Union[Dataset, EPUDataset, EPUDatasetFromConfig],
loading_cfg: DataLoading) -> DataLoader:
return DataLoader(dataset,
batch_size=loading_cfg.batch_size,
shuffle=loading_cfg.shuffle,
num_workers=loading_cfg.num_workers,
pin_memory=loading_cfg.pin_memory,
persistent_workers=loading_cfg.persistent_workers,
drop_last=False
)