Spaces:
Sleeping
Sleeping
File size: 4,328 Bytes
ef677f1 b32645b ef677f1 b32645b ef677f1 b32645b ef677f1 b32645b ef677f1 b32645b ef677f1 b32645b ef677f1 b32645b e4a8a19 b32645b e4a8a19 b32645b e4a8a19 b32645b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import to_torch, to_cuda, to_numpy, demo_to_torch
from sklearn.base import BaseEstimator
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
super(VAE, self).__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.demo_dim = demo_dim
self.use_cuda = use_cuda
# Encoder
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
# Decoder
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
# Batch normalization layers
self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
def enc(self, x):
x = self.bn1(F.relu(self.enc1(x)))
z = self.enc2(x)
return z
def gen(self, n):
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
def dec(self, z, demo):
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
x = self.bn2(F.relu(self.dec1(z)))
x = self.dec2(x)
return x
class DemoVAE(BaseEstimator):
def __init__(self, **params):
self.set_params(**params)
@staticmethod
def get_default_params():
return dict(
latent_dim=32,
use_cuda=True,
nepochs=1000,
pperiod=100,
bsize=16,
loss_C_mult=1,
loss_mu_mult=1,
loss_rec_mult=100,
loss_decor_mult=10,
loss_pred_mult=0.001,
alpha=100,
LR_C=100,
lr=1e-4,
weight_decay=0
)
def get_params(self, deep=True):
return {k: getattr(self, k) for k in self.get_default_params().keys()}
def set_params(self, **params):
for k, v in self.get_default_params().items():
setattr(self, k, params.get(k, v))
return self
def fit(self, x, demo, demo_types):
from utils import train_vae
# Calculate demo_dim
demo_dim = 0
for d, t in zip(demo, demo_types):
if t == 'continuous':
demo_dim += 1
elif t == 'categorical':
demo_dim += len(set(d))
else:
raise ValueError(f'Demographic type "{t}" not supported')
# Initialize VAE
self.input_dim = x.shape[1]
self.demo_dim = demo_dim
self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
# Train VAE
train_vae(
self.vae, x, demo, demo_types,
self.nepochs, self.pperiod, self.bsize,
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
self.loss_decor_mult, self.loss_pred_mult,
self.lr, self.weight_decay, self.alpha, self.LR_C,
self
)
return self
def transform(self, x, demo, demo_types):
if isinstance(x, int):
z = self.vae.gen(x)
else:
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
y = self.vae.dec(z, demo_t)
return to_numpy(y)
def get_latents(self, x):
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
return to_numpy(z)
def save(self, path):
torch.save({
'model_state_dict': self.vae.state_dict(),
'params': self.get_params(),
'pred_stats': self.pred_stats,
'input_dim': self.input_dim,
'demo_dim': self.demo_dim
}, path)
def load(self, path):
checkpoint = torch.load(path)
self.set_params(**checkpoint['params'])
self.pred_stats = checkpoint['pred_stats']
self.input_dim = checkpoint['input_dim']
self.demo_dim = checkpoint['demo_dim']
self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
self.vae.load_state_dict(checkpoint['model_state_dict'])
|