import numpy as np class Sampler: """docstring for Sampler.""" def __init__(self, data, output_info): super(Sampler, self).__init__() self.data = data self.model = [] self.n = len(data) st = 0 skip = False for item in output_info: if item[1] == 'tanh': st += item[0] skip = True elif item[1] == 'softmax': if skip: skip = False st += item[0] continue ed = st + item[0] tmp = [] for j in range(item[0]): tmp.append(np.nonzero(data[:, st + j])[0]) self.model.append(tmp) st = ed else: raise AssertionError if st != data.shape[1]: raise AssertionError def sample(self, n, col, opt): if col is None: idx = np.random.choice(np.arange(self.n), n) return self.data[idx] idx = [] for c, o in zip(col, opt): idx.append(np.random.choice(self.model[c][o])) return self.data[idx]