File size: 2,830 Bytes
3d7e366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import sys
sys.path.append('../')
import os
import torch
import numpy as np
from codebase import utils as ut
from utils import get_batch_unin_dataset_withlabel
from torchvision.utils import save_image
from models.icm_vae import ICM_VAE

MAX_EPOCHS=101
DATA="pendulum"
SAVE_DIR="icm_vae_recon"
NAME="icm_vae_cdp"
DATASET_DIR='../data/pendulum'
RUN=0
TRAIN=1
ITER_SAVE=5
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")

torch.manual_seed(44)


layout = [
	('model={:s}',  str(NAME)),
	('run={:04d}', RUN),
	('toy={:s}', str(DATA) + '_' + str(NAME))
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
print('Model name:', model_name)


if not os.path.exists(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/'):
	os.makedirs(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/')


def save_model_by_name(model, global_step):
	save_dir = os.path.join('checkpoints', model.name)
	if not os.path.exists(save_dir):
		os.makedirs(save_dir)
	file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(global_step))
	state = model.state_dict()
	torch.save(state, file_path)
	print('Saved to {}'.format(file_path))	


C = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]])
scale = np.array([[0,44],[100,40],[6.5, 3.5],[10,5]])


icm_vae = ICM_VAE(name=NAME + '_' + DATA, z_dim=16, z1_dim=4, z2_dim=4, C=C, scale=scale).to(device)
train_dataset = get_batch_unin_dataset_withlabel(DATASET_DIR, 64)
optimizer = torch.optim.Adam(icm_vae.parameters(), lr=1e-3, betas=(0.9, 0.999))

def linear_scheduler(step, total_steps, initial, final):
    """Linear scheduler"""

    if step >= total_steps:
        return final
    if step <= 0:
        return initial
    if total_steps <= 1:
        return final

    t = step / (total_steps - 1)
    return (1.0 - t) * initial + t * final


for epoch in range(MAX_EPOCHS):
	icm_vae.train()
	total_loss = 0
	total_rec = 0
	total_kl = 0
	for X, l in train_dataset:
		optimizer.zero_grad()
		#u = torch.bernoulli(u.to(device).reshape(u.size(0), -1))
		X = X.to(device)
		L, kl, rec, reconstructed_image, z, cp_m = icm_vae.forward(X,l,sample = False)
       
		L.backward()
		optimizer.step()

		total_loss += L.item()
		total_kl += kl.item() 
		total_rec += rec.item() 

		m = len(train_dataset)
		save_image(X[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/true_{epoch}.png')
		save_image(reconstructed_image[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/reconstructed_{epoch}.png')

	beta = linear_scheduler(epoch, 94, 0.0, 1.2)
	icm_vae.beta = beta
    
	alpha = linear_scheduler(epoch, 94, 0.0, 0.1)
	icm_vae.alpha = alpha

	if epoch % 1 == 0:
		print(str(epoch)+' loss:'+str(total_loss/m)+' kl:'+str(total_kl/m)+' rec:'+str(total_rec/m)+'m:' + str(m))

	if epoch % ITER_SAVE == 0:
		ut.save_model_by_name(icm_vae, epoch)