File size: 4,277 Bytes
3e7c536 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import numpy as np
class ConditionalGenerator(object):
"""A class that generates conditional data based on the given input data and output information.
Args:
data (numpy.ndarray): The input data.
output_info (list): A list of tuples containing information about the output data.
log_frequency (bool): A boolean value indicating whether to use logarithmic frequency.
Attributes:
model (list): A list of models.
interval (numpy.ndarray): An array of intervals.
n_col (int): The number of columns.
n_opt (int): The number of options.
p (numpy.ndarray): An array of probabilities.
"""
def __init__(self, data, output_info, log_frequency):
self.model = []
start = 0
skip = False
max_interval = 0
counter = 0
for item in output_info:
if item[1] == 'tanh':
start += item[0]
skip = True
continue
elif item[1] == 'softmax':
if skip:
skip = False
start += item[0]
continue
end = start + item[0]
max_interval = max(max_interval, end - start)
counter += 1
self.model.append(np.argmax(data[:, start:end], axis=-1))
start = end
else:
raise AssertionError
if start != data.shape[1]:
raise AssertionError
self.interval = []
self.n_col = 0
self.n_opt = 0
skip = False
start = 0
self.p = np.zeros((counter, max_interval))
for item in output_info:
if item[1] == 'tanh':
skip = True
start += item[0]
continue
elif item[1] == 'softmax':
if skip:
start += item[0]
skip = False
continue
end = start + item[0]
tmp = np.sum(data[:, start:end], axis=0)
if log_frequency:
tmp = np.log(tmp + 1)
tmp = tmp / np.sum(tmp)
self.p[self.n_col, :item[0]] = tmp
self.interval.append((self.n_opt, item[0]))
self.n_opt += item[0]
self.n_col += 1
start = end
else:
raise AssertionError
self.interval = np.asarray(self.interval)
def random_choice_prob_index(self, idx):
"""Randomly selects an index based on the given probabilities.
Args:
idx (numpy.ndarray): An array of indices.
Returns:
numpy.ndarray: An array of randomly selected indices.
"""
a = self.p[idx]
r = np.expand_dims(np.random.rand(a.shape[0]), axis=1)
return (a.cumsum(axis=1) > r).argmax(axis=1)
def sample(self, batch):
"""Samples data based on the given batch size.
Args:
batch (int): The batch size.
Returns:
tuple: A tuple containing the generated data, mask, index, and option.
"""
if self.n_col == 0:
return None
batch = batch
idx = np.random.choice(np.arange(self.n_col), batch)
vec1 = np.zeros((batch, self.n_opt), dtype='float32')
mask1 = np.zeros((batch, self.n_col), dtype='float32')
mask1[np.arange(batch), idx] = 1
opt1prime = self.random_choice_prob_index(idx)
opt1 = self.interval[idx, 0] + opt1prime
vec1[np.arange(batch), opt1] = 1
return vec1, mask1, idx, opt1prime
def sample_zero(self, batch):
"""Samples zero data based on the given batch size.
Args:
batch (int): The batch size.
Returns:
numpy.ndarray: An array of generated zero data.
"""
if self.n_col == 0:
return None
vec = np.zeros((batch, self.n_opt), dtype='float32')
idx = np.random.choice(np.arange(self.n_col), batch)
for i in range(batch):
col = idx[i]
pick = int(np.random.choice(self.model[col]))
vec[i, pick + self.interval[col, 0]] = 1
return vec
|