|
|
import warnings |
|
|
|
|
|
from collections import deque |
|
|
from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVar, Deque |
|
|
|
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe |
|
|
from torch.utils.data.datapipes._hook_iterator import _SnapshotState |
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe |
|
|
from torch.utils.data.datapipes.utils.common import StreamWrapper, _check_unpickable_fn |
|
|
|
|
|
__all__ = [ |
|
|
"ConcaterIterDataPipe", |
|
|
"DemultiplexerIterDataPipe", |
|
|
"ForkerIterDataPipe", |
|
|
"MultiplexerIterDataPipe", |
|
|
"ZipperIterDataPipe", |
|
|
] |
|
|
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
|
|
|
|
|
@functional_datapipe('concat') |
|
|
class ConcaterIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Concatenates multiple Iterable DataPipes (functional name: ``concat``). The resulting DataPipe will |
|
|
yield all the elements from the first input DataPipe, before yielding from the subsequent ones. |
|
|
|
|
|
Args: |
|
|
datapipes: Iterable DataPipes being concatenated |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +REQUIRES(module:torchdata) |
|
|
>>> import random |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> dp1 = IterableWrapper(range(3)) |
|
|
>>> dp2 = IterableWrapper(range(5)) |
|
|
>>> list(dp1.concat(dp2)) |
|
|
[0, 1, 2, 0, 1, 2, 3, 4] |
|
|
""" |
|
|
datapipes: Tuple[IterDataPipe] |
|
|
length: Optional[int] |
|
|
|
|
|
def __init__(self, *datapipes: IterDataPipe): |
|
|
if len(datapipes) == 0: |
|
|
raise ValueError("Expected at least one DataPipe, but got nothing") |
|
|
if not all(isinstance(dp, IterDataPipe) for dp in datapipes): |
|
|
raise TypeError("Expected all inputs to be `IterDataPipe`") |
|
|
self.datapipes = datapipes |
|
|
self.length = None |
|
|
|
|
|
def __iter__(self) -> Iterator: |
|
|
for dp in self.datapipes: |
|
|
for data in dp: |
|
|
yield data |
|
|
|
|
|
def __len__(self) -> int: |
|
|
if self.length is not None: |
|
|
if self.length == -1: |
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) |
|
|
return self.length |
|
|
if all(isinstance(dp, Sized) for dp in self.datapipes): |
|
|
self.length = sum(len(dp) for dp in self.datapipes) |
|
|
else: |
|
|
self.length = -1 |
|
|
return len(self) |
|
|
|
|
|
|
|
|
@functional_datapipe('fork') |
|
|
class ForkerIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``). |
|
|
|
|
|
Args: |
|
|
datapipe: Iterable DataPipe being copied |
|
|
num_instances: number of instances of the datapipe to create |
|
|
buffer_size: this restricts how far ahead the leading child DataPipe |
|
|
can read relative to the slowest child DataPipe. |
|
|
Defaults to ``1000``. Use ``-1`` for the unlimited buffer. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +REQUIRES(module:torchdata) |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> source_dp = IterableWrapper(range(5)) |
|
|
>>> dp1, dp2 = source_dp.fork(num_instances=2) |
|
|
>>> list(dp1) |
|
|
[0, 1, 2, 3, 4] |
|
|
>>> list(dp2) |
|
|
[0, 1, 2, 3, 4] |
|
|
""" |
|
|
def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000): |
|
|
if num_instances < 1: |
|
|
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") |
|
|
if num_instances == 1: |
|
|
return datapipe |
|
|
container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size) |
|
|
return [_ChildDataPipe(container, i) for i in range(num_instances)] |
|
|
|
|
|
|
|
|
class _ForkerIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Container to hold instance-specific information on behalf of ForkerIterDataPipe. It tracks |
|
|
the state of its child DataPipes, maintains the buffer, and yields the next value |
|
|
as requested by the child DataPipes. |
|
|
""" |
|
|
def __init__(self, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000): |
|
|
self.main_datapipe = datapipe |
|
|
self._datapipe_iterator: Optional[Iterator[Any]] = None |
|
|
self.num_instances = num_instances |
|
|
self.buffer: Deque = deque() |
|
|
self.buffer_size = buffer_size |
|
|
if self.buffer_size < 0: |
|
|
warnings.warn( |
|
|
"Unlimited buffer size is set for `fork`, " |
|
|
"please be aware of OOM at random places", |
|
|
UserWarning |
|
|
) |
|
|
self.child_pointers: List[int] = [0] * num_instances |
|
|
self.slowest_ptr = 0 |
|
|
self.leading_ptr = 0 |
|
|
self.end_ptr: Optional[int] = None |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.main_datapipe) |
|
|
|
|
|
def get_next_element_by_instance(self, instance_id: int): |
|
|
if self._datapipe_iterator is None: |
|
|
self._datapipe_iterator = iter(self.main_datapipe) |
|
|
self._snapshot_state = _SnapshotState.Iterating |
|
|
while self.end_ptr is None or self.child_pointers[instance_id] + 1 < self.end_ptr: |
|
|
self.child_pointers[instance_id] += 1 |
|
|
|
|
|
if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr: |
|
|
idx = self.child_pointers[instance_id] - self.slowest_ptr - 1 |
|
|
return_val = self.buffer[idx] |
|
|
else: |
|
|
self.leading_ptr = self.child_pointers[instance_id] |
|
|
try: |
|
|
return_val = next(self._datapipe_iterator) |
|
|
self.buffer.append(return_val) |
|
|
except StopIteration: |
|
|
self.end_ptr = self.leading_ptr |
|
|
continue |
|
|
if self.child_pointers[instance_id] == self.slowest_ptr + 1: |
|
|
new_min = min(self.child_pointers) |
|
|
if self.slowest_ptr < new_min: |
|
|
self.slowest_ptr = new_min |
|
|
self.buffer.popleft() |
|
|
if self.buffer_size >= 0 and self.leading_ptr > self.buffer_size + self.slowest_ptr: |
|
|
raise BufferError("ForkerIterDataPipe buffer overflow," + |
|
|
f"buffer size {self.buffer_size} is insufficient.") |
|
|
yield return_val |
|
|
|
|
|
if all(p + 1 == self.end_ptr for p in self.child_pointers): |
|
|
self._datapipe_iterator = None |
|
|
|
|
|
def is_every_instance_exhausted(self) -> bool: |
|
|
|
|
|
|
|
|
return self.end_ptr is not None and\ |
|
|
all(self.end_ptr == ptr or self.end_ptr - 1 == ptr for ptr in self.child_pointers) |
|
|
|
|
|
def reset(self) -> None: |
|
|
self._datapipe_iterator = iter(self.main_datapipe) |
|
|
self.buffer = deque() |
|
|
self.child_pointers = [0] * self.num_instances |
|
|
self.slowest_ptr = 0 |
|
|
self.leading_ptr = 0 |
|
|
self.end_ptr = None |
|
|
|
|
|
def __getstate__(self): |
|
|
state = ( |
|
|
self.main_datapipe, |
|
|
self.num_instances, |
|
|
self.buffer_size, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) |
|
|
if IterDataPipe.getstate_hook is not None: |
|
|
return IterDataPipe.getstate_hook(state) |
|
|
return state |
|
|
|
|
|
def __setstate__(self, state): |
|
|
( |
|
|
self.main_datapipe, |
|
|
self.num_instances, |
|
|
self.buffer_size, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) = state |
|
|
self._datapipe_iterator = None |
|
|
self.buffer = deque() |
|
|
self.child_pointers = [0] * self.num_instances |
|
|
self.slowest_ptr = 0 |
|
|
self.leading_ptr = 0 |
|
|
self.end_ptr = None |
|
|
|
|
|
def __del__(self): |
|
|
self.buffer.clear() |
|
|
|
|
|
|
|
|
class _ChildDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Iterable Datapipe that is a child of a main DataPipe. The instance of this class |
|
|
will pass its instance_id to get the next value from its main DataPipe. |
|
|
|
|
|
Note: |
|
|
ChildDataPipe, like all other IterDataPipe, follows the single iterator per IterDataPipe constraint. |
|
|
Since ChildDataPipes share a common buffer, when an iterator is created for one of the ChildDataPipes, |
|
|
the previous iterators for all ChildDataPipes must be invalidated, with the exception when a ChildDataPipe |
|
|
hasn't had an iterator created from it since the last invalidation. See the example below. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +REQUIRES(module:torchdata) |
|
|
>>> # Singler Iterator per IteraDataPipe Invalidation |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> source_dp = IterableWrapper(range(10)) |
|
|
>>> cdp1, cdp2 = source_dp.fork(num_instances=2) |
|
|
>>> it1, it2 = iter(cdp1), iter(cdp2) |
|
|
>>> it3 = iter(cdp1) |
|
|
>>> # The line above invalidates `it1` and `it2`, and resets `ForkerIterDataPipe`. |
|
|
>>> it4 = iter(cdp2) |
|
|
>>> # The line above doesn't invalidate `it3`, because an iterator for `cdp2` hasn't been created since |
|
|
>>> # the last invalidation. |
|
|
|
|
|
Args: |
|
|
main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)' |
|
|
instance_id: integer identifier of this instance |
|
|
""" |
|
|
_is_child_datapipe: bool = True |
|
|
|
|
|
def __init__(self, main_datapipe: IterDataPipe, instance_id: int): |
|
|
required_attrs = ["get_next_element_by_instance", "is_every_instance_exhausted", "reset"] |
|
|
required_ops = [getattr(main_datapipe, attr) for attr in required_attrs] |
|
|
if any(not callable(op) for op in required_ops): |
|
|
raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.") |
|
|
self.main_datapipe: IterDataPipe = main_datapipe |
|
|
self.instance_id = instance_id |
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
|
|
|
return self.main_datapipe.get_next_element_by_instance(self.instance_id) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.main_datapipe) |
|
|
|
|
|
|
|
|
def _set_main_datapipe_valid_iterator_id(self) -> int: |
|
|
r""" |
|
|
Update the valid iterator ID for both this DataPipe object and `main_datapipe`. |
|
|
`main_datapipe.reset()` is called when the ID is incremented to a new generation. |
|
|
""" |
|
|
|
|
|
if self.main_datapipe._valid_iterator_id is None: |
|
|
self.main_datapipe._valid_iterator_id = 0 |
|
|
|
|
|
|
|
|
elif self.main_datapipe._valid_iterator_id == self._valid_iterator_id: |
|
|
self.main_datapipe._valid_iterator_id += 1 |
|
|
|
|
|
if not self.main_datapipe.is_every_instance_exhausted(): |
|
|
warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting " |
|
|
"the buffer and each child DataPipe will read from the start again.", UserWarning) |
|
|
self.main_datapipe.reset() |
|
|
|
|
|
|
|
|
self._valid_iterator_id = self.main_datapipe._valid_iterator_id |
|
|
return self._valid_iterator_id |
|
|
|
|
|
|
|
|
def _check_valid_iterator_id(self, iterator_id) -> bool: |
|
|
r""" |
|
|
Check the valid iterator ID against that of DataPipe object and that of `main_datapipe`. |
|
|
""" |
|
|
return iterator_id == self._valid_iterator_id and iterator_id == self.main_datapipe._valid_iterator_id |
|
|
|
|
|
|
|
|
@functional_datapipe('demux') |
|
|
class DemultiplexerIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Splits the input DataPipe into multiple child DataPipes, using the given |
|
|
classification function (functional name: ``demux``). A list of the child DataPipes is returned from this operation. |
|
|
|
|
|
Args: |
|
|
datapipe: Iterable DataPipe being filtered |
|
|
num_instances: number of instances of the DataPipe to create |
|
|
classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None`` |
|
|
drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` |
|
|
buffer_size: this defines the maximum number of inputs that the buffer can hold across all child |
|
|
DataPipes while waiting for their values to be yielded. |
|
|
Defaults to ``1000``. Use ``-1`` for the unlimited buffer. |
|
|
|
|
|
Examples: |
|
|
>>> # xdoctest: +REQUIRES(module:torchdata) |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> def odd_or_even(n): |
|
|
... return n % 2 |
|
|
>>> source_dp = IterableWrapper(range(5)) |
|
|
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even) |
|
|
>>> list(dp1) |
|
|
[0, 2, 4] |
|
|
>>> list(dp2) |
|
|
[1, 3] |
|
|
>>> # It can also filter out any element that gets `None` from the `classifier_fn` |
|
|
>>> def odd_or_even_no_zero(n): |
|
|
... return n % 2 if n != 0 else None |
|
|
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) |
|
|
>>> list(dp1) |
|
|
[2, 4] |
|
|
>>> list(dp2) |
|
|
[1, 3] |
|
|
""" |
|
|
def __new__(cls, datapipe: IterDataPipe, num_instances: int, |
|
|
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000): |
|
|
if num_instances < 1: |
|
|
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") |
|
|
|
|
|
_check_unpickable_fn(classifier_fn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) |
|
|
return [_ChildDataPipe(container, i) for i in range(num_instances)] |
|
|
|
|
|
|
|
|
class _DemultiplexerIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. It tracks |
|
|
the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value |
|
|
as requested by the child DataPipes. |
|
|
""" |
|
|
|
|
|
def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int, |
|
|
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool, buffer_size: int): |
|
|
self.main_datapipe = datapipe |
|
|
self._datapipe_iterator: Optional[Iterator[Any]] = None |
|
|
self.num_instances = num_instances |
|
|
self.buffer_size = buffer_size |
|
|
if self.buffer_size < 0: |
|
|
warnings.warn( |
|
|
"Unlimited buffer size is set for `demux`, " |
|
|
"please be aware of OOM at random places", |
|
|
UserWarning |
|
|
) |
|
|
self.current_buffer_usage = 0 |
|
|
self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)] |
|
|
self.classifier_fn = classifier_fn |
|
|
self.drop_none = drop_none |
|
|
self.main_datapipe_exhausted = False |
|
|
|
|
|
def _find_next(self, instance_id: int) -> T_co: |
|
|
while True: |
|
|
if self.main_datapipe_exhausted: |
|
|
raise StopIteration |
|
|
if self._datapipe_iterator is None: |
|
|
raise ValueError( |
|
|
"_datapipe_iterator has not been set, likely because this private method is called directly " |
|
|
"without invoking get_next_element_by_instance() first.") |
|
|
value = next(self._datapipe_iterator) |
|
|
classification = self.classifier_fn(value) |
|
|
if classification is None and self.drop_none: |
|
|
StreamWrapper.close_streams(value) |
|
|
continue |
|
|
if classification is None or classification >= self.num_instances or classification < 0: |
|
|
raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " + |
|
|
f"{classification} is returned.") |
|
|
if classification == instance_id: |
|
|
return value |
|
|
self.child_buffers[classification].append(value) |
|
|
self.current_buffer_usage += 1 |
|
|
if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size: |
|
|
raise BufferError( |
|
|
f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient.") |
|
|
|
|
|
def get_next_element_by_instance(self, instance_id: int): |
|
|
if self._datapipe_iterator is None and not self.main_datapipe_exhausted: |
|
|
self._datapipe_iterator = iter(self.main_datapipe) |
|
|
self._snapshot_state = _SnapshotState.Iterating |
|
|
stop = False |
|
|
while not stop: |
|
|
if self.child_buffers[instance_id]: |
|
|
self.current_buffer_usage -= 1 |
|
|
yield self.child_buffers[instance_id].popleft() |
|
|
else: |
|
|
try: |
|
|
yield self._find_next(instance_id) |
|
|
except StopIteration: |
|
|
stop = True |
|
|
self.main_datapipe_exhausted = True |
|
|
self._datapipe_iterator = None |
|
|
|
|
|
def is_every_instance_exhausted(self) -> bool: |
|
|
return self.main_datapipe_exhausted and all(not child_buffer for child_buffer in self.child_buffers) |
|
|
|
|
|
def reset(self) -> None: |
|
|
self._datapipe_iterator = None |
|
|
self.current_buffer_usage = 0 |
|
|
self.child_buffers = [deque() for _ in range(self.num_instances)] |
|
|
self.main_datapipe_exhausted = False |
|
|
|
|
|
def __getstate__(self): |
|
|
state = ( |
|
|
self.main_datapipe, |
|
|
self.num_instances, |
|
|
self.buffer_size, |
|
|
self.classifier_fn, |
|
|
self.drop_none, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) |
|
|
if IterDataPipe.getstate_hook is not None: |
|
|
return IterDataPipe.getstate_hook(state) |
|
|
return state |
|
|
|
|
|
def __setstate__(self, state): |
|
|
( |
|
|
self.main_datapipe, |
|
|
self.num_instances, |
|
|
self.buffer_size, |
|
|
self.classifier_fn, |
|
|
self.drop_none, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) = state |
|
|
self._datapipe_iterator = None |
|
|
self.current_buffer_usage = 0 |
|
|
self.child_buffers = [deque() for _ in range(self.num_instances)] |
|
|
self.main_datapipe_exhausted = False |
|
|
|
|
|
def __del__(self): |
|
|
for dq in self.child_buffers: |
|
|
dq.clear() |
|
|
|
|
|
|
|
|
@functional_datapipe('mux') |
|
|
class MultiplexerIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). As in, |
|
|
one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, |
|
|
and so on. It ends when the shortest input DataPipe is exhausted. |
|
|
|
|
|
Args: |
|
|
datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +REQUIRES(module:torchdata) |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) |
|
|
>>> list(dp1.mux(dp2, dp3)) |
|
|
[0, 10, 20, 1, 11, 21, 2, 12, 22] |
|
|
""" |
|
|
def __init__(self, *datapipes): |
|
|
self.datapipes = datapipes |
|
|
self.length: Optional[int] = None |
|
|
self.buffer: List = [] |
|
|
|
|
|
def __iter__(self): |
|
|
iterators = [iter(x) for x in self.datapipes] |
|
|
while len(iterators): |
|
|
for it in iterators: |
|
|
try: |
|
|
value = next(it) |
|
|
self.buffer.append(value) |
|
|
except StopIteration: |
|
|
self.buffer.clear() |
|
|
return |
|
|
for value in self.buffer: |
|
|
yield value |
|
|
self.buffer.clear() |
|
|
|
|
|
def __len__(self): |
|
|
if self.length is not None: |
|
|
if self.length == -1: |
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) |
|
|
return self.length |
|
|
if all(isinstance(dp, Sized) for dp in self.datapipes): |
|
|
self.length = min(len(dp) for dp in self.datapipes) * len(self.datapipes) |
|
|
else: |
|
|
self.length = -1 |
|
|
return len(self) |
|
|
|
|
|
def reset(self) -> None: |
|
|
self.buffer = [] |
|
|
|
|
|
def __getstate__(self): |
|
|
state = ( |
|
|
self.datapipes, |
|
|
self.length, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) |
|
|
if IterDataPipe.getstate_hook is not None: |
|
|
return IterDataPipe.getstate_hook(state) |
|
|
return state |
|
|
|
|
|
def __setstate__(self, state): |
|
|
( |
|
|
self.datapipes, |
|
|
self.length, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) = state |
|
|
self.buffer = [] |
|
|
|
|
|
def __del__(self): |
|
|
self.buffer.clear() |
|
|
|
|
|
|
|
|
@functional_datapipe('zip') |
|
|
class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]): |
|
|
r""" |
|
|
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). |
|
|
The output is stopped as soon as the shortest input DataPipe is exhausted. |
|
|
|
|
|
Args: |
|
|
*datapipes: Iterable DataPipes being aggregated |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +REQUIRES(module:torchdata) |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) |
|
|
>>> list(dp1.zip(dp2, dp3)) |
|
|
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] |
|
|
""" |
|
|
datapipes: Tuple[IterDataPipe] |
|
|
length: Optional[int] |
|
|
|
|
|
def __init__(self, *datapipes: IterDataPipe): |
|
|
if not all(isinstance(dp, IterDataPipe) for dp in datapipes): |
|
|
raise TypeError("All inputs are required to be `IterDataPipe` " |
|
|
"for `ZipIterDataPipe`.") |
|
|
super().__init__() |
|
|
self.datapipes = datapipes |
|
|
self.length = None |
|
|
|
|
|
def __iter__(self) -> Iterator[Tuple[T_co]]: |
|
|
iterators = [iter(datapipe) for datapipe in self.datapipes] |
|
|
try: |
|
|
for data in zip(*iterators): |
|
|
yield data |
|
|
finally: |
|
|
unused = [] |
|
|
for iterator in iterators: |
|
|
try: |
|
|
unused += list(iterator) |
|
|
except RuntimeError: |
|
|
pass |
|
|
|
|
|
|
|
|
for item in unused: |
|
|
StreamWrapper.close_streams(item) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
if self.length is not None: |
|
|
if self.length == -1: |
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) |
|
|
return self.length |
|
|
if all(isinstance(dp, Sized) for dp in self.datapipes): |
|
|
self.length = min(len(dp) for dp in self.datapipes) |
|
|
else: |
|
|
self.length = -1 |
|
|
return len(self) |
|
|
|