from typing import Union from torch.utils.data import Dataset from torch.utils.data import DataLoader from utils.omega_parser import DataLoading from data.dataset import EPUDataset from data.loading import EPUDatasetFromConfig def to_dataloader(dataset: Union[Dataset, EPUDataset, EPUDatasetFromConfig], loading_cfg: DataLoading) -> DataLoader: return DataLoader(dataset, batch_size=loading_cfg.batch_size, shuffle=loading_cfg.shuffle, num_workers=loading_cfg.num_workers, pin_memory=loading_cfg.pin_memory, persistent_workers=loading_cfg.persistent_workers, drop_last=False )