| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data.sampler import Sampler | |
| from tqdm import * | |
| class BalancedSampler(Sampler): | |
| def __init__(self, data_source, batch_size, images_per_class=3): | |
| self.data_source = data_source | |
| self.ys = np.array(data_source.all_labels) | |
| self.num_groups = batch_size // images_per_class | |
| self.batch_size = batch_size | |
| self.num_instances = images_per_class | |
| self.num_samples = len(self.ys) | |
| self.num_classes = len(set(self.ys)) | |
| def __len__(self): | |
| return self.num_samples | |
| def __iter__(self): | |
| num_batches = len(self.data_source) // self.batch_size | |
| ret = [] | |
| while num_batches > 0: | |
| sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False) | |
| for i in range(len(sampled_classes)): | |
| ith_class_idxs = np.nonzero(self.ys == sampled_classes[i])[0] | |
| class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True) | |
| ret.extend(np.random.permutation(class_sel)) | |
| num_batches -= 1 | |
| return iter(ret) |