| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import Iterator, Optional, TypeVar |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.utils.data.dataset import Dataset |
| | from torch.utils.data.sampler import Sampler |
| |
|
| |
|
| | __all__ = ["DistributedSampler"] |
| |
|
| |
|
| | _T_co = TypeVar("_T_co", covariant=True) |
| |
|
| |
|
| | |
| | class DistributedSampler(Sampler[_T_co]): |
| | r"""Sampler that restricts data loading to a subset of the dataset. |
| | |
| | It is especially useful in conjunction with |
| | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each |
| | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a |
| | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the |
| | original dataset that is exclusive to it. |
| | |
| | .. note:: |
| | Dataset is assumed to be of constant size and that any instance of it always |
| | returns the same elements in the same order. |
| | |
| | Args: |
| | dataset: Dataset used for sampling. |
| | num_replicas (int, optional): Number of processes participating in |
| | distributed training. By default, :attr:`world_size` is retrieved from the |
| | current distributed group. |
| | rank (int, optional): Rank of the current process within :attr:`num_replicas`. |
| | By default, :attr:`rank` is retrieved from the current distributed |
| | group. |
| | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the |
| | indices. |
| | seed (int, optional): random seed used to shuffle the sampler if |
| | :attr:`shuffle=True`. This number should be identical across all |
| | processes in the distributed group. Default: ``0``. |
| | drop_last (bool, optional): if ``True``, then the sampler will drop the |
| | tail of the data to make it evenly divisible across the number of |
| | replicas. If ``False``, the sampler will add extra indices to make |
| | the data evenly divisible across the replicas. Default: ``False``. |
| | |
| | .. warning:: |
| | In distributed mode, calling the :meth:`set_epoch` method at |
| | the beginning of each epoch **before** creating the :class:`DataLoader` iterator |
| | is necessary to make shuffling work properly across multiple epochs. Otherwise, |
| | the same ordering will be always used. |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +SKIP |
| | >>> sampler = DistributedSampler(dataset) if is_distributed else None |
| | >>> loader = DataLoader(dataset, shuffle=(sampler is None), |
| | ... sampler=sampler) |
| | >>> for epoch in range(start_epoch, n_epochs): |
| | ... if is_distributed: |
| | ... sampler.set_epoch(epoch) |
| | ... train(loader) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dataset: Dataset, |
| | num_replicas: Optional[int] = None, |
| | rank: Optional[int] = None, |
| | shuffle: bool = True, |
| | seed: int = 0, |
| | drop_last: bool = False, |
| | consumed_samples=0, |
| | ) -> None: |
| | if num_replicas is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | num_replicas = dist.get_world_size() |
| | if rank is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | rank = dist.get_rank() |
| | if rank >= num_replicas or rank < 0: |
| | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") |
| | self.dataset = dataset |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.epoch = 0 |
| | self.drop_last = drop_last |
| | |
| | |
| | if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
| | |
| | |
| | |
| | self.num_samples = math.ceil( |
| | (len(self.dataset) - self.num_replicas) / self.num_replicas |
| | ) |
| | else: |
| | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) |
| | self.total_size = self.num_samples * self.num_replicas |
| | self.shuffle = shuffle |
| | self.seed = seed |
| | self.consumed_indicies = consumed_samples // self.num_replicas |
| |
|
| | def __iter__(self) -> Iterator[_T_co]: |
| | if self.shuffle: |
| | |
| | g = torch.Generator() |
| | g.manual_seed(self.seed + self.epoch) |
| | indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| | else: |
| | indices = list(range(len(self.dataset))) |
| |
|
| | if not self.drop_last: |
| | |
| | padding_size = self.total_size - len(indices) |
| | if padding_size <= len(indices): |
| | indices += indices[:padding_size] |
| | else: |
| | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] |
| | else: |
| | |
| | indices = indices[: self.total_size] |
| | assert len(indices) == self.total_size |
| |
|
| | |
| | indices = indices[self.rank : self.total_size : self.num_replicas] |
| | |
| | indices = indices[self.consumed_indicies :] |
| | assert len(indices) == self.num_samples - self.consumed_indicies |
| |
|
| | return iter(indices) |
| |
|
| | def __len__(self) -> int: |
| | return self.num_samples - self.consumed_indicies |
| |
|
| | def set_epoch(self, epoch: int, consumed_samples=0) -> None: |
| | r""" |
| | Set the epoch for this sampler. |
| | |
| | When :attr:`shuffle=True`, this ensures all replicas |
| | use a different random ordering for each epoch. Otherwise, the next iteration of this |
| | sampler will yield the same ordering. |
| | |
| | Args: |
| | epoch (int): Epoch number. |
| | """ |
| | self.epoch = epoch |
| | self.consumed_indicies = consumed_samples // self.num_replicas |
| |
|