PyTorch
diogo-ferreira-2002 commited on
Commit
0ebcc5e
·
verified ·
1 Parent(s): 1ed6c61

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +15 -0
  2. gain_dann_hela_dic.bin +3 -0
  3. model_gain_dann.py +362 -0
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name" : "gain_dann",
3
+ "input_dim" : 3013,
4
+ "latent_dim" : 3013,
5
+ "n_class" : 17,
6
+ "lr_D": 0.001,
7
+ "lr_G": 0.001,
8
+ "num_iterations": 2001,
9
+ "batch_size": 128,
10
+ "alpha": 10,
11
+ "miss_rate": 0.1,
12
+ "override": 0,
13
+ "output_all": 0
14
+
15
+ }
gain_dann_hela_dic.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:985e5142126b1c75a77831bd19a62dd65744fd8e20f15834fe1117a6011d8d4e
3
+ size 42784021
model_gain_dann.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from model_gain_dann import GainDANNConfig
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ #----------------------------------------------------------------------------------------------
12
+ #------------------------------------------Encoder class --------------------------------------
13
+ #----------------------------------------------------------------------------------------------
14
+
15
+ # Encoder
16
+ class Encoder(nn.Module):
17
+ def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
18
+ super(Encoder, self).__init__()
19
+ self.encoder = nn.Sequential(
20
+ nn.Linear(input_dim, hidden_dim),
21
+ nn.BatchNorm1d(hidden_dim),
22
+ nn.ReLU(),
23
+ nn.Dropout(0.3),
24
+ nn.Linear(hidden_dim, latent_dim),
25
+ nn.ReLU(),
26
+ nn.BatchNorm1d(latent_dim)
27
+ )
28
+
29
+
30
+ def forward(self, x):
31
+ return self.encoder(x)
32
+
33
+
34
+ #----------------------------------------------------------------------------------------------
35
+ #------------------------------------------Decoder class --------------------------------------
36
+ #----------------------------------------------------------------------------------------------
37
+
38
+ # Decoder
39
+ class Decoder(nn.Module):
40
+ def __init__(self, latent_dim: int, hidden_dim: int, target_dim: int):
41
+ super(Decoder, self).__init__()
42
+ self.decoder = nn.Sequential(
43
+ nn.Linear(latent_dim, hidden_dim),
44
+ nn.Dropout(0.3),
45
+ nn.Linear(hidden_dim, target_dim),
46
+ )
47
+
48
+ def forward(self, x):
49
+ return self.decoder(x)
50
+
51
+ #----------------------------------------------------------------------------------------------
52
+ #-------------------------------------DomainClassifier class ----------------------------------
53
+ #----------------------------------------------------------------------------------------------
54
+
55
+
56
+ class DomainClassifier(nn.Module):
57
+ """ Distinguish the domain of the input.
58
+ """
59
+
60
+ def __init__(self, input_dim: int, n_class: int):
61
+ super(DomainClassifier, self).__init__()
62
+
63
+ # in the end is a logistic regressor
64
+ self.domain_classifier = nn.Sequential(
65
+ nn.Linear(input_dim, input_dim),
66
+ nn.ReLU(),
67
+ nn.Linear(input_dim, n_class)
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.domain_classifier(x)
72
+
73
+ #----------------------------------------------------------------------------------------------
74
+ #--------------------------------- class for GradientReverseal --------------------------------
75
+ #----------------------------------------------------------------------------------------------
76
+
77
+ class GradientReversalFunction(torch.autograd.Function):
78
+ @staticmethod
79
+ def forward(ctx, x, lambd=1.0):
80
+ ctx.lambd = lambd
81
+ return x.view_as(x)
82
+
83
+ @staticmethod
84
+ def backward(ctx, grad_output):
85
+ return grad_output.neg() * ctx.lambd, None
86
+
87
+ class GradientReversalLayer(nn.Module):
88
+ def __init__(self, lambd=1.0):
89
+ super(GradientReversalLayer, self).__init__()
90
+ self.lambd = lambd
91
+
92
+ def forward(self, x):
93
+ return GradientReversalFunction.apply(x, self.lambd)
94
+
95
+ #----------------------------------------------------------------------------------------------
96
+ #------------------------------------------Params class ---------------------------------------
97
+ #----------------------------------------------------------------------------------------------
98
+
99
+ class Params:
100
+
101
+ def __init__(
102
+ self,
103
+ input=None,
104
+ output="imputed",
105
+ ref=None,
106
+ output_folder=f"{os.getcwd()}/results/",
107
+ header=None,
108
+ num_iterations=2001,
109
+ batch_size=128,
110
+ alpha=10,
111
+ miss_rate=0.1,
112
+ hint_rate=0.9,
113
+ lr_D=0.001,
114
+ lr_G=0.001,
115
+ override=0,
116
+ output_all=0,
117
+ ):
118
+ self.input = input
119
+ self.output = output
120
+ self.output_folder = output_folder
121
+ self.ref = ref
122
+ self.header = header
123
+ self.num_iterations = num_iterations
124
+ self.batch_size = batch_size
125
+ self.alpha = alpha
126
+ self.miss_rate = miss_rate
127
+ self.hint_rate = hint_rate
128
+ self.lr_D = lr_D
129
+ self.lr_G = lr_G
130
+ self.override = override
131
+ self.output_all = output_all
132
+
133
+
134
+ #----------------------------------------------------------------------------------------------
135
+ #------------------------------------------Metrics class --------------------------------------
136
+ #----------------------------------------------------------------------------------------------
137
+
138
+ class Metrics:
139
+ def __init__(self, hypers: Params):
140
+
141
+ self.hypers = hypers
142
+
143
+ self.loss_D = np.zeros(hypers.num_iterations)
144
+ self.loss_D_evaluate = np.zeros(hypers.num_iterations)
145
+
146
+ self.loss_G = np.zeros(hypers.num_iterations)
147
+ self.loss_G_evaluate = np.zeros(hypers.num_iterations)
148
+
149
+ self.loss_MSE_train = np.zeros(hypers.num_iterations)
150
+ self.loss_MSE_train_evaluate = np.zeros(hypers.num_iterations)
151
+
152
+ self.loss_MSE_test = np.zeros(hypers.num_iterations)
153
+
154
+ self.cpu = np.zeros(hypers.num_iterations)
155
+ self.cpu_evaluate = np.zeros(hypers.num_iterations)
156
+
157
+ self.ram = np.zeros(hypers.num_iterations)
158
+ self.ram_evaluate = np.zeros(hypers.num_iterations)
159
+
160
+ self.ram_percentage = np.zeros(hypers.num_iterations)
161
+ self.ram_percentage_evaluate = np.zeros(hypers.num_iterations)
162
+
163
+ self.data_imputed = None
164
+ self.ref_data_imputed = None
165
+
166
+
167
+ #----------------------------------------------------------------------------------------------
168
+ #----------------------------------Functions for Hint Generation ------------------------------
169
+ #----------------------------------------------------------------------------------------------
170
+
171
+ def generate_hint(mask, hint_rate):
172
+ hint_mask = generate_mask(mask, 1 - hint_rate)
173
+ hint = mask * hint_mask
174
+
175
+ return hint
176
+
177
+
178
+ def generate_mask(data, miss_rate):
179
+ dim = data.shape[1]
180
+ size = data.shape[0]
181
+ A = np.random.uniform(0.0, 1.0, size=(size, dim))
182
+ B = A > miss_rate
183
+ mask = 1.0 * B
184
+
185
+ return mask
186
+
187
+ #----------------------------------------------------------------------------------------------
188
+ #------------------------------------------Network class --------------------------------------
189
+ #----------------------------------------------------------------------------------------------
190
+
191
+ class Network:
192
+ def __init__(self, hypers: Params, net_G, net_D, metrics: Metrics):
193
+
194
+ # for w in net_D.parameters():
195
+ # nn.init.normal_(w, 0, 0.02)
196
+ # for w in net_G.parameters():
197
+ # nn.init.normal_(w, 0, 0.02)
198
+
199
+ # for w in net_D.parameters():
200
+ # nn.init.xavier_normal_(w)
201
+ # for w in net_G.parameters():
202
+ # nn.init.xavier_normal_(w)
203
+
204
+ for name, param in net_D.named_parameters():
205
+ if "weight" in name:
206
+ nn.init.xavier_normal_(param)
207
+ # nn.init.uniform_(param)
208
+
209
+ for name, param in net_G.named_parameters():
210
+ if "weight" in name:
211
+ nn.init.xavier_normal_(param)
212
+ # nn.init.uniform_(param)
213
+
214
+ self.hypers = hypers
215
+ self.net_G = net_G
216
+ self.net_D = net_D
217
+ self.metrics = metrics
218
+
219
+ self.optimizer_D = torch.optim.Adam(net_D.parameters(), lr=hypers.lr_D)
220
+ self.optimizer_G = torch.optim.Adam(net_G.parameters(), lr=hypers.lr_G)
221
+
222
+ # print(summary(net_G))
223
+
224
+ def generate_sample(cls, data, mask):
225
+ dim = data.shape[1]
226
+ size = data.shape[0]
227
+
228
+ Z = torch.rand((size, dim)) * 0.01
229
+ missing_data_with_noise = mask * data + (1 - mask) * Z
230
+ input_G = torch.cat((missing_data_with_noise, mask), 1).float()
231
+
232
+ return cls.net_G(input_G)
233
+
234
+ #----------------------------------------------------------------------------------------------
235
+ #-----------------------------------------GAIN_DANN class -------------------------------------
236
+ #----------------------------------------------------------------------------------------------
237
+
238
+ class GAIN_DANN(nn.Module):
239
+ def __init__(self, input_dim: int, latent_dim: int, n_class: int, params: Params, metrics: Metrics):
240
+ super(GAIN_DANN, self).__init__()
241
+ self.encoder = Encoder(input_dim=input_dim, hidden_dim=128, latent_dim=latent_dim)
242
+
243
+ # gradient reversal layer
244
+ self.grl = GradientReversalLayer()
245
+
246
+ self.domain_classifier = DomainClassifier(latent_dim, n_class=n_class)
247
+ print("latent_dim1:", latent_dim)
248
+ # gain
249
+ self.gain = Network(hypers=params,
250
+ net_G= nn.Sequential(
251
+ nn.Linear(latent_dim* 2, latent_dim),
252
+ nn.ReLU(),
253
+ nn.Linear(latent_dim, latent_dim),
254
+ nn.ReLU(),
255
+ nn.Linear(latent_dim, latent_dim),
256
+ nn.Sigmoid(),
257
+ ),
258
+ net_D= nn.Sequential(
259
+ nn.Linear(latent_dim * 2, latent_dim),
260
+ nn.ReLU(),
261
+ nn.Linear(latent_dim, latent_dim),
262
+ nn.ReLU(),
263
+ nn.Linear(latent_dim, latent_dim),
264
+ nn.Sigmoid(),
265
+ ),
266
+ metrics=metrics)
267
+
268
+ self.decoder = Decoder(latent_dim=latent_dim, hidden_dim=128, target_dim=input_dim)
269
+ print("latent_dim2:",latent_dim)
270
+
271
+
272
+
273
+ def forward(self, x):
274
+ """
275
+ Forward pass for GAIN_DANN.
276
+ Handles missing values (NaNs) by replacing them with noise and using a mask.
277
+ """
278
+
279
+ #todo x must be scaled
280
+
281
+ x_filled = x.clone()
282
+ x_filled[torch.isnan(x_filled)] = 0 # x filled with zeros in the place of missing values
283
+
284
+ mask = (~torch.isnan(x)).float()
285
+
286
+ # 1. Encode
287
+ x_encoded = self.encoder(x_filled)
288
+ x_grl = self.grl(x_encoded) # as a matter of fact, this is not needed, this layer is important for the training process
289
+
290
+ # 2. Gain
291
+ sample = self.gain.generate_sample(x_grl, mask)
292
+ x_imputed = x_encoded * mask + sample * (1 - mask)
293
+
294
+ # 2.1. Domain Classifier
295
+ x_domain = self.domain_classifier(x_encoded)
296
+ x_domain = torch.argmax(x_domain, dim=1)
297
+
298
+ # 3. Decoder
299
+ x_reconstructed = self.decoder(x_imputed)
300
+
301
+ #todo voltar a transformar para a escala antes de ser scaled
302
+
303
+ return x_reconstructed, x_domain
304
+
305
+
306
+
307
+ #----------------------------------------------------------------------------------------------
308
+ #---------------------------------GAIN_DANN class for HuggingFace -----------------------------
309
+ #----------------------------------------------------------------------------------------------
310
+
311
+
312
+ class GainDANNConfig(PretrainedConfig):
313
+ model_type = "gain_dann"
314
+
315
+ def __init__(self, input_dim=3013, latent_dim=3013, n_class=17, hint_rate=0.9, lr_D=0.001, lr_G=0.001,
316
+ num_iterations=2001, batch_size=128, alpha=10, miss_rate=0.1, override=0, output_all=0, **kwargs):
317
+ super().__init__(**kwargs)
318
+ self.input_dim = input_dim
319
+ self.latent_dim = latent_dim
320
+ self.n_class = n_class
321
+ self.hint_rate = hint_rate
322
+ self.lr_D = lr_D
323
+ self.lr_G = lr_G
324
+ self.num_iterations = num_iterations
325
+ self.batch_size = batch_size
326
+ self.alpha = alpha
327
+ self.miss_rate = miss_rate
328
+ self.override = override
329
+ self.output_all = output_all
330
+
331
+
332
+
333
+
334
+ class GainDANN(PreTrainedModel):
335
+ config_class = GainDANNConfig
336
+
337
+ def __init__(self, config):
338
+ super().__init__(config)
339
+ params = Params(lr_D=config.lr_D,
340
+ lr_G=config.lr_G,
341
+ hint_rate=config.hint_rate,
342
+ num_iterations=getattr(config, "num_iterations", 2001),
343
+ batch_size=getattr(config, "batch_size", 128),
344
+ alpha=getattr(config, "alpha", 10),
345
+ miss_rate=getattr(config, "miss_rate", 0.1),
346
+ override=getattr(config, "override", 0),
347
+ output_all=getattr(config, "output_all", 0))
348
+ metrics = Metrics(params)
349
+ self.model = GAIN_DANN(
350
+ input_dim=config.input_dim,
351
+ latent_dim=config.latent_dim,
352
+ n_class=config.n_class,
353
+ params=params,
354
+ metrics=metrics,
355
+ hint_rate=config.hint_rate
356
+ )
357
+
358
+ def forward(self, x):
359
+ return self.model(x)
360
+
361
+
362
+