PyTorch
File size: 12,770 Bytes
41250f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7092701
41250f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
from transformers import PreTrainedModel, PretrainedConfig
#from model_gain_dann import GainDANNConfig

import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn


#----------------------------------------------------------------------------------------------
#------------------------------------------Encoder class --------------------------------------
#----------------------------------------------------------------------------------------------

# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, latent_dim),
            nn.ReLU(),
            nn.BatchNorm1d(latent_dim)
        )

    
    def forward(self, x):
        return self.encoder(x)


#----------------------------------------------------------------------------------------------
#------------------------------------------Decoder class --------------------------------------
#----------------------------------------------------------------------------------------------

# Decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim: int, hidden_dim: int, target_dim: int):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, target_dim),
        )

    def forward(self, x):
        return self.decoder(x)
    
#----------------------------------------------------------------------------------------------
#-------------------------------------DomainClassifier class ----------------------------------
#----------------------------------------------------------------------------------------------


class DomainClassifier(nn.Module):
    """ Distinguish the domain of the input.
    """

    def __init__(self, input_dim: int, n_class: int):
        super(DomainClassifier, self).__init__()

        # in the end is a logistic regressor
        self.domain_classifier = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, n_class)
        )

    def forward(self, x):
        return self.domain_classifier(x)
    
#----------------------------------------------------------------------------------------------
#--------------------------------- class for GradientReverseal --------------------------------
#---------------------------------------------------------------------------------------------- 

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd=1.0):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambd, None

class GradientReversalLayer(nn.Module):
    def __init__(self, lambd=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambd)
    
#----------------------------------------------------------------------------------------------
#------------------------------------------Params class ---------------------------------------
#----------------------------------------------------------------------------------------------

class Params:

    def __init__(
        self,
        input=None,
        output="imputed",
        ref=None,
        output_folder=f"{os.getcwd()}/results/",
        header=None,
        num_iterations=2001,
        batch_size=128,
        alpha=10,
        miss_rate=0.1,
        hint_rate=0.9,
        lr_D=0.001,
        lr_G=0.001,
        override=0,
        output_all=0,
    ):
        self.input = input
        self.output = output
        self.output_folder = output_folder
        self.ref = ref
        self.header = header
        self.num_iterations = num_iterations
        self.batch_size = batch_size
        self.alpha = alpha
        self.miss_rate = miss_rate
        self.hint_rate = hint_rate
        self.lr_D = lr_D
        self.lr_G = lr_G
        self.override = override
        self.output_all = output_all


#----------------------------------------------------------------------------------------------
#------------------------------------------Metrics class --------------------------------------
#----------------------------------------------------------------------------------------------

class Metrics:
    def __init__(self, hypers: Params):

        self.hypers = hypers

        self.loss_D = np.zeros(hypers.num_iterations)
        self.loss_D_evaluate = np.zeros(hypers.num_iterations)

        self.loss_G = np.zeros(hypers.num_iterations)
        self.loss_G_evaluate = np.zeros(hypers.num_iterations)

        self.loss_MSE_train = np.zeros(hypers.num_iterations)
        self.loss_MSE_train_evaluate = np.zeros(hypers.num_iterations)

        self.loss_MSE_test = np.zeros(hypers.num_iterations)

        self.cpu = np.zeros(hypers.num_iterations)
        self.cpu_evaluate = np.zeros(hypers.num_iterations)

        self.ram = np.zeros(hypers.num_iterations)
        self.ram_evaluate = np.zeros(hypers.num_iterations)

        self.ram_percentage = np.zeros(hypers.num_iterations)
        self.ram_percentage_evaluate = np.zeros(hypers.num_iterations)

        self.data_imputed = None
        self.ref_data_imputed = None


#----------------------------------------------------------------------------------------------
#----------------------------------Functions for Hint Generation ------------------------------
#----------------------------------------------------------------------------------------------

def generate_hint(mask, hint_rate):
    hint_mask = generate_mask(mask, 1 - hint_rate)
    hint = mask * hint_mask

    return hint


def generate_mask(data, miss_rate):
    dim = data.shape[1]
    size = data.shape[0]
    A = np.random.uniform(0.0, 1.0, size=(size, dim))
    B = A > miss_rate
    mask = 1.0 * B

    return mask

#----------------------------------------------------------------------------------------------
#------------------------------------------Network class --------------------------------------
#----------------------------------------------------------------------------------------------

class Network:
    def __init__(self, hypers: Params, net_G, net_D, metrics: Metrics):

        # for w in net_D.parameters():
        #    nn.init.normal_(w, 0, 0.02)
        # for w in net_G.parameters():
        #    nn.init.normal_(w, 0, 0.02)

        # for w in net_D.parameters():
        #    nn.init.xavier_normal_(w)
        # for w in net_G.parameters():
        #    nn.init.xavier_normal_(w)

        for name, param in net_D.named_parameters():
            if "weight" in name:
                nn.init.xavier_normal_(param)
                # nn.init.uniform_(param)

        for name, param in net_G.named_parameters():
            if "weight" in name:
                nn.init.xavier_normal_(param)
                # nn.init.uniform_(param)

        self.hypers = hypers
        self.net_G = net_G
        self.net_D = net_D
        self.metrics = metrics

        self.optimizer_D = torch.optim.Adam(net_D.parameters(), lr=hypers.lr_D)
        self.optimizer_G = torch.optim.Adam(net_G.parameters(), lr=hypers.lr_G)

        # print(summary(net_G))

    def generate_sample(cls, data, mask):
        dim = data.shape[1]
        size = data.shape[0]

        Z = torch.rand((size, dim)) * 0.01
        missing_data_with_noise = mask * data + (1 - mask) * Z
        input_G = torch.cat((missing_data_with_noise, mask), 1).float()

        return cls.net_G(input_G)

#----------------------------------------------------------------------------------------------
#-----------------------------------------GAIN_DANN class -------------------------------------
#----------------------------------------------------------------------------------------------    

class GAIN_DANN(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int, n_class: int, params: Params, metrics: Metrics, hint_rate = 0.9):
        super(GAIN_DANN, self).__init__()
        self.encoder = Encoder(input_dim=input_dim, hidden_dim=128, latent_dim=latent_dim)
        
        # gradient reversal layer
        self.grl = GradientReversalLayer()

        self.domain_classifier = DomainClassifier(latent_dim, n_class=n_class)
        print("latent_dim1:", latent_dim)
        # gain
        self.gain = Network(hypers=params, 
                            net_G= nn.Sequential(
                                nn.Linear(latent_dim* 2, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.Sigmoid(),
                            ), 
                            net_D= nn.Sequential(
                                nn.Linear(latent_dim * 2, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.Sigmoid(),
                            ),
                            metrics=metrics)
        
        self.decoder = Decoder(latent_dim=latent_dim, hidden_dim=128, target_dim=input_dim)
        print("latent_dim2:",latent_dim)



    def forward(self, x):
        """
            Forward pass for GAIN_DANN.
            Handles missing values (NaNs) by replacing them with noise and using a mask.
        """

        #todo x must be scaled

        x_filled = x.clone()
        x_filled[torch.isnan(x_filled)] = 0 # x filled with zeros in the place of missing values

        mask = (~torch.isnan(x)).float()

        # 1. Encode
        x_encoded = self.encoder(x_filled)
        x_grl = self.grl(x_encoded) # as a matter of fact, this is not needed, this layer is important for the training process

        # 2. Gain
        sample = self.gain.generate_sample(x_grl, mask)
        x_imputed = x_encoded * mask + sample * (1 - mask)

        # 2.1. Domain Classifier
        x_domain = self.domain_classifier(x_encoded)
        x_domain = torch.argmax(x_domain, dim=1)

        # 3. Decoder
        x_reconstructed = self.decoder(x_imputed)

        #todo voltar a transformar para a escala antes de ser scaled

        return x_reconstructed, x_domain
    


#----------------------------------------------------------------------------------------------
#---------------------------------GAIN_DANN class for HuggingFace -----------------------------
#----------------------------------------------------------------------------------------------    


class GainDANNConfig(PretrainedConfig):
    model_type = "gain_dann"

    def __init__(self, input_dim=3013, latent_dim=3013, n_class=17, hint_rate=0.9, lr_D=0.001, lr_G=0.001,
                 num_iterations=2001, batch_size=128, alpha=10, miss_rate=0.1, override=0, output_all=0, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.n_class = n_class
        self.hint_rate = hint_rate
        self.lr_D = lr_D
        self.lr_G = lr_G
        self.num_iterations = num_iterations
        self.batch_size = batch_size
        self.alpha = alpha
        self.miss_rate = miss_rate
        self.override = override
        self.output_all = output_all




class GainDANN(PreTrainedModel):
    config_class = GainDANNConfig

    def __init__(self, config):
        super().__init__(config)
        params = Params(lr_D=config.lr_D,
            lr_G=config.lr_G,
            hint_rate=config.hint_rate,
            num_iterations=getattr(config, "num_iterations", 2001),
            batch_size=getattr(config, "batch_size", 128),
            alpha=getattr(config, "alpha", 10),
            miss_rate=getattr(config, "miss_rate", 0.1),
            override=getattr(config, "override", 0),
            output_all=getattr(config, "output_all", 0))
        metrics = Metrics(params)
        self.model = GAIN_DANN(
            input_dim=config.input_dim,
            latent_dim=config.latent_dim,
            n_class=config.n_class,
            params=params,
            metrics=metrics,
            hint_rate=config.hint_rate
        )

    def forward(self, x):
        return self.model(x)