import argparse import os import random import numpy as np import torch import torchvision import yaml from torch.optim import Adam from torchvision.utils import make_grid from tqdm import tqdm from celeba import create_dataloader from model.discriminator import Discriminator from model.lpips import LPIPS from model.vae import VAE device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.backends.mps.is_available(): device = torch.device("mps") print("Using mps") def train(args): # Read the config file # with open(args.config_path, "r") as file: try: config = yaml.safe_load(file) except yaml.YAMLError as exc: print(exc) dataset_config = config["dataset_params"] autoencoder_config = config["autoencoder_params"] train_config = config["train_params"] # Set the desired seed value # seed = train_config["seed"] torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == "cuda": torch.cuda.manual_seed_all(seed) ############################# # Create the model and dataset # model = VAE( im_channels=dataset_config["im_channels"], model_config=autoencoder_config ).to(device) # Create the dataset # Create output directories if not os.path.exists(train_config["task_name"]): os.mkdir(train_config["task_name"]) num_epochs = train_config["autoencoder_epochs"] # L1/L2 loss for Reconstruction recon_criterion = torch.nn.MSELoss() # Disc Loss can even be BCEWithLogits disc_criterion = torch.nn.MSELoss() # No need to freeze lpips as lpips.py takes care of that lpips_model = LPIPS().eval().to(device) discriminator = Discriminator(im_channels=dataset_config["im_channels"]).to(device) data_loader = create_dataloader(dataset_config["im_path"]) if os.path.exists( os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ) ): model.load_state_dict( torch.load( os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ), map_location=device, ) ) print("Loaded autoencoder from checkpoint") if os.path.exists( os.path.join( train_config["task_name"], train_config["vae_discriminator_ckpt_name"] ) ): discriminator.load_state_dict( torch.load( os.path.join( train_config["task_name"], train_config["vae_discriminator_ckpt_name"], ), map_location=device, ) ) print("Loaded discriminator from checkpoint") optimizer_d = Adam( discriminator.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999), ) optimizer_g = Adam( model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999) ) disc_step_start = train_config["disc_start"] step_count = 0 # This is for accumulating gradients incase the images are huge # And one cant afford higher batch sizes acc_steps = train_config["autoencoder_acc_steps"] image_save_steps = train_config["autoencoder_img_save_steps"] img_save_count = 0 for epoch_idx in range(num_epochs): recon_losses = [] perceptual_losses = [] disc_losses = [] gen_losses = [] losses = [] optimizer_g.zero_grad() optimizer_d.zero_grad() for im in tqdm(data_loader): step_count += 1 im = im.float().to(device) # Fetch autoencoders output(reconstructions) model_output = model(im) output, encoder_output = model_output # Image Saving Logic if step_count % image_save_steps == 0 or step_count == 1: sample_size = min(8, im.shape[0]) save_output = ( torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu() ) save_output = (save_output + 1) / 2 save_input = ((im[:sample_size] + 1) / 2).detach().cpu() grid = make_grid( torch.cat([save_input, save_output], dim=0), nrow=sample_size ) img = torchvision.transforms.ToPILImage()(grid) if not os.path.exists( os.path.join(train_config["task_name"], "vae_autoencoder_samples") ): os.mkdir( os.path.join( train_config["task_name"], "vae_autoencoder_samples" ) ) img.save( os.path.join( train_config["task_name"], "vae_autoencoder_samples", "current_autoencoder_sample_{}.png".format(img_save_count), ) ) img_save_count += 1 img.close() ######### Optimize Generator ########## # L2 Loss recon_loss = recon_criterion(output, im) recon_losses.append(recon_loss.item()) recon_loss = recon_loss / acc_steps mean, logvar = torch.chunk(encoder_output, 2, dim=1) kl_loss = torch.mean( 0.5 * torch.sum(torch.exp(logvar) + mean**2 - 1 - logvar, dim=[1, 2, 3]) ) g_loss = recon_loss + (train_config["kl_weight"] * kl_loss / acc_steps) # Adversarial loss only if disc_step_start steps passed if step_count > disc_step_start: disc_fake_pred = discriminator(model_output[0]) disc_fake_loss = disc_criterion( disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device), ) gen_losses.append(train_config["disc_weight"] * disc_fake_loss.item()) g_loss += train_config["disc_weight"] * disc_fake_loss / acc_steps lpips_loss = torch.mean(lpips_model(output, im)) perceptual_losses.append( train_config["perceptual_weight"] * lpips_loss.item() ) g_loss += train_config["perceptual_weight"] * lpips_loss / acc_steps losses.append(g_loss.item()) g_loss.backward() ##################################### ######### Optimize Discriminator ####### if step_count > disc_step_start: fake = output disc_fake_pred = discriminator(fake.detach()) disc_real_pred = discriminator(im) disc_fake_loss = disc_criterion( disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device), ) disc_real_loss = disc_criterion( disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device), ) disc_loss = ( train_config["disc_weight"] * (disc_fake_loss + disc_real_loss) / 2 ) disc_losses.append(disc_loss.item()) disc_loss = disc_loss / acc_steps disc_loss.backward() if step_count % acc_steps == 0: optimizer_d.step() optimizer_d.zero_grad() ##################################### if step_count % acc_steps == 0: optimizer_g.step() optimizer_g.zero_grad() optimizer_d.step() optimizer_d.zero_grad() optimizer_g.step() optimizer_g.zero_grad() if len(disc_losses) > 0: print( "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | " "G Loss : {:.4f} | D Loss {:.4f}".format( epoch_idx + 1, np.mean(recon_losses), np.mean(perceptual_losses), np.mean(gen_losses), np.mean(disc_losses), ) ) else: print( "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}".format( epoch_idx + 1, np.mean(recon_losses), np.mean(perceptual_losses) ) ) torch.save( model.state_dict(), os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ), ) torch.save( discriminator.state_dict(), os.path.join( train_config["task_name"], train_config["vae_discriminator_ckpt_name"] ), ) print("Done Training...") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Arguments for vae training") parser.add_argument( "--config", dest="config_path", default="celeba/config.yaml", type=str ) args = parser.parse_args() train(args)