|
|
import copy |
|
|
import warnings |
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe |
|
|
|
|
|
__all__ = ["IterableWrapperIterDataPipe", ] |
|
|
|
|
|
|
|
|
class IterableWrapperIterDataPipe(IterDataPipe): |
|
|
r""" |
|
|
Wraps an iterable object to create an IterDataPipe. |
|
|
|
|
|
Args: |
|
|
iterable: Iterable object to be wrapped into an IterDataPipe |
|
|
deepcopy: Option to deepcopy input iterable object for each |
|
|
iterator. The copy is made when the first element is read in ``iter()``. |
|
|
|
|
|
.. note:: |
|
|
If ``deepcopy`` is explicitly set to ``False``, users should ensure |
|
|
that the data pipeline doesn't contain any in-place operations over |
|
|
the iterable instance to prevent data inconsistency across iterations. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> dp = IterableWrapper(range(10)) |
|
|
>>> list(dp) |
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] |
|
|
""" |
|
|
def __init__(self, iterable, deepcopy=True): |
|
|
self.iterable = iterable |
|
|
self.deepcopy = deepcopy |
|
|
|
|
|
def __iter__(self): |
|
|
source_data = self.iterable |
|
|
if self.deepcopy: |
|
|
try: |
|
|
source_data = copy.deepcopy(self.iterable) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except TypeError: |
|
|
warnings.warn( |
|
|
"The input iterable can not be deepcopied, " |
|
|
"please be aware of in-place modification would affect source data." |
|
|
) |
|
|
for data in source_data: |
|
|
yield data |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.iterable) |
|
|
|