File size: 2,727 Bytes
ada3f28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Iterator, Optional
from operator import itemgetter
import numpy as np
import torch
from torch.utils.data import (
Dataset,
Sampler,
DistributedSampler,
WeightedRandomSampler
)
class DatasetFromSampler(Dataset):
def __init__(self, sampler: Sampler):
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
return len(self.sampler)
class DistributedSamplerWrapper(DistributedSampler):
""" Convert any Pytorch Sampler to a DistributedSampler """
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self) -> Iterator[int]:
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
class CustomWeightedRandomSampler(WeightedRandomSampler):
""" Generalized WeightedRandomSampler to allow for more than 2^24 samples """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __iter__(self):
rand_tensor = np.random.choice(
range(0, len(self.weights)),
size=self.num_samples,
p=self.weights.numpy() / torch.sum(self.weights).numpy(),
replace=self.replacement
)
rand_tensor = torch.from_numpy(rand_tensor)
return iter(rand_tensor.tolist())
class DistributedWeightedSampler(DistributedSamplerWrapper):
def __init__(
self,
weights,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
weighted_sampler = CustomWeightedRandomSampler(
weights=weights,
num_samples=len(weights),
replacement=False)
super(DistributedWeightedSampler, self).__init__(
sampler=weighted_sampler,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
|