import torch import numpy as np from sklearn.linear_model import Ridge, LogisticRegression def to_torch(x): return torch.from_numpy(x).float() def to_cuda(x, use_cuda): return x.cuda() if use_cuda else x def to_numpy(x): return x.detach().cpu().numpy() def rmse(a, b, mean=torch.mean): return mean((a-b)**2)**0.5 def latent_loss(z, use_cuda=True): C = z.T@z mu = torch.mean(z, dim=0) tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z) tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda) loss_C = rmse(C, tgt1) loss_mu = rmse(mu, tgt2) return loss_C, loss_mu, C, mu def decor_loss(z, demo, use_cuda=True): ps = [] losses = [] for di in range(demo.shape[1]): d = demo[:,di] d = d - torch.mean(d) p = torch.einsum('n,nz->z', d, z) p = p/torch.std(d) p = p/torch.einsum('nz,nz->z', z, z) tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda) loss = rmse(p, tgt) losses.append(loss) ps.append(p) losses = torch.stack(losses) return losses, ps def demo_to_torch(demo, demo_types, pred_stats, use_cuda): demo_t = [] demo_idx = 0 for d, t, s in zip(demo, demo_types, pred_stats): if t == 'continuous': demo_t.append(to_cuda(to_torch(d), use_cuda)) elif t == 'categorical': for dd in d: if dd not in s: print(f'Model not trained with value {dd} for categorical demographic {demo_idx}') raise Exception('Bad demographic') for ss in s: idx = (d == ss).astype('bool') zeros = torch.zeros(len(d)) zeros[idx] = 1 demo_t.append(to_cuda(zeros, use_cuda)) demo_idx += 1 demo_t = torch.stack(demo_t).permute(1,0) return demo_t def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj): # Get linear predictors for demographics pred_w = [] pred_i = [] pred_stats = [] train_losses = [] val_losses = [] for i, d, t in zip(range(len(demo)), demo, demo_types): print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='') if t == 'continuous': pred_stats.append([np.mean(d), np.std(d)]) reg = Ridge(alpha=alpha).fit(x, d) reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda) reg_i = reg.intercept_ pred_w.append(reg_w) pred_i.append(reg_i) elif t == 'categorical': pred_stats.append(sorted(list(set(list(d))))) reg = LogisticRegression(C=LR_C).fit(x, d) if len(reg.coef_) == 1: reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda) reg_i = reg.intercept_[0] pred_w.append(-reg_w) pred_i.append(-reg_i) pred_w.append(reg_w) pred_i.append(reg_i) else: for i in range(len(reg.coef_)): reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda) reg_i = reg.intercept_[i] pred_w.append(reg_w) pred_i.append(reg_i) print(' done') ret_obj.pred_stats = pred_stats # Convert input to pytorch x = to_cuda(to_torch(x), vae.use_cuda) # Convert demographics to pytorch demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda) # Training loop ce = torch.nn.CrossEntropyLoss() optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay) for e in range(nepochs): epoch_losses = [] vae.train() for bs in range(0, len(x), bsize): xb = x[bs:(bs+bsize)] db = demo_t[bs:(bs+bsize)] optim.zero_grad() # Reconstruct z = vae.enc(xb) y = vae.dec(z, db) loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda) loss_decor, _ = decor_loss(z, db, vae.use_cuda) loss_decor = sum(loss_decor) loss_rec = rmse(xb, y) # Calculate total loss total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor) total_loss.backward() optim.step() epoch_losses.append(total_loss.item()) # Record training loss train_losses.append(np.mean(epoch_losses)) # Validation step if e % pperiod == 0: vae.eval() with torch.no_grad(): z = vae.enc(x) y = vae.dec(z, demo_t) val_loss = rmse(x, y).item() val_losses.append(val_loss) print(f'Epoch {e}/{nepochs} - ' f'Train Loss: {train_losses[-1]:.4f} - ' f'Val Loss: {val_loss:.4f}') return train_losses, val_losses