File size: 4,747 Bytes
663494c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | # Copyright (c) OpenMMLab. All rights reserved.
import math
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from .sampler import SAMPLER
import random
from IPython import embed
@SAMPLER.register_module()
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
seed (int, optional): random seed used to shuffle the sampler if
``shuffle=True``. This number should be identical across all
processes in the distributed group. Default: 0.
"""
def __init__(
self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0
):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.seed = seed if seed is not None else 0
assert hasattr(self.dataset, "flag")
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += (
int(
math.ceil(
self.group_sizes[i]
* 1.0
/ self.samples_per_gpu
/ self.num_replicas
)
)
* self.samples_per_gpu
)
self.total_size = self.num_samples * self.num_replicas
# skip iteration for auto-resume
self.skip_iter_at_epoch = False
self.start_iter = 0
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice = indice[
list(torch.randperm(int(size), generator=g).numpy())
].tolist()
extra = int(
math.ceil(size * 1.0 / self.samples_per_gpu / self.num_replicas)
) * self.samples_per_gpu * self.num_replicas - len(indice)
# pad indice
tmp = indice.copy()
for _ in range(extra // size):
indice.extend(tmp)
indice.extend(tmp[: extra % size])
# print('extra', extra)
# print('size', size)
indices.extend(indice)
assert len(indices) == self.total_size
indices = [
indices[j]
for i in list(
torch.randperm(len(indices) // self.samples_per_gpu, generator=g)
)
for j in range(i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples
# skip iteration, only once at the first epoch to resume
if self.skip_iter_at_epoch:
indices = indices[self.start_iter:]
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def skip_iter_at_epoch_x(self, inner_iter):
# previous epoch ends at an iteration in the middle of an epoch
# now, we resume and starts from this specific iteration
if inner_iter > 0:
self.skip_iter_at_epoch = True
self.start_iter = inner_iter
else:
self.skip_iter_at_epoch = False
self.start_iter = 0 |