Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import random | |
| import numpy as np | |
| from sklearn.linear_model import Ridge | |
| from sklearn.linear_model import LogisticRegression | |
| def to_torch(x): | |
| return torch.from_numpy(x).float() | |
| def to_cuda(x, use_cuda): | |
| if use_cuda: | |
| try: | |
| return x.cuda() | |
| except (RuntimeError, AssertionError) as e: | |
| print(f"Warning: CUDA error: {e}. Falling back to CPU.") | |
| return x | |
| else: | |
| return x | |
| def to_numpy(x): | |
| return x.detach().cpu().numpy() | |
| class VAE(nn.Module): | |
| def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True): | |
| super(VAE, self).__init__() | |
| self.input_dim = input_dim | |
| self.latent_dim = latent_dim | |
| self.demo_dim = demo_dim | |
| self.use_cuda = use_cuda | |
| self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda) | |
| self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda) | |
| self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda) | |
| self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda) | |
| def enc(self, x): | |
| x = F.relu(self.enc1(x)) | |
| z = self.enc2(x) | |
| return z | |
| def gen(self, n): | |
| return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda) | |
| def dec(self, z, demo): | |
| z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda) | |
| x = F.relu(self.dec1(z)) | |
| x = self.dec2(x) | |
| #x = x.reshape(len(z), 264, 5) | |
| #x = torch.einsum('nac,nbc->nab', x, x) | |
| #a,b = np.triu_indices(264, 1) | |
| #x = x[:,a,b] | |
| return x | |
| def rmse(a, b, mean=torch.mean): | |
| return mean((a-b)**2)**0.5 | |
| def latent_loss(z, use_cuda=True): | |
| C = z.T@z | |
| mu = torch.mean(z, dim=0) | |
| tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z) | |
| tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda) | |
| loss_C = rmse(C, tgt1) | |
| loss_mu = rmse(mu, tgt2) | |
| return loss_C, loss_mu, C, mu | |
| def decor_loss(z, demo, use_cuda=True): | |
| ps = [] | |
| losses = [] | |
| for di in range(demo.shape[1]): | |
| d = demo[:,di] | |
| d = d - torch.mean(d) | |
| p = torch.einsum('n,nz->z', d, z) | |
| p = p/torch.std(d) | |
| p = p/torch.einsum('nz,nz->z', z, z) | |
| tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda) | |
| loss = rmse(p, tgt) | |
| losses.append(loss) | |
| ps.append(p) | |
| losses = torch.stack(losses) | |
| return losses, ps | |
| def pretty(x): | |
| return f'{round(float(x), 4)}' | |
| def demo_to_torch(demo, demo_types, pred_stats, use_cuda): | |
| demo_t = [] | |
| demo_idx = 0 | |
| for d,t,s in zip(demo, demo_types, pred_stats): | |
| if t == 'continuous': | |
| demo_t.append(to_cuda(to_torch(d), use_cuda)) | |
| elif t == 'categorical': | |
| for dd in d: | |
| if dd not in s: | |
| print(f'Model not trained with value {dd} for categorical demographic {demo_idx}') | |
| raise Exception('Bad demographic') | |
| for ss in s: | |
| idx = (d == ss).astype('bool') | |
| zeros = torch.zeros(len(d)) | |
| zeros[idx] = 1 | |
| demo_t.append(to_cuda(zeros, use_cuda)) | |
| demo_idx += 1 | |
| demo_t = torch.stack(demo_t).permute(1,0) | |
| return demo_t | |
| def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj): | |
| # Get linear predictors for demographics | |
| pred_w = [] | |
| pred_i = [] | |
| # Pred stats are mean and std for continuous, and a list of all values for categorical | |
| pred_stats = [] | |
| for i,d,t in zip(range(len(demo)), demo, demo_types): | |
| print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='') | |
| if t == 'continuous': | |
| pred_stats.append([np.mean(d), np.std(d)]) | |
| reg = Ridge(alpha=alpha).fit(x, d) | |
| reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda) | |
| reg_i = reg.intercept_ | |
| pred_w.append(reg_w) | |
| pred_i.append(reg_i) | |
| elif t == 'categorical': | |
| pred_stats.append(sorted(list(set(list(d))))) | |
| reg = LogisticRegression(C=LR_C).fit(x, d) | |
| # Binary | |
| if len(reg.coef_) == 1: | |
| reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda) | |
| reg_i = reg.intercept_[0] | |
| pred_w.append(-reg_w) | |
| pred_i.append(-reg_i) | |
| pred_w.append(reg_w) | |
| pred_i.append(reg_i) | |
| # Categorical | |
| else: | |
| for i in range(len(reg.coef_)): | |
| reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda) | |
| reg_i = reg.intercept_[i] | |
| pred_w.append(reg_w) | |
| pred_i.append(reg_i) | |
| else: | |
| print(f'demographic type "{t}" not "continuous" or "categorical"') | |
| raise Exception('Bad demographic type') | |
| print(' done') | |
| ret_obj.pred_stats = pred_stats | |
| # Convert input to pytorch | |
| print('Converting input to pytorch') | |
| x = to_cuda(to_torch(x), vae.use_cuda) | |
| # Convert demographics to pytorch | |
| print('Converting demographics to pytorch') | |
| demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda) | |
| # Training loop | |
| print('Beginning VAE training') | |
| ce = nn.CrossEntropyLoss() | |
| optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay) | |
| for e in range(nepochs): | |
| for bs in range(0,len(x),bsize): | |
| xb = x[bs:(bs+bsize)] | |
| db = demo_t[bs:(bs+bsize)] | |
| optim.zero_grad() | |
| # Reconstruct | |
| z = vae.enc(xb) | |
| y = vae.dec(z, db) | |
| loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda) | |
| loss_decor, _ = decor_loss(z, db, vae.use_cuda) | |
| loss_decor = sum(loss_decor) | |
| loss_rec = rmse(xb, y) | |
| # Sample demographics | |
| demo_gen = [] | |
| for s,t in zip(pred_stats, demo_types): | |
| if t == 'continuous': | |
| mu = s[0] | |
| std = s[1] | |
| dd = torch.randn(100).float() | |
| dd = dd*std+mu | |
| dd = to_cuda(dd, vae.use_cuda) | |
| demo_gen.append(dd) | |
| elif t == 'categorical': | |
| idx = random.randint(0, len(s)-1) | |
| for i in range(len(s)): | |
| if idx == i: | |
| dd = torch.ones(100).float() | |
| else: | |
| dd = torch.zeros(100).float() | |
| dd = to_cuda(dd, vae.use_cuda) | |
| demo_gen.append(dd) | |
| demo_gen = torch.stack(demo_gen).permute(1,0) | |
| # Generate | |
| z = vae.gen(100) | |
| y = vae.dec(z, demo_gen) | |
| # Regressor/classifier guidance loss | |
| losses_pred = [] | |
| idcs = [] | |
| dg_idx = 0 | |
| for s,t in zip(pred_stats, demo_types): | |
| if t == 'continuous': | |
| yy = y@pred_w[dg_idx]+pred_i[dg_idx] | |
| loss = rmse(demo_gen[:,dg_idx], yy) | |
| losses_pred.append(loss) | |
| idcs.append(float(demo_gen[0,dg_idx])) | |
| dg_idx += 1 | |
| elif t == 'categorical': | |
| loss = 0 | |
| for i in range(len(s)): | |
| yy = y@pred_w[dg_idx]+pred_i[dg_idx] | |
| loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long()) | |
| idcs.append(int(demo_gen[0,dg_idx])) | |
| dg_idx += 1 | |
| losses_pred.append(loss) | |
| total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred) | |
| total_loss.backward() | |
| optim.step() | |
| if e%pperiod == 0 or e == nepochs-1: | |
| print(f'Epoch {e} ', end='') | |
| print(f'ReconLoss {pretty(loss_rec)} ', end='') | |
| print(f'CovarianceLoss {pretty(loss_C)} ', end='') | |
| print(f'MeanLoss {pretty(loss_mu)} ', end='') | |
| print(f'DecorLoss {pretty(loss_decor)} ', end='') | |
| losses_pred = [pretty(loss) for loss in losses_pred] | |
| print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='') | |
| print() | |
| print('Training complete.') | |