Spaces:
Runtime error
Runtime error
| import torch | |
| import tops | |
| from .utils import collate_fn | |
| def get_dataloader( | |
| dataset, gpu_transform: torch.nn.Module, | |
| num_workers, | |
| batch_size, | |
| infinite: bool, | |
| drop_last: bool, | |
| prefetch_factor: int, | |
| shuffle, | |
| channels_last=False | |
| ): | |
| sampler = None | |
| dl_kwargs = dict( | |
| pin_memory=True, | |
| ) | |
| if infinite: | |
| sampler = tops.InfiniteSampler( | |
| dataset, rank=tops.rank(), | |
| num_replicas=tops.world_size(), | |
| shuffle=shuffle | |
| ) | |
| elif tops.world_size() > 1: | |
| sampler = torch.utils.data.DistributedSampler( | |
| dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank()) | |
| dl_kwargs["drop_last"] = drop_last | |
| else: | |
| dl_kwargs["shuffle"] = shuffle | |
| dl_kwargs["drop_last"] = drop_last | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, sampler=sampler, collate_fn=collate_fn, | |
| batch_size=batch_size, | |
| num_workers=num_workers, prefetch_factor=prefetch_factor, | |
| **dl_kwargs | |
| ) | |
| dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last) | |
| return dataloader | |