import numpy as np class ConditionalGenerator(object): """A class that generates conditional data based on the given input data and output information. Args: data (numpy.ndarray): The input data. output_info (list): A list of tuples containing information about the output data. log_frequency (bool): A boolean value indicating whether to use logarithmic frequency. Attributes: model (list): A list of models. interval (numpy.ndarray): An array of intervals. n_col (int): The number of columns. n_opt (int): The number of options. p (numpy.ndarray): An array of probabilities. """ def __init__(self, data, output_info, log_frequency): self.model = [] start = 0 skip = False max_interval = 0 counter = 0 for item in output_info: if item[1] == 'tanh': start += item[0] skip = True continue elif item[1] == 'softmax': if skip: skip = False start += item[0] continue end = start + item[0] max_interval = max(max_interval, end - start) counter += 1 self.model.append(np.argmax(data[:, start:end], axis=-1)) start = end else: raise AssertionError if start != data.shape[1]: raise AssertionError self.interval = [] self.n_col = 0 self.n_opt = 0 skip = False start = 0 self.p = np.zeros((counter, max_interval)) for item in output_info: if item[1] == 'tanh': skip = True start += item[0] continue elif item[1] == 'softmax': if skip: start += item[0] skip = False continue end = start + item[0] tmp = np.sum(data[:, start:end], axis=0) if log_frequency: tmp = np.log(tmp + 1) tmp = tmp / np.sum(tmp) self.p[self.n_col, :item[0]] = tmp self.interval.append((self.n_opt, item[0])) self.n_opt += item[0] self.n_col += 1 start = end else: raise AssertionError self.interval = np.asarray(self.interval) def random_choice_prob_index(self, idx): """Randomly selects an index based on the given probabilities. Args: idx (numpy.ndarray): An array of indices. Returns: numpy.ndarray: An array of randomly selected indices. """ a = self.p[idx] r = np.expand_dims(np.random.rand(a.shape[0]), axis=1) return (a.cumsum(axis=1) > r).argmax(axis=1) def sample(self, batch): """Samples data based on the given batch size. Args: batch (int): The batch size. Returns: tuple: A tuple containing the generated data, mask, index, and option. """ if self.n_col == 0: return None batch = batch idx = np.random.choice(np.arange(self.n_col), batch) vec1 = np.zeros((batch, self.n_opt), dtype='float32') mask1 = np.zeros((batch, self.n_col), dtype='float32') mask1[np.arange(batch), idx] = 1 opt1prime = self.random_choice_prob_index(idx) opt1 = self.interval[idx, 0] + opt1prime vec1[np.arange(batch), opt1] = 1 return vec1, mask1, idx, opt1prime def sample_zero(self, batch): """Samples zero data based on the given batch size. Args: batch (int): The batch size. Returns: numpy.ndarray: An array of generated zero data. """ if self.n_col == 0: return None vec = np.zeros((batch, self.n_opt), dtype='float32') idx = np.random.choice(np.arange(self.n_col), batch) for i in range(batch): col = idx[i] pick = int(np.random.choice(self.model[col])) vec[i, pick + self.interval[col, 0]] = 1 return vec