unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
# Copyright (c) OpenMMLab. All rights reserved.
import math
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from .sampler import SAMPLER
import random
from IPython import embed
@SAMPLER.register_module()
class DistributedGroupSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
seed (int, optional): random seed used to shuffle the sampler if
``shuffle=True``. This number should be identical across all
processes in the distributed group. Default: 0.
"""
def __init__(
self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0
):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.seed = seed if seed is not None else 0
assert hasattr(self.dataset, "flag")
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += (
int(
math.ceil(
self.group_sizes[i]
* 1.0
/ self.samples_per_gpu
/ self.num_replicas
)
)
* self.samples_per_gpu
)
self.total_size = self.num_samples * self.num_replicas
# skip iteration for auto-resume
self.skip_iter_at_epoch = False
self.start_iter = 0
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
indices = []
for i, size in enumerate(self.group_sizes):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice = indice[
list(torch.randperm(int(size), generator=g).numpy())
].tolist()
extra = int(
math.ceil(size * 1.0 / self.samples_per_gpu / self.num_replicas)
) * self.samples_per_gpu * self.num_replicas - len(indice)
# pad indice
tmp = indice.copy()
for _ in range(extra // size):
indice.extend(tmp)
indice.extend(tmp[: extra % size])
# print('extra', extra)
# print('size', size)
indices.extend(indice)
assert len(indices) == self.total_size
indices = [
indices[j]
for i in list(
torch.randperm(len(indices) // self.samples_per_gpu, generator=g)
)
for j in range(i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu)
]
# subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples
# skip iteration, only once at the first epoch to resume
if self.skip_iter_at_epoch:
indices = indices[self.start_iter:]
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def skip_iter_at_epoch_x(self, inner_iter):
# previous epoch ends at an iteration in the middle of an epoch
# now, we resume and starts from this specific iteration
if inner_iter > 0:
self.skip_iter_at_epoch = True
self.start_iter = inner_iter
else:
self.skip_iter_at_epoch = False
self.start_iter = 0