File size: 5,248 Bytes
ef677f1
 
 
 
 
 
 
 
b32645b
ef677f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
ef677f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
ef677f1
 
 
 
dbe81c1
 
a4c8f0c
ef677f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
 
 
ef677f1
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
ef677f1
dbe81c1
ef677f1
 
 
 
dbe81c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import numpy as np
from sklearn.linear_model import Ridge, LogisticRegression

def to_torch(x):
    return torch.from_numpy(x).float()

def to_cuda(x, use_cuda):
    return x.cuda() if use_cuda else x

def to_numpy(x):
    return x.detach().cpu().numpy()

def rmse(a, b, mean=torch.mean):
    return mean((a-b)**2)**0.5

def latent_loss(z, use_cuda=True):
    C = z.T@z
    mu = torch.mean(z, dim=0)
    tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
    tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
    loss_C = rmse(C, tgt1)
    loss_mu = rmse(mu, tgt2)
    return loss_C, loss_mu, C, mu

def decor_loss(z, demo, use_cuda=True):
    ps = []
    losses = []
    for di in range(demo.shape[1]):
        d = demo[:,di]
        d = d - torch.mean(d)
        p = torch.einsum('n,nz->z', d, z)
        p = p/torch.std(d)
        p = p/torch.einsum('nz,nz->z', z, z)
        tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
        loss = rmse(p, tgt)
        losses.append(loss)
        ps.append(p)
    losses = torch.stack(losses)
    return losses, ps

def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
    demo_t = []
    demo_idx = 0
    for d, t, s in zip(demo, demo_types, pred_stats):
        if t == 'continuous':
            demo_t.append(to_cuda(to_torch(d), use_cuda))
        elif t == 'categorical':
            for dd in d:
                if dd not in s:
                    print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
                    raise Exception('Bad demographic')
            for ss in s:
                idx = (d == ss).astype('bool')
                zeros = torch.zeros(len(d))
                zeros[idx] = 1
                demo_t.append(to_cuda(zeros, use_cuda))
        demo_idx += 1
    demo_t = torch.stack(demo_t).permute(1,0)
    return demo_t

def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, 
              loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, 
              loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
    
    # Get linear predictors for demographics
    pred_w = []
    pred_i = []
    pred_stats = []
    train_losses = []
    val_losses = []
    
    for i, d, t in zip(range(len(demo)), demo, demo_types):
        print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
        if t == 'continuous':
            pred_stats.append([np.mean(d), np.std(d)])
            reg = Ridge(alpha=alpha).fit(x, d)
            reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
            reg_i = reg.intercept_
            pred_w.append(reg_w)
            pred_i.append(reg_i)
        elif t == 'categorical':
            pred_stats.append(sorted(list(set(list(d)))))
            reg = LogisticRegression(C=LR_C).fit(x, d)
            if len(reg.coef_) == 1:
                reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
                reg_i = reg.intercept_[0]
                pred_w.append(-reg_w)
                pred_i.append(-reg_i)
                pred_w.append(reg_w)
                pred_i.append(reg_i)
            else:
                for i in range(len(reg.coef_)):
                    reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
                    reg_i = reg.intercept_[i]
                    pred_w.append(reg_w)
                    pred_i.append(reg_i)
        print(' done')
    
    ret_obj.pred_stats = pred_stats
    
    # Convert input to pytorch
    x = to_cuda(to_torch(x), vae.use_cuda)
    
    # Convert demographics to pytorch
    demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
    
    # Training loop
    ce = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
    
    for e in range(nepochs):
        epoch_losses = []
        vae.train()
        
        for bs in range(0, len(x), bsize):
            xb = x[bs:(bs+bsize)]
            db = demo_t[bs:(bs+bsize)]
            optim.zero_grad()
            
            # Reconstruct
            z = vae.enc(xb)
            y = vae.dec(z, db)
            loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
            loss_decor, _ = decor_loss(z, db, vae.use_cuda)
            loss_decor = sum(loss_decor)
            loss_rec = rmse(xb, y)
            
            # Calculate total loss
            total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu + 
                         loss_rec_mult*loss_rec + loss_decor_mult*loss_decor)
            
            total_loss.backward()
            optim.step()
            
            epoch_losses.append(total_loss.item())
        
        # Record training loss
        train_losses.append(np.mean(epoch_losses))
        
        # Validation step
        if e % pperiod == 0:
            vae.eval()
            with torch.no_grad():
                z = vae.enc(x)
                y = vae.dec(z, demo_t)
                val_loss = rmse(x, y).item()
                val_losses.append(val_loss)
                
                print(f'Epoch {e}/{nepochs} - '
                      f'Train Loss: {train_losses[-1]:.4f} - '
                      f'Val Loss: {val_loss:.4f}')
    
    return train_losses, val_losses