| | import math |
| | import torch |
| | from torch.utils.data import Sampler |
| | import torch.distributed as dist |
| |
|
| |
|
| | class OrderedDistributedSampler(Sampler): |
| | """Sampler that restricts data loading to a subset of the dataset. |
| | It is especially useful in conjunction with |
| | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each |
| | process can pass a DistributedSampler instance as a DataLoader sampler, |
| | and load a subset of the original dataset that is exclusive to it. |
| | .. note:: |
| | Dataset is assumed to be of constant size. |
| | Arguments: |
| | dataset: Dataset used for sampling. |
| | num_replicas (optional): Number of processes participating in |
| | distributed training. |
| | rank (optional): Rank of the current process within num_replicas. |
| | """ |
| |
|
| | def __init__(self, dataset, num_replicas=None, rank=None): |
| | if num_replicas is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | num_replicas = dist.get_world_size() |
| | if rank is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | rank = dist.get_rank() |
| | self.dataset = dataset |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
| | self.total_size = self.num_samples * self.num_replicas |
| |
|
| | def __iter__(self): |
| | indices = list(range(len(self.dataset))) |
| |
|
| | |
| | indices += indices[:(self.total_size - len(indices))] |
| | assert len(indices) == self.total_size |
| |
|
| | |
| | indices = indices[self.rank:self.total_size:self.num_replicas] |
| | assert len(indices) == self.num_samples |
| |
|
| | return iter(indices) |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|