PawanratRung commited on
Commit
19ebd30
·
verified ·
1 Parent(s): e7f7588

Create combined_loader.py

Browse files
3rdparty/densepose/data/combined_loader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import random
4
+ from collections import deque
5
+ from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
6
+
7
+ Loader = Iterable[Any]
8
+
9
+
10
+ def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
11
+ if not pool:
12
+ pool.extend(next(iterator))
13
+ return pool.popleft()
14
+
15
+
16
+ class CombinedDataLoader:
17
+ """
18
+ Combines data loaders using the provided sampling ratios
19
+ """
20
+
21
+ BATCH_COUNT = 100
22
+
23
+ def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
24
+ self.loaders = loaders
25
+ self.batch_size = batch_size
26
+ self.ratios = ratios
27
+
28
+ def __iter__(self) -> Iterator[List[Any]]:
29
+ iters = [iter(loader) for loader in self.loaders]
30
+ indices = []
31
+ pool = [deque()] * len(iters)
32
+ # infinite iterator, as in D2
33
+ while True:
34
+ if not indices:
35
+ # just a buffer of indices, its size doesn't matter
36
+ # as long as it's a multiple of batch_size
37
+ k = self.batch_size * self.BATCH_COUNT
38
+ indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
39
+ try:
40
+ batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
41
+ except StopIteration:
42
+ break
43
+ indices = indices[self.batch_size :]
44
+ yield batch