cs2229 / experiments /train_pendulum.py
pltnhan07's picture
Add files using upload-large-folder tool
3d7e366 verified
import sys
sys.path.append('../')
import os
import torch
import numpy as np
from codebase import utils as ut
from utils import get_batch_unin_dataset_withlabel
from torchvision.utils import save_image
from models.icm_vae import ICM_VAE
MAX_EPOCHS=101
DATA="pendulum"
SAVE_DIR="icm_vae_recon"
NAME="icm_vae_cdp"
DATASET_DIR='../data/pendulum'
RUN=0
TRAIN=1
ITER_SAVE=5
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")
torch.manual_seed(44)
layout = [
('model={:s}', str(NAME)),
('run={:04d}', RUN),
('toy={:s}', str(DATA) + '_' + str(NAME))
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
print('Model name:', model_name)
if not os.path.exists(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/'):
os.makedirs(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/')
def save_model_by_name(model, global_step):
save_dir = os.path.join('checkpoints', model.name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(global_step))
state = model.state_dict()
torch.save(state, file_path)
print('Saved to {}'.format(file_path))
C = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]])
scale = np.array([[0,44],[100,40],[6.5, 3.5],[10,5]])
icm_vae = ICM_VAE(name=NAME + '_' + DATA, z_dim=16, z1_dim=4, z2_dim=4, C=C, scale=scale).to(device)
train_dataset = get_batch_unin_dataset_withlabel(DATASET_DIR, 64)
optimizer = torch.optim.Adam(icm_vae.parameters(), lr=1e-3, betas=(0.9, 0.999))
def linear_scheduler(step, total_steps, initial, final):
"""Linear scheduler"""
if step >= total_steps:
return final
if step <= 0:
return initial
if total_steps <= 1:
return final
t = step / (total_steps - 1)
return (1.0 - t) * initial + t * final
for epoch in range(MAX_EPOCHS):
icm_vae.train()
total_loss = 0
total_rec = 0
total_kl = 0
for X, l in train_dataset:
optimizer.zero_grad()
#u = torch.bernoulli(u.to(device).reshape(u.size(0), -1))
X = X.to(device)
L, kl, rec, reconstructed_image, z, cp_m = icm_vae.forward(X,l,sample = False)
L.backward()
optimizer.step()
total_loss += L.item()
total_kl += kl.item()
total_rec += rec.item()
m = len(train_dataset)
save_image(X[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/true_{epoch}.png')
save_image(reconstructed_image[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/reconstructed_{epoch}.png')
beta = linear_scheduler(epoch, 94, 0.0, 1.2)
icm_vae.beta = beta
alpha = linear_scheduler(epoch, 94, 0.0, 0.1)
icm_vae.alpha = alpha
if epoch % 1 == 0:
print(str(epoch)+' loss:'+str(total_loss/m)+' kl:'+str(total_kl/m)+' rec:'+str(total_rec/m)+'m:' + str(m))
if epoch % ITER_SAVE == 0:
ut.save_model_by_name(icm_vae, epoch)