|
|
|
|
| import random
|
| from collections import deque
|
| from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
|
|
|
| Loader = Iterable[Any]
|
|
|
|
|
| def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
|
| if not pool:
|
| pool.extend(next(iterator))
|
| return pool.popleft()
|
|
|
|
|
| class CombinedDataLoader:
|
| """
|
| Combines data loaders using the provided sampling ratios
|
| """
|
|
|
| BATCH_COUNT = 100
|
|
|
| def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
|
| self.loaders = loaders
|
| self.batch_size = batch_size
|
| self.ratios = ratios
|
|
|
| def __iter__(self) -> Iterator[List[Any]]:
|
| iters = [iter(loader) for loader in self.loaders]
|
| indices = []
|
| pool = [deque()] * len(iters)
|
|
|
| while True:
|
| if not indices:
|
|
|
|
|
| k = self.batch_size * self.BATCH_COUNT
|
| indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
|
| try:
|
| batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
|
| except StopIteration:
|
| break
|
| indices = indices[self.batch_size :]
|
| yield batch
|
|
|