File size: 9,283 Bytes
31677e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | import argparse
import os
import random
import numpy as np
import torch
import torchvision
import yaml
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from celeba import create_dataloader
from model.discriminator import Discriminator
from model.lpips import LPIPS
from model.vae import VAE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using mps")
def train(args):
# Read the config file #
with open(args.config_path, "r") as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
dataset_config = config["dataset_params"]
autoencoder_config = config["autoencoder_params"]
train_config = config["train_params"]
# Set the desired seed value #
seed = train_config["seed"]
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == "cuda":
torch.cuda.manual_seed_all(seed)
#############################
# Create the model and dataset #
model = VAE(
im_channels=dataset_config["im_channels"], model_config=autoencoder_config
).to(device)
# Create the dataset
# Create output directories
if not os.path.exists(train_config["task_name"]):
os.mkdir(train_config["task_name"])
num_epochs = train_config["autoencoder_epochs"]
# L1/L2 loss for Reconstruction
recon_criterion = torch.nn.MSELoss()
# Disc Loss can even be BCEWithLogits
disc_criterion = torch.nn.MSELoss()
# No need to freeze lpips as lpips.py takes care of that
lpips_model = LPIPS().eval().to(device)
discriminator = Discriminator(im_channels=dataset_config["im_channels"]).to(device)
data_loader = create_dataloader(dataset_config["im_path"])
if os.path.exists(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
)
):
model.load_state_dict(
torch.load(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
),
map_location=device,
)
)
print("Loaded autoencoder from checkpoint")
if os.path.exists(
os.path.join(
train_config["task_name"], train_config["vae_discriminator_ckpt_name"]
)
):
discriminator.load_state_dict(
torch.load(
os.path.join(
train_config["task_name"],
train_config["vae_discriminator_ckpt_name"],
),
map_location=device,
)
)
print("Loaded discriminator from checkpoint")
optimizer_d = Adam(
discriminator.parameters(),
lr=train_config["autoencoder_lr"],
betas=(0.5, 0.999),
)
optimizer_g = Adam(
model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
)
disc_step_start = train_config["disc_start"]
step_count = 0
# This is for accumulating gradients incase the images are huge
# And one cant afford higher batch sizes
acc_steps = train_config["autoencoder_acc_steps"]
image_save_steps = train_config["autoencoder_img_save_steps"]
img_save_count = 0
for epoch_idx in range(num_epochs):
recon_losses = []
perceptual_losses = []
disc_losses = []
gen_losses = []
losses = []
optimizer_g.zero_grad()
optimizer_d.zero_grad()
for im in tqdm(data_loader):
step_count += 1
im = im.float().to(device)
# Fetch autoencoders output(reconstructions)
model_output = model(im)
output, encoder_output = model_output
# Image Saving Logic
if step_count % image_save_steps == 0 or step_count == 1:
sample_size = min(8, im.shape[0])
save_output = (
torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()
)
save_output = (save_output + 1) / 2
save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
grid = make_grid(
torch.cat([save_input, save_output], dim=0), nrow=sample_size
)
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(
os.path.join(train_config["task_name"], "vae_autoencoder_samples")
):
os.mkdir(
os.path.join(
train_config["task_name"], "vae_autoencoder_samples"
)
)
img.save(
os.path.join(
train_config["task_name"],
"vae_autoencoder_samples",
"current_autoencoder_sample_{}.png".format(img_save_count),
)
)
img_save_count += 1
img.close()
######### Optimize Generator ##########
# L2 Loss
recon_loss = recon_criterion(output, im)
recon_losses.append(recon_loss.item())
recon_loss = recon_loss / acc_steps
mean, logvar = torch.chunk(encoder_output, 2, dim=1)
kl_loss = torch.mean(
0.5 * torch.sum(torch.exp(logvar) + mean**2 - 1 - logvar, dim=[1, 2, 3])
)
g_loss = recon_loss + (train_config["kl_weight"] * kl_loss / acc_steps)
# Adversarial loss only if disc_step_start steps passed
if step_count > disc_step_start:
disc_fake_pred = discriminator(model_output[0])
disc_fake_loss = disc_criterion(
disc_fake_pred,
torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device),
)
gen_losses.append(train_config["disc_weight"] * disc_fake_loss.item())
g_loss += train_config["disc_weight"] * disc_fake_loss / acc_steps
lpips_loss = torch.mean(lpips_model(output, im))
perceptual_losses.append(
train_config["perceptual_weight"] * lpips_loss.item()
)
g_loss += train_config["perceptual_weight"] * lpips_loss / acc_steps
losses.append(g_loss.item())
g_loss.backward()
#####################################
######### Optimize Discriminator #######
if step_count > disc_step_start:
fake = output
disc_fake_pred = discriminator(fake.detach())
disc_real_pred = discriminator(im)
disc_fake_loss = disc_criterion(
disc_fake_pred,
torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device),
)
disc_real_loss = disc_criterion(
disc_real_pred,
torch.ones(disc_real_pred.shape, device=disc_real_pred.device),
)
disc_loss = (
train_config["disc_weight"] * (disc_fake_loss + disc_real_loss) / 2
)
disc_losses.append(disc_loss.item())
disc_loss = disc_loss / acc_steps
disc_loss.backward()
if step_count % acc_steps == 0:
optimizer_d.step()
optimizer_d.zero_grad()
#####################################
if step_count % acc_steps == 0:
optimizer_g.step()
optimizer_g.zero_grad()
optimizer_d.step()
optimizer_d.zero_grad()
optimizer_g.step()
optimizer_g.zero_grad()
if len(disc_losses) > 0:
print(
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
"G Loss : {:.4f} | D Loss {:.4f}".format(
epoch_idx + 1,
np.mean(recon_losses),
np.mean(perceptual_losses),
np.mean(gen_losses),
np.mean(disc_losses),
)
)
else:
print(
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}".format(
epoch_idx + 1, np.mean(recon_losses), np.mean(perceptual_losses)
)
)
torch.save(
model.state_dict(),
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
),
)
torch.save(
discriminator.state_dict(),
os.path.join(
train_config["task_name"], train_config["vae_discriminator_ckpt_name"]
),
)
print("Done Training...")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Arguments for vae training")
parser.add_argument(
"--config", dest="config_path", default="celeba/config.yaml", type=str
)
args = parser.parse_args()
train(args)
|