InsafQ commited on
Commit
91b2f4f
·
verified ·
1 Parent(s): 67b97d0

Add _ctgan/synthesizer.py

Browse files
Files changed (1) hide show
  1. _ctgan/synthesizer.py +310 -0
_ctgan/synthesizer.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import optim
6
+ from torch.nn import functional
7
+ from tqdm.autonotebook import tqdm
8
+
9
+ from _ctgan.conditional import ConditionalGenerator
10
+ from _ctgan.models import Discriminator, Generator
11
+ from _ctgan.sampler import Sampler
12
+ from _ctgan.transformer import DataTransformer
13
+
14
+
15
+ class EarlyStopping:
16
+ """Early stops the training if validation loss doesn't improve after a given patience."""
17
+
18
+ def __init__(self, patience=7, verbose=False, delta=0):
19
+ """
20
+ Args:
21
+ patience (int): How long to wait after last time validation loss improved.
22
+ Default: 7
23
+ delta (float): Minimum change in the monitored quantity to qualify as an improvement.
24
+ Default: 0
25
+ """
26
+ self.patience = patience
27
+ self.counter = 0
28
+ self.best_score = None
29
+ self.early_stop = False
30
+ self.val_loss_min = np.inf
31
+ self.delta = delta
32
+ self.verbose = verbose
33
+
34
+ def __call__(self, val_loss):
35
+
36
+ score = -val_loss
37
+
38
+ if self.best_score is None:
39
+ self.best_score = score
40
+ elif score < self.best_score + self.delta:
41
+ self.counter += 1
42
+ if self.counter >= self.patience:
43
+ logging.info("Early stoping for GAN. Best score: {:.2f} with patience = {}".format(self.best_score,
44
+ self.patience))
45
+ self.early_stop = True
46
+ else:
47
+ self.best_score = score
48
+ self.counter = 0
49
+
50
+
51
+ class _CTGANSynthesizer:
52
+ """Conditional Table GAN Synthesizer.
53
+
54
+ This is the core class of the CTGAN project, where the different components
55
+ are orchestrated together.
56
+
57
+ For more details about the process, please check the [Modeling Tabular data using
58
+ Conditional GAN](https://arxiv.org/abs/1907.00503) paper.
59
+
60
+ Args:
61
+ embedding_dim (int):
62
+ Size of the random sample passed to the Generator. Defaults to 128.
63
+ gen_dim (tuple or list of ints):
64
+ Size of the output samples for each one of the Residuals. A Resiudal Layer
65
+ will be created for each one of the values provided. Defaults to (256, 256).
66
+ dis_dim (tuple or list of ints):
67
+ Size of the output samples for each one of the Discriminator Layers. A Linear Layer
68
+ will be created for each one of the values provided. Defaults to (256, 256).
69
+ l2scale (float):
70
+ Wheight Decay for the Adam Optimizer. Defaults to 1e-6.
71
+ batch_size (int):
72
+ Number of data samples to process in each step.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ embedding_dim=128,
78
+ gen_dim=(256, 256),
79
+ dis_dim=(256, 256),
80
+ l2scale=1e-6,
81
+ batch_size=500,
82
+ patience=25,
83
+ ):
84
+
85
+ self.embedding_dim = embedding_dim
86
+ self.gen_dim = gen_dim
87
+ self.dis_dim = dis_dim
88
+ self.patience = patience
89
+ self.l2scale = l2scale
90
+ self.batch_size = batch_size
91
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
92
+
93
+ def _apply_activate(self, data):
94
+ data_t = []
95
+ st = 0
96
+ for item in self.transformer.output_info:
97
+ if item[1] == "tanh":
98
+ ed = st + item[0]
99
+ data_t.append(torch.tanh(data[:, st:ed]))
100
+ st = ed
101
+ elif item[1] == "softmax":
102
+ ed = st + item[0]
103
+ data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
104
+ st = ed
105
+ else:
106
+ raise AssertionError
107
+
108
+ return torch.cat(data_t, dim=1)
109
+
110
+ def _cond_loss(self, data, c, m):
111
+ loss = []
112
+ st = 0
113
+ st_c = 0
114
+ skip = False
115
+ for item in self.transformer.output_info:
116
+ if item[1] == "tanh":
117
+ st += item[0]
118
+ skip = True
119
+
120
+ elif item[1] == "softmax":
121
+ if skip:
122
+ skip = False
123
+ st += item[0]
124
+ continue
125
+
126
+ ed = st + item[0]
127
+ ed_c = st_c + item[0]
128
+ tmp = functional.cross_entropy(
129
+ data[:, st:ed],
130
+ torch.argmax(c[:, st_c:ed_c], dim=1),
131
+ reduction="none",
132
+ )
133
+ loss.append(tmp)
134
+ st = ed
135
+ st_c = ed_c
136
+
137
+ else:
138
+ raise AssertionError
139
+
140
+ loss = torch.stack(loss, dim=1)
141
+
142
+ return (loss * m).sum() / data.size()[0]
143
+
144
+ def fit(self, train_data, discrete_columns=(), epochs=300, log_frequency=True):
145
+ """Fit the CTGAN Synthesizer models to the training data.
146
+
147
+ Args:
148
+ train_data (numpy.ndarray or pandas.DataFrame):
149
+ Training Data. It must be a 2-dimensional numpy array or a
150
+ pandas.DataFrame.
151
+ discrete_columns (list-like):
152
+ List of discrete columns to be used to generate the Conditional
153
+ Vector. If ``train_data`` is a Numpy array, this list should
154
+ contain the integer indices of the columns. Otherwise, if it is
155
+ a ``pandas.DataFrame``, this list should contain the column names.
156
+ epochs (int):
157
+ Number of training epochs. Defaults to 300.
158
+ log_frequency (boolean):
159
+ Whether to use log frequency of categorical levels in conditional
160
+ sampling. Defaults to ``True``.
161
+ """
162
+ self.transformer = DataTransformer()
163
+ self.transformer.fit(train_data, discrete_columns)
164
+ train_data = self.transformer.transform(train_data)
165
+
166
+ data_sampler = Sampler(train_data, self.transformer.output_info)
167
+
168
+ data_dim = self.transformer.output_dimensions
169
+ self.cond_generator = ConditionalGenerator(
170
+ train_data, self.transformer.output_info, log_frequency
171
+ )
172
+
173
+ self.generator = Generator(
174
+ self.embedding_dim + self.cond_generator.n_opt, self.gen_dim, data_dim
175
+ ).to(self.device)
176
+
177
+ discriminator = Discriminator(
178
+ data_dim + self.cond_generator.n_opt, self.dis_dim
179
+ ).to(self.device)
180
+
181
+ optimizerG = optim.Adam(
182
+ self.generator.parameters(),
183
+ lr=2e-4,
184
+ betas=(0.5, 0.9),
185
+ weight_decay=self.l2scale,
186
+ )
187
+ optimizerD = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9))
188
+
189
+ if self.batch_size % 2 != 0:
190
+ raise ValueError("batch_size should even, but {} is provided".format(self.batch_size))
191
+ mean = torch.zeros(self.batch_size, self.embedding_dim, device=self.device)
192
+ std = mean + 1
193
+
194
+ train_losses = []
195
+ early_stopping = EarlyStopping(patience=self.patience, verbose=False)
196
+
197
+ steps_per_epoch = max(len(train_data) // self.batch_size, 1)
198
+
199
+ for i in tqdm(range(epochs), desc="Training CTGAN, epochs:"):
200
+ for id_ in range(steps_per_epoch):
201
+ fakez = torch.normal(mean=mean, std=std)
202
+
203
+ condvec = self.cond_generator.sample(self.batch_size)
204
+ if condvec is None:
205
+ c1, m1, col, opt = None, None, None, None
206
+ real = data_sampler.sample(self.batch_size, col, opt)
207
+ else:
208
+ c1, m1, col, opt = condvec
209
+ c1 = torch.from_numpy(c1).to(self.device)
210
+ m1 = torch.from_numpy(m1).to(self.device)
211
+ fakez = torch.cat([fakez, c1], dim=1)
212
+
213
+ perm = np.arange(self.batch_size)
214
+ np.random.shuffle(perm)
215
+ real = data_sampler.sample(self.batch_size, col[perm], opt[perm])
216
+ c2 = c1[perm]
217
+
218
+ fake = self.generator(fakez)
219
+ fakeact = self._apply_activate(fake)
220
+
221
+ real = torch.from_numpy(real.astype("float32")).to(self.device)
222
+
223
+ if c1 is not None:
224
+ fake_cat = torch.cat([fakeact, c1], dim=1)
225
+ real_cat = torch.cat([real, c2], dim=1)
226
+ else:
227
+ real_cat = real
228
+ fake_cat = fake
229
+
230
+ y_fake = discriminator(fake_cat)
231
+ y_real = discriminator(real_cat)
232
+
233
+ pen = discriminator.calc_gradient_penalty(
234
+ real_cat, fake_cat, self.device
235
+ )
236
+ loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
237
+ train_losses.append(loss_d.item())
238
+ optimizerD.zero_grad()
239
+ pen.backward(retain_graph=True)
240
+ loss_d.backward()
241
+ optimizerD.step()
242
+
243
+ fakez = torch.normal(mean=mean, std=std)
244
+ condvec = self.cond_generator.sample(self.batch_size)
245
+
246
+ if condvec is None:
247
+ c1, m1, col, opt = None, None, None, None
248
+ else:
249
+ c1, m1, col, opt = condvec
250
+ c1 = torch.from_numpy(c1).to(self.device)
251
+ m1 = torch.from_numpy(m1).to(self.device)
252
+ fakez = torch.cat([fakez, c1], dim=1)
253
+
254
+ fake = self.generator(fakez)
255
+ fakeact = self._apply_activate(fake)
256
+
257
+ if c1 is not None:
258
+ y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
259
+ else:
260
+ y_fake = discriminator(fakeact)
261
+
262
+ if condvec is None:
263
+ cross_entropy = 0
264
+ else:
265
+ cross_entropy = self._cond_loss(fake, c1, m1)
266
+
267
+ loss_g = -torch.mean(y_fake) + cross_entropy
268
+ train_losses.append(loss_g.item())
269
+ optimizerG.zero_grad()
270
+ loss_g.backward()
271
+ optimizerG.step()
272
+ early_stopping(np.average(train_losses))
273
+ if early_stopping.early_stop:
274
+ logging.info("Early stopping in GAN training!")
275
+ break
276
+ train_losses = []
277
+
278
+ def sample(self, n):
279
+ """Sample data similar to the training data.
280
+
281
+ Args:
282
+ n (int):
283
+ Number of rows to sample.
284
+
285
+ Returns:
286
+ numpy.ndarray or pandas.DataFrame
287
+ """
288
+ steps = n // self.batch_size + 1
289
+ data = []
290
+ for i in range(steps):
291
+ mean = torch.zeros(self.batch_size, self.embedding_dim)
292
+ std = mean + 1
293
+ fakez = torch.normal(mean=mean, std=std).to(self.device)
294
+
295
+ condvec = self.cond_generator.sample_zero(self.batch_size)
296
+ if condvec is None:
297
+ pass
298
+ else:
299
+ c1 = condvec
300
+ c1 = torch.from_numpy(c1).to(self.device)
301
+ fakez = torch.cat([fakez, c1], dim=1)
302
+
303
+ fake = self.generator(fakez)
304
+ fakeact = self._apply_activate(fake)
305
+ data.append(fakeact.detach().cpu().numpy())
306
+
307
+ data = np.concatenate(data, axis=0)
308
+ data = data[:n]
309
+
310
+ return self.transformer.inverse_transform(data, None)