| from torch.utils.data.sampler import Sampler
|
| import random
|
| import numpy as np
|
|
|
| class BalancedBatchSampler(Sampler):
|
| def __init__(self, dataset, labels=None):
|
| self.labels = labels
|
| self.dataset = dict()
|
| self.balanced_max = 0
|
|
|
| for idx in range(0, len(dataset)):
|
| label = self._get_label(dataset, idx)
|
| if label not in self.dataset:
|
| self.dataset[label] = list()
|
| self.dataset[label].append(idx)
|
| self.balanced_max = len(self.dataset[label]) \
|
| if len(self.dataset[label]) > self.balanced_max else self.balanced_max
|
|
|
|
|
| for label in self.dataset:
|
| while len(self.dataset[label]) < self.balanced_max:
|
| self.dataset[label].append(random.choice(self.dataset[label]))
|
| self.keys = list(self.dataset.keys())
|
| self.currentkey = 0
|
| self.indices = [-1]*len(self.keys)
|
|
|
| def __iter__(self):
|
| while self.indices[self.currentkey] < self.balanced_max - 1:
|
| self.indices[self.currentkey] += 1
|
| yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
|
| self.currentkey = (self.currentkey + 1) % len(self.keys)
|
| self.indices = [-1]*len(self.keys)
|
|
|
| def _get_label(self, dataset, idx, labels = None):
|
| _, label, _ = dataset[idx]
|
| return label
|
|
|
| def __len__(self):
|
| return self.balanced_max*len(self.keys)
|
|
|
| class UnderSampler(Sampler):
|
| def __init__(self, dataset, labels=None, under_sample_rate=0.2):
|
| self.under_sample_rate = under_sample_rate
|
| self.dataset_full = dataset
|
|
|
| if labels is not None:
|
| self.labels = labels
|
| elif hasattr(dataset, 'get_labels') and callable(getattr(dataset, 'get_labels')):
|
| self.labels = dataset.get_labels()
|
| else:
|
| self.labels = [self._get_label(dataset, idx) for idx in range(len(dataset))]
|
|
|
| self.dataset = {}
|
| for idx, label in enumerate(self.labels):
|
| if label not in self.dataset:
|
| self.dataset[label] = []
|
| self.dataset[label].append(idx)
|
|
|
| self.under_represented_label = min(self.dataset, key=lambda x: len(self.dataset[x]))
|
|
|
| self.minority_size = len(self.dataset[self.under_represented_label])
|
| self.majority_sizes = {label: int(len(indices) * self.under_sample_rate)
|
| for label, indices in self.dataset.items()
|
| if label != self.under_represented_label}
|
|
|
| self._length = self.minority_size + sum(self.majority_sizes.values())
|
|
|
| def _get_label(self, dataset, idx, labels=None):
|
| _, label, _ = dataset[idx]
|
| return label
|
|
|
| def __len__(self):
|
| return self._length
|
|
|
| def __iter__(self):
|
| minority_indices = np.array(self.dataset[self.under_represented_label])
|
|
|
| under_sampled_indices = list(minority_indices)
|
|
|
| for label, indices in self.dataset.items():
|
| if label != self.under_represented_label:
|
| indices = np.array(indices)
|
| sample_size = self.majority_sizes[label]
|
|
|
| if sample_size < len(indices):
|
| sampled = indices[np.random.choice(len(indices), sample_size, replace=False)]
|
| under_sampled_indices.extend(sampled)
|
| else:
|
| under_sampled_indices.extend(indices)
|
|
|
| under_sampled_indices = np.array(under_sampled_indices)
|
| np.random.shuffle(under_sampled_indices)
|
|
|
| return iter(under_sampled_indices.tolist()) |