Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from sklearn.linear_model import Ridge, LogisticRegression | |
| def to_torch(x): | |
| return torch.from_numpy(x).float() | |
| def to_cuda(x, use_cuda): | |
| return x.cuda() if use_cuda else x | |
| def to_numpy(x): | |
| return x.detach().cpu().numpy() | |
| def rmse(a, b, mean=torch.mean): | |
| return mean((a-b)**2)**0.5 | |
| def latent_loss(z, use_cuda=True): | |
| C = z.T@z | |
| mu = torch.mean(z, dim=0) | |
| tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z) | |
| tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda) | |
| loss_C = rmse(C, tgt1) | |
| loss_mu = rmse(mu, tgt2) | |
| return loss_C, loss_mu, C, mu | |
| def decor_loss(z, demo, use_cuda=True): | |
| ps = [] | |
| losses = [] | |
| for di in range(demo.shape[1]): | |
| d = demo[:,di] | |
| d = d - torch.mean(d) | |
| p = torch.einsum('n,nz->z', d, z) | |
| p = p/torch.std(d) | |
| p = p/torch.einsum('nz,nz->z', z, z) | |
| tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda) | |
| loss = rmse(p, tgt) | |
| losses.append(loss) | |
| ps.append(p) | |
| losses = torch.stack(losses) | |
| return losses, ps | |
| def demo_to_torch(demo, demo_types, pred_stats, use_cuda): | |
| demo_t = [] | |
| demo_idx = 0 | |
| for d, t, s in zip(demo, demo_types, pred_stats): | |
| if t == 'continuous': | |
| demo_t.append(to_cuda(to_torch(d), use_cuda)) | |
| elif t == 'categorical': | |
| for dd in d: | |
| if dd not in s: | |
| print(f'Model not trained with value {dd} for categorical demographic {demo_idx}') | |
| raise Exception('Bad demographic') | |
| for ss in s: | |
| idx = (d == ss).astype('bool') | |
| zeros = torch.zeros(len(d)) | |
| zeros[idx] = 1 | |
| demo_t.append(to_cuda(zeros, use_cuda)) | |
| demo_idx += 1 | |
| demo_t = torch.stack(demo_t).permute(1,0) | |
| return demo_t | |
| def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, | |
| loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, | |
| loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj): | |
| # Get linear predictors for demographics | |
| pred_w = [] | |
| pred_i = [] | |
| pred_stats = [] | |
| train_losses = [] | |
| val_losses = [] | |
| for i, d, t in zip(range(len(demo)), demo, demo_types): | |
| print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='') | |
| if t == 'continuous': | |
| pred_stats.append([np.mean(d), np.std(d)]) | |
| reg = Ridge(alpha=alpha).fit(x, d) | |
| reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda) | |
| reg_i = reg.intercept_ | |
| pred_w.append(reg_w) | |
| pred_i.append(reg_i) | |
| elif t == 'categorical': | |
| pred_stats.append(sorted(list(set(list(d))))) | |
| reg = LogisticRegression(C=LR_C).fit(x, d) | |
| if len(reg.coef_) == 1: | |
| reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda) | |
| reg_i = reg.intercept_[0] | |
| pred_w.append(-reg_w) | |
| pred_i.append(-reg_i) | |
| pred_w.append(reg_w) | |
| pred_i.append(reg_i) | |
| else: | |
| for i in range(len(reg.coef_)): | |
| reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda) | |
| reg_i = reg.intercept_[i] | |
| pred_w.append(reg_w) | |
| pred_i.append(reg_i) | |
| print(' done') | |
| ret_obj.pred_stats = pred_stats | |
| # Convert input to pytorch | |
| x = to_cuda(to_torch(x), vae.use_cuda) | |
| # Convert demographics to pytorch | |
| demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda) | |
| # Training loop | |
| ce = torch.nn.CrossEntropyLoss() | |
| optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay) | |
| for e in range(nepochs): | |
| epoch_losses = [] | |
| vae.train() | |
| for bs in range(0, len(x), bsize): | |
| xb = x[bs:(bs+bsize)] | |
| db = demo_t[bs:(bs+bsize)] | |
| optim.zero_grad() | |
| # Reconstruct | |
| z = vae.enc(xb) | |
| y = vae.dec(z, db) | |
| loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda) | |
| loss_decor, _ = decor_loss(z, db, vae.use_cuda) | |
| loss_decor = sum(loss_decor) | |
| loss_rec = rmse(xb, y) | |
| # Calculate total loss | |
| total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu + | |
| loss_rec_mult*loss_rec + loss_decor_mult*loss_decor) | |
| total_loss.backward() | |
| optim.step() | |
| epoch_losses.append(total_loss.item()) | |
| # Record training loss | |
| train_losses.append(np.mean(epoch_losses)) | |
| # Validation step | |
| if e % pperiod == 0: | |
| vae.eval() | |
| with torch.no_grad(): | |
| z = vae.enc(x) | |
| y = vae.dec(z, demo_t) | |
| val_loss = rmse(x, y).item() | |
| val_losses.append(val_loss) | |
| print(f'Epoch {e}/{nepochs} - ' | |
| f'Train Loss: {train_losses[-1]:.4f} - ' | |
| f'Val Loss: {val_loss:.4f}') | |
| return train_losses, val_losses |