rishikasrinivas commited on
Commit
345ae4e
·
verified ·
1 Parent(s): 03012a7

Create sampler.py

Browse files
Files changed (1) hide show
  1. sampler.py +36 -0
sampler.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ class BalanceSampler(torch.utils.data.sampler.Sampler):
4
+ def __init__ (self, data):
5
+ self.data = data
6
+
7
+ self.labels = torch.stack([self.data[entry_idx][2] for entry_idx in range(len(self.data))])
8
+ self.sums = self.labels.sum(dim=0)
9
+ self.avg = int(torch.mean(self.sums).item())
10
+
11
+
12
+ def __len__(self):
13
+ return len(self.data)
14
+
15
+ def __iter__(self):
16
+ training = []
17
+ minority_classes = torch.where(self.sums < self.avg)[0]
18
+ majority_classes = torch.where(self.sums >= self.avg)[0]
19
+
20
+ for class_idx in minority_classes:
21
+ class_indices = torch.where(self.labels[:, class_idx] == 1)[0]
22
+ oversampled_indices = np.random.choice(class_indices, size=self.avg, replace=True)
23
+ training.extend(oversampled_indices.tolist())
24
+
25
+ # Undersample majority classes
26
+ for class_idx in majority_classes:
27
+ class_indices = torch.where(self.labels[:, class_idx] == 1)[0]
28
+ undersampled_indices = np.random.choice(class_indices, size=self.avg, replace=False)
29
+ training.extend(undersampled_indices.tolist())
30
+ training=np.random.choice(training, size=6300, replace=False)
31
+
32
+
33
+ return iter(training)
34
+
35
+ def __getitem__(self, index):
36
+ return self.data[index]