| """ | |
| Utilities for data loading and processing. | |
| """ | |
| from torch.utils.data import IterableDataset | |
| class ShardedIterableDataset(IterableDataset): | |
| """ | |
| A super simple implementation of a sharded iterable dataset that enables DataParallelism | |
| across multiple workers. Ensures that each worker gets a unique shard of the dataset. | |
| NOTE: Also works fine if there is only one worker. | |
| """ | |
| def __init__(self, dataset, rank, world_size): | |
| self.dataset = dataset | |
| self.rank = rank | |
| self.world_size = world_size | |
| def __iter__(self): | |
| iterator = iter(self.dataset) | |
| # NOTE: Start by skipping to this worker's shard | |
| for _ in range(self.rank): | |
| next(iterator) | |
| # NOTE: Yield every world_size-th item | |
| while True: | |
| try: | |
| yield next(iterator) | |
| # Skip other workers' samples | |
| for _ in range(self.world_size - 1): | |
| next(iterator) | |
| except StopIteration: | |
| break | |