| | |
| |
|
| | import random |
| | from collections import deque |
| | from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence |
| |
|
| | Loader = Iterable[Any] |
| |
|
| |
|
| | def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]): |
| | if not pool: |
| | pool.extend(next(iterator)) |
| | return pool.popleft() |
| |
|
| |
|
| | class CombinedDataLoader: |
| | """ |
| | Combines data loaders using the provided sampling ratios |
| | """ |
| |
|
| | BATCH_COUNT = 100 |
| |
|
| | def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]): |
| | self.loaders = loaders |
| | self.batch_size = batch_size |
| | self.ratios = ratios |
| |
|
| | def __iter__(self) -> Iterator[List[Any]]: |
| | iters = [iter(loader) for loader in self.loaders] |
| | indices = [] |
| | pool = [deque()] * len(iters) |
| | |
| | while True: |
| | if not indices: |
| | |
| | |
| | k = self.batch_size * self.BATCH_COUNT |
| | indices = random.choices(range(len(self.loaders)), self.ratios, k=k) |
| | try: |
| | batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]] |
| | except StopIteration: |
| | break |
| | indices = indices[self.batch_size :] |
| | yield batch |
| |
|