Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import itertools | |
| import torch | |
| from torch.utils.data.sampler import BatchSampler | |
| from torch.utils.data.sampler import Sampler | |
| class GroupedBatchSampler(BatchSampler): | |
| """ | |
| Wraps another sampler to yield a mini-batch of indices. | |
| It enforces that elements from the same group should appear in groups of batch_size. | |
| It also tries to provide mini-batches which follows an ordering which is | |
| as close as possible to the ordering from the original sampler. | |
| Arguments: | |
| sampler (Sampler): Base sampler. | |
| batch_size (int): Size of mini-batch. | |
| drop_uneven (bool): If ``True``, the sampler will drop the batches whose | |
| size is less than ``batch_size`` | |
| """ | |
| def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): | |
| if not isinstance(sampler, Sampler): | |
| raise ValueError( | |
| "sampler should be an instance of " | |
| "torch.utils.data.Sampler, but got sampler={}".format(sampler) | |
| ) | |
| self.sampler = sampler | |
| self.group_ids = torch.as_tensor(group_ids) | |
| assert self.group_ids.dim() == 1 | |
| self.batch_size = batch_size | |
| self.drop_uneven = drop_uneven | |
| self.groups = torch.unique(self.group_ids).sort(0)[0] | |
| self._can_reuse_batches = False | |
| def _prepare_batches(self): | |
| dataset_size = len(self.group_ids) | |
| # get the sampled indices from the sampler | |
| sampled_ids = torch.as_tensor(list(self.sampler)) | |
| # potentially not all elements of the dataset were sampled | |
| # by the sampler (e.g., DistributedSampler). | |
| # construct a tensor which contains -1 if the element was | |
| # not sampled, and a non-negative number indicating the | |
| # order where the element was sampled. | |
| # for example. if sampled_ids = [3, 1] and dataset_size = 5, | |
| # the order is [-1, 1, -1, 0, -1] | |
| order = torch.full((dataset_size,), -1, dtype=torch.int64) | |
| order[sampled_ids] = torch.arange(len(sampled_ids)) | |
| # get a mask with the elements that were sampled | |
| mask = order >= 0 | |
| # find the elements that belong to each individual cluster | |
| clusters = [(self.group_ids == i) & mask for i in self.groups] | |
| # get relative order of the elements inside each cluster | |
| # that follows the order from the sampler | |
| relative_order = [order[cluster] for cluster in clusters] | |
| # with the relative order, find the absolute order in the | |
| # sampled space | |
| permutation_ids = [s[s.sort()[1]] for s in relative_order] | |
| # permute each cluster so that they follow the order from | |
| # the sampler | |
| permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] | |
| # splits each cluster in batch_size, and merge as a list of tensors | |
| splits = [c.split(self.batch_size) for c in permuted_clusters] | |
| merged = tuple(itertools.chain.from_iterable(splits)) | |
| # now each batch internally has the right order, but | |
| # they are grouped by clusters. Find the permutation between | |
| # different batches that brings them as close as possible to | |
| # the order that we have in the sampler. For that, we will consider the | |
| # ordering as coming from the first element of each batch, and sort | |
| # correspondingly | |
| first_element_of_batch = [t[0].item() for t in merged] | |
| # get and inverse mapping from sampled indices and the position where | |
| # they occur (as returned by the sampler) | |
| inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} | |
| # from the first element in each batch, get a relative ordering | |
| first_index_of_batch = torch.as_tensor( | |
| [inv_sampled_ids_map[s] for s in first_element_of_batch] | |
| ) | |
| # permute the batches so that they approximately follow the order | |
| # from the sampler | |
| permutation_order = first_index_of_batch.sort(0)[1].tolist() | |
| # finally, permute the batches | |
| batches = [merged[i].tolist() for i in permutation_order] | |
| if self.drop_uneven: | |
| kept = [] | |
| for batch in batches: | |
| if len(batch) == self.batch_size: | |
| kept.append(batch) | |
| batches = kept | |
| return batches | |
| def __iter__(self): | |
| if self._can_reuse_batches: | |
| batches = self._batches | |
| self._can_reuse_batches = False | |
| else: | |
| batches = self._prepare_batches() | |
| self._batches = batches | |
| return iter(batches) | |
| def __len__(self): | |
| if not hasattr(self, "_batches"): | |
| self._batches = self._prepare_batches() | |
| self._can_reuse_batches = True | |
| return len(self._batches) | |