Spaces:
Runtime error
Runtime error
Create combined_loader.py
Browse files
3rdparty/densepose/data/combined_loader.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from collections import deque
|
| 5 |
+
from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
|
| 6 |
+
|
| 7 |
+
Loader = Iterable[Any]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
|
| 11 |
+
if not pool:
|
| 12 |
+
pool.extend(next(iterator))
|
| 13 |
+
return pool.popleft()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CombinedDataLoader:
|
| 17 |
+
"""
|
| 18 |
+
Combines data loaders using the provided sampling ratios
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
BATCH_COUNT = 100
|
| 22 |
+
|
| 23 |
+
def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
|
| 24 |
+
self.loaders = loaders
|
| 25 |
+
self.batch_size = batch_size
|
| 26 |
+
self.ratios = ratios
|
| 27 |
+
|
| 28 |
+
def __iter__(self) -> Iterator[List[Any]]:
|
| 29 |
+
iters = [iter(loader) for loader in self.loaders]
|
| 30 |
+
indices = []
|
| 31 |
+
pool = [deque()] * len(iters)
|
| 32 |
+
# infinite iterator, as in D2
|
| 33 |
+
while True:
|
| 34 |
+
if not indices:
|
| 35 |
+
# just a buffer of indices, its size doesn't matter
|
| 36 |
+
# as long as it's a multiple of batch_size
|
| 37 |
+
k = self.batch_size * self.BATCH_COUNT
|
| 38 |
+
indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
|
| 39 |
+
try:
|
| 40 |
+
batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
|
| 41 |
+
except StopIteration:
|
| 42 |
+
break
|
| 43 |
+
indices = indices[self.batch_size :]
|
| 44 |
+
yield batch
|