|
|
| import numpy as np
|
| from torch.utils.data.sampler import BatchSampler, Sampler
|
|
|
|
|
| class GroupedBatchSampler(BatchSampler):
|
| """
|
| Wraps another sampler to yield a mini-batch of indices.
|
| It enforces that the batch only contain elements from the same group.
|
| It also tries to provide mini-batches which follows an ordering which is
|
| as close as possible to the ordering from the original sampler.
|
| """
|
|
|
| def __init__(self, sampler, group_ids, batch_size):
|
| """
|
| Args:
|
| sampler (Sampler): Base sampler.
|
| group_ids (list[int]): If the sampler produces indices in range [0, N),
|
| `group_ids` must be a list of `N` ints which contains the group id of each sample.
|
| The group ids must be a set of integers in the range [0, num_groups).
|
| batch_size (int): Size of mini-batch.
|
| """
|
| if not isinstance(sampler, Sampler):
|
| raise ValueError(
|
| "sampler should be an instance of "
|
| "torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
| )
|
| self.sampler = sampler
|
| self.group_ids = np.asarray(group_ids)
|
| assert self.group_ids.ndim == 1
|
| self.batch_size = batch_size
|
| groups = np.unique(self.group_ids).tolist()
|
|
|
|
|
| self.buffer_per_group = {k: [] for k in groups}
|
|
|
| def __iter__(self):
|
| for idx in self.sampler:
|
| group_id = self.group_ids[idx]
|
| group_buffer = self.buffer_per_group[group_id]
|
| group_buffer.append(idx)
|
| if len(group_buffer) == self.batch_size:
|
| yield group_buffer[:]
|
| del group_buffer[:]
|
|
|
| def __len__(self):
|
| raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
|
|
|