import os import argparse import torch import torch.nn as nn from models.vqvae import VQVAE from models.discriminator import Discriminator from torch.optim import Adam from models.lpips import LPIPS from dataset.celeba import create_dataloader from torchvision.utils import make_grid from torchvision.transforms import ToPILImage import yaml import numpy as np from tqdm import tqdm import wandb wandb.init(project="vqvae") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def train(args): with open(args.config_path, "r") as f: try: config = yaml.safe_load(f) except yaml.YAMLError as e: print(e) autoencoder_config = config["autoencoder_params"] train_config = config["train_config"] dataset_config = config["dataset_config"] vqvae_ckpt_path = os.path.join( train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] ) discriminator_ckpt_path = os.path.join( train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"] ) optimizer_d_ckpt = os.path.join( train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"] ) optimizer_g_ckpt = os.path.join( train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"] ) # Set seed for reproducability seed = train_config["seed"] torch.manual_seed(seed) np.random.seed(seed) data_loader = create_dataloader(dataset_config["im_path"]) if not os.path.exists(train_config["task_name"]): os.mkdir(train_config["task_name"]) # L1/L2 as reconstruction loss recon_criterion = nn.MSELoss() # Disc loss is BCEWith Logits Loss disc_criterion = nn.BCEWithLogitsLoss() if os.path.exists(vqvae_ckpt_path) and os.path.exists(discriminator_ckpt_path): print("Loading checkpoint...") model = torch.load(vqvae_ckpt_path).to(device) discriminator = torch.load(discriminator_ckpt_path).to(device) optimizer_d = torch.load(optimizer_d_ckpt) optimizer_g = torch.load(optimizer_g_ckpt) else: model = VQVAE( im_channels=dataset_config["im_channels"], model_config=autoencoder_config ).to(device) discriminator = Discriminator(im_channels=dataset_config["im_channels"]).to( device ) optimizer_d = Adam( discriminator.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999), ) optimizer_g = Adam( model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999) ) wandb.watch(model,log="all", log_freq=100) # LPIPS model lpips_model = LPIPS().eval().to(device) img_save_steps = train_config["autoencoder_img_save_steps"] img_saved = 14 disc_step_start = train_config["disc_start"] steps = 15001 for epoch in range(train_config["autoencoder_epochs"]): recon_losses = [] codebook_losses = [] perceptual_losses = [] disc_losses = [] gen_losses = [] losses = [] optimizer_d.zero_grad() optimizer_g.zero_grad() for im_tensor in tqdm(data_loader): # Model output with losses im_tensor = im_tensor.to(device) model_output = model(im_tensor) output, z, quatize_losses = model_output # Image saving if steps % img_save_steps == 0: sample_size = min(8, im_tensor.shape[0]) save_output = ( torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu() ) save_output = (save_output + 1) / 2 save_input = ((im_tensor[:sample_size] + 1) / 2).detach().cpu() grid = make_grid( torch.cat([save_input, save_output], dim=0), nrow=sample_size ) grid_image = ToPILImage(grid) wandb.log({"Latent generation": wandb.Image(grid_image,caption=f"Epoch: {epoch+1}, Step: {steps}")}) img_saved += 1 steps += 1 # Optimizing generator # Reconstruction loss recon_loss = recon_criterion(output, im_tensor) recon_losses.append(recon_loss.item()) g_loss = ( recon_loss + (train_config["codebook_beta"] * quatize_losses["codebook_loss"]) + (train_config["commitment_beta"] * quatize_losses["commitment_loss"]) ) codebook_losses.append( train_config["codebook_beta"] * quatize_losses["codebook_loss"].item() ) # Adversarial loss if disc_step_start is met if steps > disc_step_start: disc_fake_pred = discriminator(model_output[0]) disc_fake_loss = disc_criterion( disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device), ) gen_losses.append(train_config["disc_beta"] * disc_fake_loss.item()) g_loss += train_config["disc_beta"] * disc_fake_loss lpips_loss = torch.mean(lpips_model(output, im_tensor)) perceptual_losses.append( train_config["perceptual_weight"] * lpips_loss.item() ) g_loss += train_config["perceptual_weight"] * lpips_loss losses.append(g_loss.item()) g_loss.backward() # Optmizing Discriminator if steps > disc_step_start: fake = output disc_fake_pred = discriminator(fake.detach()) disc_real_pred = discriminator(im_tensor) disc_fake_loss = disc_criterion( disc_fake_pred, torch.zeros_like(disc_fake_pred, device=disc_fake_pred.device), ) disc_real_loss = disc_criterion( disc_real_pred, torch.ones_like(disc_real_pred, device=disc_real_pred.device), ) disc_loss = ( train_config["disc_beta"] * (disc_real_loss + disc_fake_loss) / 2 ) disc_losses.append(disc_loss.item()) disc_loss.backward() optimizer_g.step() optimizer_g.zero_grad() optimizer_d.step() optimizer_d.zero_grad() wandb.log({ "epoch": epoch + 1, "step": steps, "image_saved": img_saved, "recon_loss": np.mean(recon_losses), "perceptual_loss": np.mean(perceptual_losses), "codebook_loss": np.mean(codebook_losses), "gen_loss": np.mean(gen_losses), "disc_loss": np.mean(disc_losses), "overall_loss": np.mean(losses) }) if len(disc_losses) > 0: print( "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | " "Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}".format( epoch + 1, np.mean(recon_losses), np.mean(perceptual_losses), np.mean(codebook_losses), np.mean(gen_losses), np.mean(disc_losses), ) ) else: print( "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}".format( epoch + 1, np.mean(recon_losses), np.mean(perceptual_losses), np.mean(codebook_losses), ) ) torch.save( model, os.path.join( train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] ), ) torch.save( discriminator, os.path.join( train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"] ), ) torch.save( optimizer_d, os.path.join( train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"] ), ) torch.save( optimizer_g, os.path.join( train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"] ), ) wandb.save( os.path.join( train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] ) ) wandb.save( os.path.join( train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"] ) ) wandb.save( os.path.join( train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"] ) ) wandb.save( os.path.join( train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"] ) ) print("Done Training....") wandb.finish() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Arguments for vq vae training") parser.add_argument( "--config_path", type=str, dest="config_path", default="config/celebahq.yaml" ) args = parser.parse_args() train(args)