File size: 1,248 Bytes
345ae4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np 
class BalanceSampler(torch.utils.data.sampler.Sampler):
  def __init__ (self, data):
    self.data = data

    self.labels = torch.stack([self.data[entry_idx][2] for entry_idx in range(len(self.data))])
    self.sums = self.labels.sum(dim=0)
    self.avg = int(torch.mean(self.sums).item())


  def __len__(self):
    return len(self.data)

  def __iter__(self):
    training = []
    minority_classes = torch.where(self.sums < self.avg)[0]
    majority_classes = torch.where(self.sums >= self.avg)[0]

    for class_idx in minority_classes:
        class_indices = torch.where(self.labels[:, class_idx] == 1)[0]
        oversampled_indices = np.random.choice(class_indices, size=self.avg, replace=True)
        training.extend(oversampled_indices.tolist())

        # Undersample majority classes
    for class_idx in majority_classes:
        class_indices = torch.where(self.labels[:, class_idx] == 1)[0]
        undersampled_indices = np.random.choice(class_indices, size=self.avg, replace=False)
        training.extend(undersampled_indices.tolist())
    training=np.random.choice(training, size=6300, replace=False)


    return iter(training)

  def __getitem__(self, index):
        return self.data[index]