|
|
import random |
|
|
import torch |
|
|
|
|
|
from torch.utils.data import Sampler, SequentialSampler |
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe |
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe |
|
|
from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar |
|
|
|
|
|
__all__ = [ |
|
|
"SamplerIterDataPipe", |
|
|
"ShufflerIterDataPipe", |
|
|
] |
|
|
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
|
|
|
|
|
class SamplerIterDataPipe(IterDataPipe[T_co]): |
|
|
r""" |
|
|
Generates sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`). |
|
|
|
|
|
Args: |
|
|
datapipe: IterDataPipe to sample from |
|
|
sampler: Sampler class to generate sample elements from input DataPipe. |
|
|
Default is :class:`SequentialSampler` for IterDataPipe |
|
|
""" |
|
|
datapipe: IterDataPipe |
|
|
sampler: Sampler |
|
|
|
|
|
def __init__(self, |
|
|
datapipe: IterDataPipe, |
|
|
sampler: Type[Sampler] = SequentialSampler, |
|
|
sampler_args: Optional[Tuple] = None, |
|
|
sampler_kwargs: Optional[Dict] = None |
|
|
) -> None: |
|
|
assert isinstance(datapipe, Sized), \ |
|
|
"Sampler class requires input datapipe implemented `__len__`" |
|
|
super().__init__() |
|
|
self.datapipe = datapipe |
|
|
self.sampler_args = () if sampler_args is None else sampler_args |
|
|
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs |
|
|
|
|
|
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) |
|
|
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
|
return iter(self.sampler) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
|
|
|
if isinstance(self.sampler, Sized) and len(self.sampler) >= 0: |
|
|
return len(self.sampler) |
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) |
|
|
|
|
|
|
|
|
@functional_datapipe('shuffle') |
|
|
class ShufflerIterDataPipe(IterDataPipe[T_co]): |
|
|
r""" |
|
|
Shuffles the input DataPipe with a buffer (functional name: ``shuffle``). The buffer |
|
|
with ``buffer_size`` is filled with elements from the datapipe first. Then, |
|
|
each item will be yielded from the buffer by reservoir sampling via iterator. |
|
|
|
|
|
``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the |
|
|
datapipe is not shuffled. In order to fully shuffle all elements from datapipe, |
|
|
``buffer_size`` is required to be greater than or equal to the size of datapipe. |
|
|
|
|
|
When it is used with :class:`torch.utils.data.DataLoader`, the methods to |
|
|
set up random seed are different based on :attr:`num_workers`. |
|
|
|
|
|
For single-process mode (:attr:`num_workers == 0`), the random seed is set before |
|
|
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process |
|
|
mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed |
|
|
for each worker process. |
|
|
|
|
|
Args: |
|
|
datapipe: The IterDataPipe being shuffled |
|
|
buffer_size: The buffer size for shuffling (default to ``10000``) |
|
|
unbatch_level: Specifies if it is necessary to unbatch source data before |
|
|
applying the shuffle |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> dp = IterableWrapper(range(10)) |
|
|
>>> shuffle_dp = dp.shuffle() |
|
|
>>> list(shuffle_dp) |
|
|
[0, 4, 1, 6, 3, 2, 9, 5, 7, 8] |
|
|
""" |
|
|
datapipe: IterDataPipe[T_co] |
|
|
buffer_size: int |
|
|
_buffer: List[T_co] |
|
|
_enabled: bool |
|
|
_seed: Optional[int] |
|
|
_rng: random.Random |
|
|
|
|
|
def __init__(self, |
|
|
datapipe: IterDataPipe[T_co], |
|
|
*, |
|
|
buffer_size: int = 10000, |
|
|
unbatch_level: int = 0 |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self._buffer: List[T_co] = [] |
|
|
assert buffer_size > 0, "buffer_size should be larger than 0" |
|
|
if unbatch_level == 0: |
|
|
self.datapipe = datapipe |
|
|
else: |
|
|
self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) |
|
|
self.buffer_size = buffer_size |
|
|
self._enabled = True |
|
|
self._seed = None |
|
|
self._rng = random.Random() |
|
|
|
|
|
def set_shuffle(self, shuffle=True): |
|
|
self._enabled = shuffle |
|
|
return self |
|
|
|
|
|
def set_seed(self, seed: int): |
|
|
self._seed = seed |
|
|
return self |
|
|
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
|
if not self._enabled: |
|
|
for x in self.datapipe: |
|
|
yield x |
|
|
else: |
|
|
for x in self.datapipe: |
|
|
if len(self._buffer) == self.buffer_size: |
|
|
idx = self._rng.randint(0, len(self._buffer) - 1) |
|
|
val, self._buffer[idx] = self._buffer[idx], x |
|
|
yield val |
|
|
else: |
|
|
self._buffer.append(x) |
|
|
while self._buffer: |
|
|
idx = self._rng.randint(0, len(self._buffer) - 1) |
|
|
yield self._buffer.pop(idx) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
if isinstance(self.datapipe, Sized): |
|
|
return len(self.datapipe) |
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) |
|
|
|
|
|
def reset(self) -> None: |
|
|
self._buffer = [] |
|
|
if self._enabled: |
|
|
if self._seed is None: |
|
|
self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
|
|
self._rng.seed(self._seed) |
|
|
self._seed = None |
|
|
|
|
|
def __getstate__(self): |
|
|
state = ( |
|
|
self.datapipe, |
|
|
self.buffer_size, |
|
|
self._enabled, |
|
|
self._seed, |
|
|
self._buffer, |
|
|
self._rng.getstate(), |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) |
|
|
if IterDataPipe.getstate_hook is not None: |
|
|
return IterDataPipe.getstate_hook(state) |
|
|
return state |
|
|
|
|
|
def __setstate__(self, state): |
|
|
( |
|
|
self.datapipe, |
|
|
self.buffer_size, |
|
|
self._enabled, |
|
|
self._seed, |
|
|
self._buffer, |
|
|
rng_state, |
|
|
self._valid_iterator_id, |
|
|
self._number_of_samples_yielded, |
|
|
) = state |
|
|
self._rng = random.Random() |
|
|
self._rng.setstate(rng_state) |
|
|
|
|
|
def __del__(self): |
|
|
self._buffer.clear() |
|
|
|