| import math |
| import random |
| from typing import Callable, List, Union |
|
|
| from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler |
|
|
|
|
| class SubsetSampler(Sampler): |
| """ |
| Samples elements sequentially from a given list of indices. |
| |
| Args: |
| indices (list): a sequence of indices |
| """ |
|
|
| def __init__(self, indices): |
| super().__init__(indices) |
| self.indices = indices |
|
|
| def __iter__(self): |
| return (self.indices[i] for i in range(len(self.indices))) |
|
|
| def __len__(self): |
| return len(self.indices) |
|
|
|
|
| class PerfectBatchSampler(Sampler): |
| """ |
| Samples a mini-batch of indices for a balanced class batching |
| |
| Args: |
| dataset_items(list): dataset items to sample from. |
| classes (list): list of classes of dataset_items to sample from. |
| batch_size (int): total number of samples to be sampled in a mini-batch. |
| num_gpus (int): number of GPU in the data parallel mode. |
| shuffle (bool): if True, samples randomly, otherwise samples sequentially. |
| drop_last (bool): if True, drops last incomplete batch. |
| """ |
|
|
| def __init__( |
| self, |
| dataset_items, |
| classes, |
| batch_size, |
| num_classes_in_batch, |
| num_gpus=1, |
| shuffle=True, |
| drop_last=False, |
| label_key="class_name", |
| ): |
| super().__init__(dataset_items) |
| assert ( |
| batch_size % (num_classes_in_batch * num_gpus) == 0 |
| ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." |
|
|
| label_indices = {} |
| for idx, item in enumerate(dataset_items): |
| label = item[label_key] |
| if label not in label_indices.keys(): |
| label_indices[label] = [idx] |
| else: |
| label_indices[label].append(idx) |
|
|
| if shuffle: |
| self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] |
| else: |
| self._samplers = [SubsetSampler(label_indices[key]) for key in classes] |
|
|
| self._batch_size = batch_size |
| self._drop_last = drop_last |
| self._dp_devices = num_gpus |
| self._num_classes_in_batch = num_classes_in_batch |
|
|
| def __iter__(self): |
| batch = [] |
| if self._num_classes_in_batch != len(self._samplers): |
| valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) |
| else: |
| valid_samplers_idx = None |
|
|
| iters = [iter(s) for s in self._samplers] |
| done = False |
|
|
| while True: |
| b = [] |
| for i, it in enumerate(iters): |
| if valid_samplers_idx is not None and i not in valid_samplers_idx: |
| continue |
| idx = next(it, None) |
| if idx is None: |
| done = True |
| break |
| b.append(idx) |
| if done: |
| break |
| batch += b |
| if len(batch) == self._batch_size: |
| yield batch |
| batch = [] |
| if valid_samplers_idx is not None: |
| valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) |
|
|
| if not self._drop_last: |
| if len(batch) > 0: |
| groups = len(batch) // self._num_classes_in_batch |
| if groups % self._dp_devices == 0: |
| yield batch |
| else: |
| batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] |
| if len(batch) > 0: |
| yield batch |
|
|
| def __len__(self): |
| class_batch_size = self._batch_size // self._num_classes_in_batch |
| return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) |
|
|
|
|
| def identity(x): |
| return x |
|
|
|
|
| class SortedSampler(Sampler): |
| """Samples elements sequentially, always in the same order. |
| |
| Taken from https://github.com/PetrochukM/PyTorch-NLP |
| |
| Args: |
| data (iterable): Iterable data. |
| sort_key (callable): Specifies a function of one argument that is used to extract a |
| numerical comparison key from each list element. |
| |
| Example: |
| >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) |
| [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] |
| |
| """ |
|
|
| def __init__(self, data, sort_key: Callable = identity): |
| super().__init__(data) |
| self.data = data |
| self.sort_key = sort_key |
| zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] |
| zip_ = sorted(zip_, key=lambda r: r[1]) |
| self.sorted_indexes = [item[0] for item in zip_] |
|
|
| def __iter__(self): |
| return iter(self.sorted_indexes) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
|
|
| class BucketBatchSampler(BatchSampler): |
| """Bucket batch sampler |
| |
| Adapted from https://github.com/PetrochukM/PyTorch-NLP |
| |
| Args: |
| sampler (torch.data.utils.sampler.Sampler): |
| batch_size (int): Size of mini-batch. |
| drop_last (bool): If `True` the sampler will drop the last batch if its size would be less |
| than `batch_size`. |
| data (list): List of data samples. |
| sort_key (callable, optional): Callable to specify a comparison key for sorting. |
| bucket_size_multiplier (int, optional): Buckets are of size |
| `batch_size * bucket_size_multiplier`. |
| |
| Example: |
| >>> sampler = WeightedRandomSampler(weights, len(weights)) |
| >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) |
| """ |
|
|
| def __init__( |
| self, |
| sampler, |
| data, |
| batch_size, |
| drop_last, |
| sort_key: Union[Callable, List] = identity, |
| bucket_size_multiplier=100, |
| ): |
| super().__init__(sampler, batch_size, drop_last) |
| self.data = data |
| self.sort_key = sort_key |
| _bucket_size = batch_size * bucket_size_multiplier |
| if hasattr(sampler, "__len__"): |
| _bucket_size = min(_bucket_size, len(sampler)) |
| self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) |
|
|
| def __iter__(self): |
| for idxs in self.bucket_sampler: |
| bucket_data = [self.data[idx] for idx in idxs] |
| sorted_sampler = SortedSampler(bucket_data, self.sort_key) |
| for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): |
| sorted_idxs = [idxs[i] for i in batch_idx] |
| yield sorted_idxs |
|
|
| def __len__(self): |
| if self.drop_last: |
| return len(self.sampler) // self.batch_size |
| return math.ceil(len(self.sampler) / self.batch_size) |
|
|