PyTorch
GainPro / modeling_gain_dann.py
diogo-ferreira-2002's picture
Update modeling_gain_dann.py
7092701 verified
from transformers import PreTrainedModel, PretrainedConfig
#from model_gain_dann import GainDANNConfig
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
#----------------------------------------------------------------------------------------------
#------------------------------------------Encoder class --------------------------------------
#----------------------------------------------------------------------------------------------
# Encoder
class Encoder(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, latent_dim),
nn.ReLU(),
nn.BatchNorm1d(latent_dim)
)
def forward(self, x):
return self.encoder(x)
#----------------------------------------------------------------------------------------------
#------------------------------------------Decoder class --------------------------------------
#----------------------------------------------------------------------------------------------
# Decoder
class Decoder(nn.Module):
def __init__(self, latent_dim: int, hidden_dim: int, target_dim: int):
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.Dropout(0.3),
nn.Linear(hidden_dim, target_dim),
)
def forward(self, x):
return self.decoder(x)
#----------------------------------------------------------------------------------------------
#-------------------------------------DomainClassifier class ----------------------------------
#----------------------------------------------------------------------------------------------
class DomainClassifier(nn.Module):
""" Distinguish the domain of the input.
"""
def __init__(self, input_dim: int, n_class: int):
super(DomainClassifier, self).__init__()
# in the end is a logistic regressor
self.domain_classifier = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.ReLU(),
nn.Linear(input_dim, n_class)
)
def forward(self, x):
return self.domain_classifier(x)
#----------------------------------------------------------------------------------------------
#--------------------------------- class for GradientReverseal --------------------------------
#----------------------------------------------------------------------------------------------
class GradientReversalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, lambd=1.0):
ctx.lambd = lambd
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.lambd, None
class GradientReversalLayer(nn.Module):
def __init__(self, lambd=1.0):
super(GradientReversalLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return GradientReversalFunction.apply(x, self.lambd)
#----------------------------------------------------------------------------------------------
#------------------------------------------Params class ---------------------------------------
#----------------------------------------------------------------------------------------------
class Params:
def __init__(
self,
input=None,
output="imputed",
ref=None,
output_folder=f"{os.getcwd()}/results/",
header=None,
num_iterations=2001,
batch_size=128,
alpha=10,
miss_rate=0.1,
hint_rate=0.9,
lr_D=0.001,
lr_G=0.001,
override=0,
output_all=0,
):
self.input = input
self.output = output
self.output_folder = output_folder
self.ref = ref
self.header = header
self.num_iterations = num_iterations
self.batch_size = batch_size
self.alpha = alpha
self.miss_rate = miss_rate
self.hint_rate = hint_rate
self.lr_D = lr_D
self.lr_G = lr_G
self.override = override
self.output_all = output_all
#----------------------------------------------------------------------------------------------
#------------------------------------------Metrics class --------------------------------------
#----------------------------------------------------------------------------------------------
class Metrics:
def __init__(self, hypers: Params):
self.hypers = hypers
self.loss_D = np.zeros(hypers.num_iterations)
self.loss_D_evaluate = np.zeros(hypers.num_iterations)
self.loss_G = np.zeros(hypers.num_iterations)
self.loss_G_evaluate = np.zeros(hypers.num_iterations)
self.loss_MSE_train = np.zeros(hypers.num_iterations)
self.loss_MSE_train_evaluate = np.zeros(hypers.num_iterations)
self.loss_MSE_test = np.zeros(hypers.num_iterations)
self.cpu = np.zeros(hypers.num_iterations)
self.cpu_evaluate = np.zeros(hypers.num_iterations)
self.ram = np.zeros(hypers.num_iterations)
self.ram_evaluate = np.zeros(hypers.num_iterations)
self.ram_percentage = np.zeros(hypers.num_iterations)
self.ram_percentage_evaluate = np.zeros(hypers.num_iterations)
self.data_imputed = None
self.ref_data_imputed = None
#----------------------------------------------------------------------------------------------
#----------------------------------Functions for Hint Generation ------------------------------
#----------------------------------------------------------------------------------------------
def generate_hint(mask, hint_rate):
hint_mask = generate_mask(mask, 1 - hint_rate)
hint = mask * hint_mask
return hint
def generate_mask(data, miss_rate):
dim = data.shape[1]
size = data.shape[0]
A = np.random.uniform(0.0, 1.0, size=(size, dim))
B = A > miss_rate
mask = 1.0 * B
return mask
#----------------------------------------------------------------------------------------------
#------------------------------------------Network class --------------------------------------
#----------------------------------------------------------------------------------------------
class Network:
def __init__(self, hypers: Params, net_G, net_D, metrics: Metrics):
# for w in net_D.parameters():
# nn.init.normal_(w, 0, 0.02)
# for w in net_G.parameters():
# nn.init.normal_(w, 0, 0.02)
# for w in net_D.parameters():
# nn.init.xavier_normal_(w)
# for w in net_G.parameters():
# nn.init.xavier_normal_(w)
for name, param in net_D.named_parameters():
if "weight" in name:
nn.init.xavier_normal_(param)
# nn.init.uniform_(param)
for name, param in net_G.named_parameters():
if "weight" in name:
nn.init.xavier_normal_(param)
# nn.init.uniform_(param)
self.hypers = hypers
self.net_G = net_G
self.net_D = net_D
self.metrics = metrics
self.optimizer_D = torch.optim.Adam(net_D.parameters(), lr=hypers.lr_D)
self.optimizer_G = torch.optim.Adam(net_G.parameters(), lr=hypers.lr_G)
# print(summary(net_G))
def generate_sample(cls, data, mask):
dim = data.shape[1]
size = data.shape[0]
Z = torch.rand((size, dim)) * 0.01
missing_data_with_noise = mask * data + (1 - mask) * Z
input_G = torch.cat((missing_data_with_noise, mask), 1).float()
return cls.net_G(input_G)
#----------------------------------------------------------------------------------------------
#-----------------------------------------GAIN_DANN class -------------------------------------
#----------------------------------------------------------------------------------------------
class GAIN_DANN(nn.Module):
def __init__(self, input_dim: int, latent_dim: int, n_class: int, params: Params, metrics: Metrics, hint_rate = 0.9):
super(GAIN_DANN, self).__init__()
self.encoder = Encoder(input_dim=input_dim, hidden_dim=128, latent_dim=latent_dim)
# gradient reversal layer
self.grl = GradientReversalLayer()
self.domain_classifier = DomainClassifier(latent_dim, n_class=n_class)
print("latent_dim1:", latent_dim)
# gain
self.gain = Network(hypers=params,
net_G= nn.Sequential(
nn.Linear(latent_dim* 2, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim),
nn.Sigmoid(),
),
net_D= nn.Sequential(
nn.Linear(latent_dim * 2, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim),
nn.Sigmoid(),
),
metrics=metrics)
self.decoder = Decoder(latent_dim=latent_dim, hidden_dim=128, target_dim=input_dim)
print("latent_dim2:",latent_dim)
def forward(self, x):
"""
Forward pass for GAIN_DANN.
Handles missing values (NaNs) by replacing them with noise and using a mask.
"""
#todo x must be scaled
x_filled = x.clone()
x_filled[torch.isnan(x_filled)] = 0 # x filled with zeros in the place of missing values
mask = (~torch.isnan(x)).float()
# 1. Encode
x_encoded = self.encoder(x_filled)
x_grl = self.grl(x_encoded) # as a matter of fact, this is not needed, this layer is important for the training process
# 2. Gain
sample = self.gain.generate_sample(x_grl, mask)
x_imputed = x_encoded * mask + sample * (1 - mask)
# 2.1. Domain Classifier
x_domain = self.domain_classifier(x_encoded)
x_domain = torch.argmax(x_domain, dim=1)
# 3. Decoder
x_reconstructed = self.decoder(x_imputed)
#todo voltar a transformar para a escala antes de ser scaled
return x_reconstructed, x_domain
#----------------------------------------------------------------------------------------------
#---------------------------------GAIN_DANN class for HuggingFace -----------------------------
#----------------------------------------------------------------------------------------------
class GainDANNConfig(PretrainedConfig):
model_type = "gain_dann"
def __init__(self, input_dim=3013, latent_dim=3013, n_class=17, hint_rate=0.9, lr_D=0.001, lr_G=0.001,
num_iterations=2001, batch_size=128, alpha=10, miss_rate=0.1, override=0, output_all=0, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.latent_dim = latent_dim
self.n_class = n_class
self.hint_rate = hint_rate
self.lr_D = lr_D
self.lr_G = lr_G
self.num_iterations = num_iterations
self.batch_size = batch_size
self.alpha = alpha
self.miss_rate = miss_rate
self.override = override
self.output_all = output_all
class GainDANN(PreTrainedModel):
config_class = GainDANNConfig
def __init__(self, config):
super().__init__(config)
params = Params(lr_D=config.lr_D,
lr_G=config.lr_G,
hint_rate=config.hint_rate,
num_iterations=getattr(config, "num_iterations", 2001),
batch_size=getattr(config, "batch_size", 128),
alpha=getattr(config, "alpha", 10),
miss_rate=getattr(config, "miss_rate", 0.1),
override=getattr(config, "override", 0),
output_all=getattr(config, "output_all", 0))
metrics = Metrics(params)
self.model = GAIN_DANN(
input_dim=config.input_dim,
latent_dim=config.latent_dim,
n_class=config.n_class,
params=params,
metrics=metrics,
hint_rate=config.hint_rate
)
def forward(self, x):
return self.model(x)