Mhara's picture
Upload folder using huggingface_hub
dae5c90 verified
Raw
History Blame Contribute Delete
3.73 kB
from torch.utils.data.sampler import Sampler
import random
import numpy as np
class BalancedBatchSampler(Sampler):
def __init__(self, dataset, labels=None):
self.labels = labels
self.dataset = dict()
self.balanced_max = 0
for idx in range(0, len(dataset)):
label = self._get_label(dataset, idx)
if label not in self.dataset:
self.dataset[label] = list()
self.dataset[label].append(idx)
self.balanced_max = len(self.dataset[label]) \
if len(self.dataset[label]) > self.balanced_max else self.balanced_max
# Oversample the classes with fewer elements than the max
for label in self.dataset:
while len(self.dataset[label]) < self.balanced_max:
self.dataset[label].append(random.choice(self.dataset[label]))
self.keys = list(self.dataset.keys())
self.currentkey = 0
self.indices = [-1]*len(self.keys)
def __iter__(self):
while self.indices[self.currentkey] < self.balanced_max - 1:
self.indices[self.currentkey] += 1
yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
self.currentkey = (self.currentkey + 1) % len(self.keys)
self.indices = [-1]*len(self.keys)
def _get_label(self, dataset, idx, labels = None):
_, label, _ = dataset[idx]
return label
def __len__(self):
return self.balanced_max*len(self.keys)
class UnderSampler(Sampler):
def __init__(self, dataset, labels=None, under_sample_rate=0.2):
self.under_sample_rate = under_sample_rate
self.dataset_full = dataset
if labels is not None:
self.labels = labels
elif hasattr(dataset, 'get_labels') and callable(getattr(dataset, 'get_labels')):
self.labels = dataset.get_labels()
else:
self.labels = [self._get_label(dataset, idx) for idx in range(len(dataset))]
self.dataset = {}
for idx, label in enumerate(self.labels):
if label not in self.dataset:
self.dataset[label] = []
self.dataset[label].append(idx)
self.under_represented_label = min(self.dataset, key=lambda x: len(self.dataset[x]))
self.minority_size = len(self.dataset[self.under_represented_label])
self.majority_sizes = {label: int(len(indices) * self.under_sample_rate)
for label, indices in self.dataset.items()
if label != self.under_represented_label}
self._length = self.minority_size + sum(self.majority_sizes.values())
def _get_label(self, dataset, idx, labels=None):
_, label, _ = dataset[idx]
return label
def __len__(self):
return self._length
def __iter__(self):
minority_indices = np.array(self.dataset[self.under_represented_label])
under_sampled_indices = list(minority_indices)
for label, indices in self.dataset.items():
if label != self.under_represented_label:
indices = np.array(indices)
sample_size = self.majority_sizes[label]
if sample_size < len(indices):
sampled = indices[np.random.choice(len(indices), sample_size, replace=False)]
under_sampled_indices.extend(sampled)
else:
under_sampled_indices.extend(indices)
under_sampled_indices = np.array(under_sampled_indices)
np.random.shuffle(under_sampled_indices)
return iter(under_sampled_indices.tolist())