| |
| |
| import itertools |
| import math |
| import random |
| from random import shuffle |
| from typing import Iterator |
| from typing import Optional |
| from typing import TypeVar |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import Dataset |
| from torch.utils.data import Sampler |
|
|
| __all__ = [ |
| "DistributedBucketSampler", |
| ] |
|
|
| T_co = TypeVar("T_co", covariant=True) |
|
|
|
|
| class DistributedBucketSampler(Sampler[T_co]): |
| r""" |
| sort the dataset wrt. input length |
| divide samples into buckets |
| sort within buckets |
| divide buckets into batches |
| sort batches |
| """ |
|
|
| def __init__( |
| self, |
| dataset: Dataset, |
| num_replicas: Optional[int] = None, |
| rank: Optional[int] = None, |
| shuffle: bool = True, |
| seed: int = 0, |
| drop_last: bool = False, |
| batch_size: int = 32, |
| ) -> None: |
| if num_replicas is None: |
| if not dist.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1 |
| if rank is None: |
| if not dist.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| rank = dist.get_rank() if torch.cuda.is_available() else 0 |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(rank) |
| if rank >= num_replicas or rank < 0: |
| raise ValueError( |
| "Invalid rank {}, rank should be in the interval" |
| " [0, {}]".format(rank, num_replicas - 1) |
| ) |
| self.dataset = dataset |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.epoch = 0 |
| self.drop_last = drop_last |
| |
| |
| if ( |
| self.drop_last and len(self.dataset) % self.num_replicas != 0 |
| ): |
| |
| |
| |
| self.num_samples = math.ceil( |
| (len(self.dataset) - self.num_replicas) |
| / self.num_replicas |
| ) |
| else: |
| self.num_samples = math.ceil( |
| len(self.dataset) / self.num_replicas |
| ) |
| self.total_size = self.num_samples * self.num_replicas |
| self.shuffle = shuffle |
| self.seed = seed |
| self.batch_size = batch_size |
| self.id_with_length = self._get_sample_lengths() |
| self.id_buckets = self.make_buckets(bucket_width=2.0) |
|
|
| def _get_sample_lengths(self): |
| id_with_lengths = [] |
| for i in range(len(self.dataset)): |
| id_with_lengths.append((i, self.dataset.get_sample_length(i))) |
| id_with_lengths.sort(key=lambda x: x[1]) |
| return id_with_lengths |
|
|
| def make_buckets(self, bucket_width: float = 2.0): |
| buckets = [] |
| cur = [] |
| max_sec = bucket_width |
| for id, sec in self.id_with_length: |
| if sec < max_sec: |
| cur.append(id) |
| else: |
| buckets.append(cur) |
| cur = [id] |
| max_sec += bucket_width |
| if len(cur) > 0: |
| buckets.append(cur) |
| return buckets |
|
|
| def __iter__(self) -> Iterator[T_co]: |
| if self.shuffle: |
| |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
| random.seed(self.epoch + self.seed) |
| shuffled_bucket = [] |
| for buc in self.id_buckets: |
| buc_copy = buc.copy() |
| shuffle(buc_copy) |
| shuffled_bucket.append(buc_copy) |
| grouped_batch_size = self.batch_size * self.num_replicas |
| shuffled_bucket = list(itertools.chain(*shuffled_bucket)) |
| n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) |
| batches = [ |
| shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] |
| for b in range(n_batch) |
| ] |
| shuffle(batches) |
| indices = list(itertools.chain(*batches)) |
| else: |
| |
| indices = list(range(len(self.dataset))) |
|
|
| if not self.drop_last: |
| |
| padding_size = self.total_size - len(indices) |
| if padding_size <= len(indices): |
| indices += indices[:padding_size] |
| else: |
| indices += (indices * math.ceil(padding_size / len(indices)))[ |
| :padding_size |
| ] |
| else: |
| |
| indices = indices[: self.total_size] |
| assert len(indices) == self.total_size |
|
|
| |
| indices = indices[self.rank : self.total_size : self.num_replicas] |
| assert len(indices) == self.num_samples |
|
|
| return iter(indices) |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
| def set_epoch(self, epoch: int) -> None: |
| r""" |
| Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas |
| use a different random ordering for each epoch. Otherwise, the next iteration of this |
| sampler will yield the same ordering. |
| |
| Args: |
| epoch (int): Epoch number. |
| """ |
| self.epoch = epoch |
|
|