InsafQ commited on
Commit
3e7c536
·
verified ·
1 Parent(s): 356166f

Add _ctgan/conditional.py

Browse files
Files changed (1) hide show
  1. _ctgan/conditional.py +131 -0
_ctgan/conditional.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class ConditionalGenerator(object):
5
+ """A class that generates conditional data based on the given input data and output information.
6
+
7
+ Args:
8
+ data (numpy.ndarray): The input data.
9
+ output_info (list): A list of tuples containing information about the output data.
10
+ log_frequency (bool): A boolean value indicating whether to use logarithmic frequency.
11
+
12
+ Attributes:
13
+ model (list): A list of models.
14
+ interval (numpy.ndarray): An array of intervals.
15
+ n_col (int): The number of columns.
16
+ n_opt (int): The number of options.
17
+ p (numpy.ndarray): An array of probabilities.
18
+ """
19
+ def __init__(self, data, output_info, log_frequency):
20
+ self.model = []
21
+
22
+ start = 0
23
+ skip = False
24
+ max_interval = 0
25
+ counter = 0
26
+ for item in output_info:
27
+ if item[1] == 'tanh':
28
+ start += item[0]
29
+ skip = True
30
+ continue
31
+
32
+ elif item[1] == 'softmax':
33
+ if skip:
34
+ skip = False
35
+ start += item[0]
36
+ continue
37
+
38
+ end = start + item[0]
39
+ max_interval = max(max_interval, end - start)
40
+ counter += 1
41
+ self.model.append(np.argmax(data[:, start:end], axis=-1))
42
+ start = end
43
+
44
+ else:
45
+ raise AssertionError
46
+
47
+ if start != data.shape[1]:
48
+ raise AssertionError
49
+
50
+ self.interval = []
51
+ self.n_col = 0
52
+ self.n_opt = 0
53
+ skip = False
54
+ start = 0
55
+ self.p = np.zeros((counter, max_interval))
56
+ for item in output_info:
57
+ if item[1] == 'tanh':
58
+ skip = True
59
+ start += item[0]
60
+ continue
61
+ elif item[1] == 'softmax':
62
+ if skip:
63
+ start += item[0]
64
+ skip = False
65
+ continue
66
+ end = start + item[0]
67
+ tmp = np.sum(data[:, start:end], axis=0)
68
+ if log_frequency:
69
+ tmp = np.log(tmp + 1)
70
+ tmp = tmp / np.sum(tmp)
71
+ self.p[self.n_col, :item[0]] = tmp
72
+ self.interval.append((self.n_opt, item[0]))
73
+ self.n_opt += item[0]
74
+ self.n_col += 1
75
+ start = end
76
+ else:
77
+ raise AssertionError
78
+
79
+ self.interval = np.asarray(self.interval)
80
+
81
+ def random_choice_prob_index(self, idx):
82
+ """Randomly selects an index based on the given probabilities.
83
+ Args:
84
+ idx (numpy.ndarray): An array of indices.
85
+ Returns:
86
+ numpy.ndarray: An array of randomly selected indices.
87
+ """
88
+ a = self.p[idx]
89
+ r = np.expand_dims(np.random.rand(a.shape[0]), axis=1)
90
+ return (a.cumsum(axis=1) > r).argmax(axis=1)
91
+
92
+ def sample(self, batch):
93
+ """Samples data based on the given batch size.
94
+ Args:
95
+ batch (int): The batch size.
96
+ Returns:
97
+ tuple: A tuple containing the generated data, mask, index, and option.
98
+ """
99
+ if self.n_col == 0:
100
+ return None
101
+
102
+ batch = batch
103
+ idx = np.random.choice(np.arange(self.n_col), batch)
104
+
105
+ vec1 = np.zeros((batch, self.n_opt), dtype='float32')
106
+ mask1 = np.zeros((batch, self.n_col), dtype='float32')
107
+ mask1[np.arange(batch), idx] = 1
108
+ opt1prime = self.random_choice_prob_index(idx)
109
+ opt1 = self.interval[idx, 0] + opt1prime
110
+ vec1[np.arange(batch), opt1] = 1
111
+
112
+ return vec1, mask1, idx, opt1prime
113
+
114
+ def sample_zero(self, batch):
115
+ """Samples zero data based on the given batch size.
116
+ Args:
117
+ batch (int): The batch size.
118
+ Returns:
119
+ numpy.ndarray: An array of generated zero data.
120
+ """
121
+ if self.n_col == 0:
122
+ return None
123
+
124
+ vec = np.zeros((batch, self.n_opt), dtype='float32')
125
+ idx = np.random.choice(np.arange(self.n_col), batch)
126
+ for i in range(batch):
127
+ col = idx[i]
128
+ pick = int(np.random.choice(self.model[col]))
129
+ vec[i, pick + self.interval[col, 0]] = 1
130
+
131
+ return vec