| | import math |
| |
|
| | import torch |
| | from torch.utils.data import DistributedSampler as _DistributedSampler |
| |
|
| |
|
| | class DistributedSampler(_DistributedSampler): |
| |
|
| | def __init__(self, |
| | dataset, |
| | num_replicas=None, |
| | rank=None, |
| | shuffle=True, |
| | seed=0): |
| | super().__init__( |
| | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) |
| | |
| | self.seed = seed if seed is not None else 0 |
| |
|
| | def __iter__(self): |
| | |
| | if self.shuffle: |
| | g = torch.Generator() |
| | g.manual_seed(self.epoch + self.seed) |
| | indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| | else: |
| | indices = torch.arange(len(self.dataset)).tolist() |
| |
|
| | |
| | |
| | indices = (indices * |
| | math.ceil(self.total_size / len(indices)))[:self.total_size] |
| | assert len(indices) == self.total_size |
| |
|
| | |
| | indices = indices[self.rank:self.total_size:self.num_replicas] |
| | assert len(indices) == self.num_samples |
| |
|
| | return iter(indices) |
| |
|