| r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch | |
| data from an iterable-style or map-style dataset. This logic is shared in both | |
| single- and multi-processing data loading. | |
| """ | |
| class _BaseDatasetFetcher(object): | |
| def __init__(self, dataset, auto_collation, collate_fn, drop_last): | |
| self.dataset = dataset | |
| self.auto_collation = auto_collation | |
| self.collate_fn = collate_fn | |
| self.drop_last = drop_last | |
| def fetch(self, possibly_batched_index): | |
| raise NotImplementedError() | |
| class _IterableDatasetFetcher(_BaseDatasetFetcher): | |
| def __init__(self, dataset, auto_collation, collate_fn, drop_last): | |
| super(_IterableDatasetFetcher, self).__init__( | |
| dataset, auto_collation, collate_fn, drop_last | |
| ) | |
| self.dataset_iter = iter(dataset) | |
| self.ended = False | |
| def fetch(self, possibly_batched_index): | |
| if self.ended: | |
| raise StopIteration | |
| if self.auto_collation: | |
| data = [] | |
| for _ in possibly_batched_index: | |
| try: | |
| data.append(next(self.dataset_iter)) | |
| except StopIteration: | |
| self.ended = True | |
| break | |
| if len(data) == 0 or ( | |
| self.drop_last and len(data) < len(possibly_batched_index) | |
| ): | |
| raise StopIteration | |
| else: | |
| data = next(self.dataset_iter) | |
| return self.collate_fn(data) | |
| class _MapDatasetFetcher(_BaseDatasetFetcher): | |
| def __init__(self, dataset, auto_collation, collate_fn, drop_last): | |
| super(_MapDatasetFetcher, self).__init__( | |
| dataset, auto_collation, collate_fn, drop_last | |
| ) | |
| def fetch(self, possibly_batched_index): | |
| if self.auto_collation: | |
| if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: | |
| data = self.dataset.__getitems__(possibly_batched_index) | |
| else: | |
| data = [self.dataset[idx] for idx in possibly_batched_index] | |
| else: | |
| data = self.dataset[possibly_batched_index] | |
| return self.collate_fn(data) | |