File size: 2,830 Bytes
3d7e366 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import sys
sys.path.append('../')
import os
import torch
import numpy as np
from codebase import utils as ut
from utils import get_batch_unin_dataset_withlabel
from torchvision.utils import save_image
from models.icm_vae import ICM_VAE
MAX_EPOCHS=101
DATA="pendulum"
SAVE_DIR="icm_vae_recon"
NAME="icm_vae_cdp"
DATASET_DIR='../data/pendulum'
RUN=0
TRAIN=1
ITER_SAVE=5
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")
torch.manual_seed(44)
layout = [
('model={:s}', str(NAME)),
('run={:04d}', RUN),
('toy={:s}', str(DATA) + '_' + str(NAME))
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
print('Model name:', model_name)
if not os.path.exists(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/'):
os.makedirs(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/')
def save_model_by_name(model, global_step):
save_dir = os.path.join('checkpoints', model.name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(global_step))
state = model.state_dict()
torch.save(state, file_path)
print('Saved to {}'.format(file_path))
C = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]])
scale = np.array([[0,44],[100,40],[6.5, 3.5],[10,5]])
icm_vae = ICM_VAE(name=NAME + '_' + DATA, z_dim=16, z1_dim=4, z2_dim=4, C=C, scale=scale).to(device)
train_dataset = get_batch_unin_dataset_withlabel(DATASET_DIR, 64)
optimizer = torch.optim.Adam(icm_vae.parameters(), lr=1e-3, betas=(0.9, 0.999))
def linear_scheduler(step, total_steps, initial, final):
"""Linear scheduler"""
if step >= total_steps:
return final
if step <= 0:
return initial
if total_steps <= 1:
return final
t = step / (total_steps - 1)
return (1.0 - t) * initial + t * final
for epoch in range(MAX_EPOCHS):
icm_vae.train()
total_loss = 0
total_rec = 0
total_kl = 0
for X, l in train_dataset:
optimizer.zero_grad()
#u = torch.bernoulli(u.to(device).reshape(u.size(0), -1))
X = X.to(device)
L, kl, rec, reconstructed_image, z, cp_m = icm_vae.forward(X,l,sample = False)
L.backward()
optimizer.step()
total_loss += L.item()
total_kl += kl.item()
total_rec += rec.item()
m = len(train_dataset)
save_image(X[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/true_{epoch}.png')
save_image(reconstructed_image[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/reconstructed_{epoch}.png')
beta = linear_scheduler(epoch, 94, 0.0, 1.2)
icm_vae.beta = beta
alpha = linear_scheduler(epoch, 94, 0.0, 0.1)
icm_vae.alpha = alpha
if epoch % 1 == 0:
print(str(epoch)+' loss:'+str(total_loss/m)+' kl:'+str(total_kl/m)+' rec:'+str(total_rec/m)+'m:' + str(m))
if epoch % ITER_SAVE == 0:
ut.save_model_by_name(icm_vae, epoch)
|