Spaces:
Sleeping
Sleeping
File size: 5,248 Bytes
ef677f1 b32645b ef677f1 dbe81c1 ef677f1 dbe81c1 ef677f1 dbe81c1 a4c8f0c ef677f1 dbe81c1 ef677f1 dbe81c1 ef677f1 dbe81c1 ef677f1 dbe81c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
import numpy as np
from sklearn.linear_model import Ridge, LogisticRegression
def to_torch(x):
return torch.from_numpy(x).float()
def to_cuda(x, use_cuda):
return x.cuda() if use_cuda else x
def to_numpy(x):
return x.detach().cpu().numpy()
def rmse(a, b, mean=torch.mean):
return mean((a-b)**2)**0.5
def latent_loss(z, use_cuda=True):
C = z.T@z
mu = torch.mean(z, dim=0)
tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
loss_C = rmse(C, tgt1)
loss_mu = rmse(mu, tgt2)
return loss_C, loss_mu, C, mu
def decor_loss(z, demo, use_cuda=True):
ps = []
losses = []
for di in range(demo.shape[1]):
d = demo[:,di]
d = d - torch.mean(d)
p = torch.einsum('n,nz->z', d, z)
p = p/torch.std(d)
p = p/torch.einsum('nz,nz->z', z, z)
tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
loss = rmse(p, tgt)
losses.append(loss)
ps.append(p)
losses = torch.stack(losses)
return losses, ps
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
demo_t = []
demo_idx = 0
for d, t, s in zip(demo, demo_types, pred_stats):
if t == 'continuous':
demo_t.append(to_cuda(to_torch(d), use_cuda))
elif t == 'categorical':
for dd in d:
if dd not in s:
print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
raise Exception('Bad demographic')
for ss in s:
idx = (d == ss).astype('bool')
zeros = torch.zeros(len(d))
zeros[idx] = 1
demo_t.append(to_cuda(zeros, use_cuda))
demo_idx += 1
demo_t = torch.stack(demo_t).permute(1,0)
return demo_t
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
# Get linear predictors for demographics
pred_w = []
pred_i = []
pred_stats = []
train_losses = []
val_losses = []
for i, d, t in zip(range(len(demo)), demo, demo_types):
print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
if t == 'continuous':
pred_stats.append([np.mean(d), np.std(d)])
reg = Ridge(alpha=alpha).fit(x, d)
reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
reg_i = reg.intercept_
pred_w.append(reg_w)
pred_i.append(reg_i)
elif t == 'categorical':
pred_stats.append(sorted(list(set(list(d)))))
reg = LogisticRegression(C=LR_C).fit(x, d)
if len(reg.coef_) == 1:
reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
reg_i = reg.intercept_[0]
pred_w.append(-reg_w)
pred_i.append(-reg_i)
pred_w.append(reg_w)
pred_i.append(reg_i)
else:
for i in range(len(reg.coef_)):
reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
reg_i = reg.intercept_[i]
pred_w.append(reg_w)
pred_i.append(reg_i)
print(' done')
ret_obj.pred_stats = pred_stats
# Convert input to pytorch
x = to_cuda(to_torch(x), vae.use_cuda)
# Convert demographics to pytorch
demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
# Training loop
ce = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
for e in range(nepochs):
epoch_losses = []
vae.train()
for bs in range(0, len(x), bsize):
xb = x[bs:(bs+bsize)]
db = demo_t[bs:(bs+bsize)]
optim.zero_grad()
# Reconstruct
z = vae.enc(xb)
y = vae.dec(z, db)
loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
loss_decor, _ = decor_loss(z, db, vae.use_cuda)
loss_decor = sum(loss_decor)
loss_rec = rmse(xb, y)
# Calculate total loss
total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
loss_rec_mult*loss_rec + loss_decor_mult*loss_decor)
total_loss.backward()
optim.step()
epoch_losses.append(total_loss.item())
# Record training loss
train_losses.append(np.mean(epoch_losses))
# Validation step
if e % pperiod == 0:
vae.eval()
with torch.no_grad():
z = vae.enc(x)
y = vae.dec(z, demo_t)
val_loss = rmse(x, y).item()
val_losses.append(val_loss)
print(f'Epoch {e}/{nepochs} - '
f'Train Loss: {train_losses[-1]:.4f} - '
f'Val Loss: {val_loss:.4f}')
return train_losses, val_losses |