File size: 2,908 Bytes
3d7e366 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import sys
sys.path.append('../')
import os
import torch
import numpy as np
from codebase import utils as ut
from utils import get_batch_unin_dataset_withlabel
from torchvision.utils import save_image
from models.icm_vae import ICM_VAE
MAX_EPOCHS=101
DATA="flow"
SAVE_DIR="icm_vae_recon"
NAME="icm_vae_cdp"
DATASET_DIR='../data/flow_noise'
RUN=0
TRAIN=1
ITER_SAVE=5
device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")
torch.manual_seed(44)
layout = [
('model={:s}', str(NAME)),
('run={:04d}', RUN),
('toy={:s}', str(DATA) + '_' + str(NAME))
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
print('Model name:', model_name)
if not os.path.exists(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/'):
os.makedirs(f'./results/{DATA}/{DATA}_{NAME}_reconstructions/')
def save_model_by_name(model, global_step):
save_dir = os.path.join('checkpoints', model.name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(global_step))
state = model.state_dict()
torch.save(state, file_path)
print('Saved to {}'.format(file_path))
C = torch.tensor([[0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]])
scale = np.array([[20,15],[10.5, 4.5],[2,2],[59.5,26.5]])
icm_vae = ICM_VAE(name=NAME + '_' + DATA, z_dim=16, z1_dim=4, z2_dim=4, C=C, scale=scale).to(device)
dataset_dir = '../../data/orig_data/flow_noise'
train_dataset = get_batch_unin_dataset_withlabel(DATASET_DIR, 64)
optimizer = torch.optim.Adam(icm_vae.parameters(), lr=1e-3, betas=(0.9, 0.999))
def linear_scheduler(step, total_steps, initial, final):
"""Linear scheduler"""
if step >= total_steps:
return final
if step <= 0:
return initial
if total_steps <= 1:
return final
t = step / (total_steps - 1)
return (1.0 - t) * initial + t * final
for epoch in range(MAX_EPOCHS):
icm_vae.train()
total_loss = 0
total_rec = 0
total_kl = 0
for X, l in train_dataset:
optimizer.zero_grad()
#u = torch.bernoulli(u.to(device).reshape(u.size(0), -1))
X = X.to(device)
L, kl, rec, reconstructed_image, z, cp_m = icm_vae.forward(X,l,sample = False)
L.backward()
optimizer.step()
#optimizer.zero_grad()
total_loss += L.item()
total_kl += kl.item()
total_rec += rec.item()
m = len(train_dataset)
save_image(X[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/true_{epoch}.png')
save_image(reconstructed_image[0], f'./results/{DATA}/{DATA}_{NAME}_reconstructions/reconstructed_{epoch}.png')
beta = linear_scheduler(epoch, 94, 0.0, 1.2)
icm_vae.beta = beta
alpha = linear_scheduler(epoch, 94, 0.0, 0.1)
icm_vae.alpha = alpha
if epoch % 1 == 0:
print(str(epoch)+' loss:'+str(total_loss/m)+' kl:'+str(total_kl/m)+' rec:'+str(total_rec/m)+'m:' + str(m))
if epoch % ITER_SAVE == 0:
ut.save_model_by_name(icm_vae, epoch)
|