| |
| |
| |
| |
| |
|
|
| import math |
| import os |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data.sampler import Sampler |
|
|
|
|
| class SubsetRandomSampler(torch.utils.data.Sampler): |
| """Samples elements randomly from a given list of indices, without |
| replacement. |
| |
| Arguments: |
| indices (sequence): a sequence of indices |
| """ |
|
|
| def __init__(self, indices): |
| self.epoch = 0 |
| self.indices = indices |
|
|
| def __iter__(self): |
| return (self.indices[i] for i in torch.randperm(len(self.indices))) |
|
|
| def __len__(self): |
| return len(self.indices) |
|
|
| def set_epoch(self, epoch): |
| self.epoch = epoch |
|
|
|
|
| class NodeDistributedSampler(Sampler): |
| """Sampler that restricts data loading to a subset of the dataset. |
| It is especially useful in conjunction with |
| :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each |
| process can pass a DistributedSampler instance as a DataLoader sampler, |
| and load a subset of the original dataset that is exclusive to it. |
| .. note:: |
| Dataset is assumed to be of constant size. |
| Arguments: |
| dataset: Dataset used for sampling. |
| num_replicas (optional): Number of processes participating in |
| distributed training. |
| rank (optional): Rank of the current process within num_replicas. |
| """ |
|
|
| def __init__(self, |
| dataset, |
| num_replicas=None, |
| rank=None, |
| local_rank=None, |
| local_size=None): |
| if num_replicas is None: |
| if not dist.is_available(): |
| raise RuntimeError( |
| 'Requires distributed package to be available') |
| num_replicas = dist.get_world_size() |
| if rank is None: |
| if not dist.is_available(): |
| raise RuntimeError( |
| 'Requires distributed package to be available') |
| rank = dist.get_rank() |
| if local_rank is None: |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| if local_size is None: |
| local_size = int(os.environ.get('LOCAL_SIZE', 1)) |
| self.dataset = dataset |
| self.num_replicas = num_replicas |
| self.num_parts = local_size |
| self.rank = rank |
| self.local_rank = local_rank |
| self.epoch = 0 |
| self.num_samples = int( |
| math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
| self.total_size = self.num_samples * self.num_replicas |
|
|
| self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts |
|
|
| def __iter__(self): |
| |
| g = torch.Generator() |
| g.manual_seed(self.epoch) |
|
|
| t = torch.Generator() |
| t.manual_seed(0) |
|
|
| indices = torch.randperm(len(self.dataset), generator=t).tolist() |
| |
| indices = [i for i in indices if i % self.num_parts == self.local_rank] |
|
|
| |
| indices += indices[:(self.total_size_parts - len(indices))] |
| assert len(indices) == self.total_size_parts |
|
|
| |
| indices = indices[self.rank // self.num_parts:self. |
| total_size_parts:self.num_replicas // self.num_parts] |
|
|
| index = torch.randperm(len(indices), generator=g).tolist() |
| indices = list(np.array(indices)[index]) |
|
|
| assert len(indices) == self.num_samples |
|
|
| return iter(indices) |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def set_epoch(self, epoch): |
| self.epoch = epoch |
|
|