AphasiaPred / utils.py
SreekarB's picture
Upload 13 files
dbe81c1 verified
import torch
import numpy as np
from sklearn.linear_model import Ridge, LogisticRegression
def to_torch(x):
return torch.from_numpy(x).float()
def to_cuda(x, use_cuda):
return x.cuda() if use_cuda else x
def to_numpy(x):
return x.detach().cpu().numpy()
def rmse(a, b, mean=torch.mean):
return mean((a-b)**2)**0.5
def latent_loss(z, use_cuda=True):
C = z.T@z
mu = torch.mean(z, dim=0)
tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
loss_C = rmse(C, tgt1)
loss_mu = rmse(mu, tgt2)
return loss_C, loss_mu, C, mu
def decor_loss(z, demo, use_cuda=True):
ps = []
losses = []
for di in range(demo.shape[1]):
d = demo[:,di]
d = d - torch.mean(d)
p = torch.einsum('n,nz->z', d, z)
p = p/torch.std(d)
p = p/torch.einsum('nz,nz->z', z, z)
tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
loss = rmse(p, tgt)
losses.append(loss)
ps.append(p)
losses = torch.stack(losses)
return losses, ps
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
demo_t = []
demo_idx = 0
for d, t, s in zip(demo, demo_types, pred_stats):
if t == 'continuous':
demo_t.append(to_cuda(to_torch(d), use_cuda))
elif t == 'categorical':
for dd in d:
if dd not in s:
print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
raise Exception('Bad demographic')
for ss in s:
idx = (d == ss).astype('bool')
zeros = torch.zeros(len(d))
zeros[idx] = 1
demo_t.append(to_cuda(zeros, use_cuda))
demo_idx += 1
demo_t = torch.stack(demo_t).permute(1,0)
return demo_t
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
# Get linear predictors for demographics
pred_w = []
pred_i = []
pred_stats = []
train_losses = []
val_losses = []
for i, d, t in zip(range(len(demo)), demo, demo_types):
print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
if t == 'continuous':
pred_stats.append([np.mean(d), np.std(d)])
reg = Ridge(alpha=alpha).fit(x, d)
reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
reg_i = reg.intercept_
pred_w.append(reg_w)
pred_i.append(reg_i)
elif t == 'categorical':
pred_stats.append(sorted(list(set(list(d)))))
reg = LogisticRegression(C=LR_C).fit(x, d)
if len(reg.coef_) == 1:
reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
reg_i = reg.intercept_[0]
pred_w.append(-reg_w)
pred_i.append(-reg_i)
pred_w.append(reg_w)
pred_i.append(reg_i)
else:
for i in range(len(reg.coef_)):
reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
reg_i = reg.intercept_[i]
pred_w.append(reg_w)
pred_i.append(reg_i)
print(' done')
ret_obj.pred_stats = pred_stats
# Convert input to pytorch
x = to_cuda(to_torch(x), vae.use_cuda)
# Convert demographics to pytorch
demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
# Training loop
ce = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
for e in range(nepochs):
epoch_losses = []
vae.train()
for bs in range(0, len(x), bsize):
xb = x[bs:(bs+bsize)]
db = demo_t[bs:(bs+bsize)]
optim.zero_grad()
# Reconstruct
z = vae.enc(xb)
y = vae.dec(z, db)
loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
loss_decor, _ = decor_loss(z, db, vae.use_cuda)
loss_decor = sum(loss_decor)
loss_rec = rmse(xb, y)
# Calculate total loss
total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
loss_rec_mult*loss_rec + loss_decor_mult*loss_decor)
total_loss.backward()
optim.step()
epoch_losses.append(total_loss.item())
# Record training loss
train_losses.append(np.mean(epoch_losses))
# Validation step
if e % pperiod == 0:
vae.eval()
with torch.no_grad():
z = vae.enc(x)
y = vae.dec(z, demo_t)
val_loss = rmse(x, y).item()
val_losses.append(val_loss)
print(f'Epoch {e}/{nepochs} - '
f'Train Loss: {train_losses[-1]:.4f} - '
f'Val Loss: {val_loss:.4f}')
return train_losses, val_losses