TabGAN / _ctgan /sampler.py
InsafQ's picture
Add _ctgan/sampler.py
67b97d0 verified
import numpy as np
class Sampler:
"""docstring for Sampler."""
def __init__(self, data, output_info):
super(Sampler, self).__init__()
self.data = data
self.model = []
self.n = len(data)
st = 0
skip = False
for item in output_info:
if item[1] == 'tanh':
st += item[0]
skip = True
elif item[1] == 'softmax':
if skip:
skip = False
st += item[0]
continue
ed = st + item[0]
tmp = []
for j in range(item[0]):
tmp.append(np.nonzero(data[:, st + j])[0])
self.model.append(tmp)
st = ed
else:
raise AssertionError
if st != data.shape[1]:
raise AssertionError
def sample(self, n, col, opt):
if col is None:
idx = np.random.choice(np.arange(self.n), n)
return self.data[idx]
idx = []
for c, o in zip(col, opt):
idx.append(np.random.choice(self.model[c][o]))
return self.data[idx]