| import os |
| import argparse |
| import torch |
| import torch.nn as nn |
| from models.vqvae import VQVAE |
| from models.discriminator import Discriminator |
| from torch.optim import Adam |
| from models.lpips import LPIPS |
| from dataset.celeba import create_dataloader |
| from torchvision.utils import make_grid |
| from torchvision.transforms import ToPILImage |
| import yaml |
| import numpy as np |
| from tqdm import tqdm |
| import wandb |
|
|
| wandb.init(project="vqvae") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def train(args): |
| with open(args.config_path, "r") as f: |
| try: |
| config = yaml.safe_load(f) |
| except yaml.YAMLError as e: |
| print(e) |
|
|
| autoencoder_config = config["autoencoder_params"] |
| train_config = config["train_config"] |
| dataset_config = config["dataset_config"] |
| vqvae_ckpt_path = os.path.join( |
| train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] |
| ) |
| discriminator_ckpt_path = os.path.join( |
| train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"] |
| ) |
| optimizer_d_ckpt = os.path.join( |
| train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"] |
| ) |
| optimizer_g_ckpt = os.path.join( |
| train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"] |
| ) |
|
|
| |
| seed = train_config["seed"] |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| data_loader = create_dataloader(dataset_config["im_path"]) |
|
|
| if not os.path.exists(train_config["task_name"]): |
| os.mkdir(train_config["task_name"]) |
|
|
| |
| recon_criterion = nn.MSELoss() |
| |
| disc_criterion = nn.BCEWithLogitsLoss() |
|
|
| if os.path.exists(vqvae_ckpt_path) and os.path.exists(discriminator_ckpt_path): |
| print("Loading checkpoint...") |
| model = torch.load(vqvae_ckpt_path).to(device) |
| discriminator = torch.load(discriminator_ckpt_path).to(device) |
| optimizer_d = torch.load(optimizer_d_ckpt) |
| optimizer_g = torch.load(optimizer_g_ckpt) |
|
|
| else: |
| model = VQVAE( |
| im_channels=dataset_config["im_channels"], model_config=autoencoder_config |
| ).to(device) |
| discriminator = Discriminator(im_channels=dataset_config["im_channels"]).to( |
| device |
| ) |
| optimizer_d = Adam( |
| discriminator.parameters(), |
| lr=train_config["autoencoder_lr"], |
| betas=(0.5, 0.999), |
| ) |
| optimizer_g = Adam( |
| model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999) |
| ) |
| wandb.watch(model,log="all", log_freq=100) |
|
|
| |
| lpips_model = LPIPS().eval().to(device) |
|
|
| img_save_steps = train_config["autoencoder_img_save_steps"] |
| img_saved = 14 |
|
|
| disc_step_start = train_config["disc_start"] |
| steps = 15001 |
|
|
| for epoch in range(train_config["autoencoder_epochs"]): |
| recon_losses = [] |
| codebook_losses = [] |
| perceptual_losses = [] |
| disc_losses = [] |
| gen_losses = [] |
| losses = [] |
|
|
| optimizer_d.zero_grad() |
| optimizer_g.zero_grad() |
|
|
| for im_tensor in tqdm(data_loader): |
| |
| im_tensor = im_tensor.to(device) |
| model_output = model(im_tensor) |
| output, z, quatize_losses = model_output |
|
|
| |
| if steps % img_save_steps == 0: |
| sample_size = min(8, im_tensor.shape[0]) |
| save_output = ( |
| torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu() |
| ) |
| save_output = (save_output + 1) / 2 |
| save_input = ((im_tensor[:sample_size] + 1) / 2).detach().cpu() |
|
|
| grid = make_grid( |
| torch.cat([save_input, save_output], dim=0), nrow=sample_size |
| ) |
| |
| grid_image = ToPILImage(grid) |
| wandb.log({"Latent generation": wandb.Image(grid_image,caption=f"Epoch: {epoch+1}, Step: {steps}")}) |
| img_saved += 1 |
|
|
| steps += 1 |
|
|
| |
| |
| recon_loss = recon_criterion(output, im_tensor) |
| recon_losses.append(recon_loss.item()) |
| g_loss = ( |
| recon_loss |
| + (train_config["codebook_beta"] * quatize_losses["codebook_loss"]) |
| + (train_config["commitment_beta"] * quatize_losses["commitment_loss"]) |
| ) |
| codebook_losses.append( |
| train_config["codebook_beta"] * quatize_losses["codebook_loss"].item() |
| ) |
|
|
| |
| if steps > disc_step_start: |
| disc_fake_pred = discriminator(model_output[0]) |
| disc_fake_loss = disc_criterion( |
| disc_fake_pred, |
| torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device), |
| ) |
| gen_losses.append(train_config["disc_beta"] * disc_fake_loss.item()) |
| g_loss += train_config["disc_beta"] * disc_fake_loss |
| lpips_loss = torch.mean(lpips_model(output, im_tensor)) |
| perceptual_losses.append( |
| train_config["perceptual_weight"] * lpips_loss.item() |
| ) |
| g_loss += train_config["perceptual_weight"] * lpips_loss |
| losses.append(g_loss.item()) |
| g_loss.backward() |
|
|
| |
| if steps > disc_step_start: |
| fake = output |
| disc_fake_pred = discriminator(fake.detach()) |
| disc_real_pred = discriminator(im_tensor) |
| disc_fake_loss = disc_criterion( |
| disc_fake_pred, |
| torch.zeros_like(disc_fake_pred, device=disc_fake_pred.device), |
| ) |
| disc_real_loss = disc_criterion( |
| disc_real_pred, |
| torch.ones_like(disc_real_pred, device=disc_real_pred.device), |
| ) |
| disc_loss = ( |
| train_config["disc_beta"] * (disc_real_loss + disc_fake_loss) / 2 |
| ) |
| disc_losses.append(disc_loss.item()) |
| disc_loss.backward() |
|
|
| optimizer_g.step() |
| optimizer_g.zero_grad() |
| optimizer_d.step() |
| optimizer_d.zero_grad() |
|
|
| wandb.log({ |
| "epoch": epoch + 1, |
| "step": steps, |
| "image_saved": img_saved, |
| "recon_loss": np.mean(recon_losses), |
| "perceptual_loss": np.mean(perceptual_losses), |
| "codebook_loss": np.mean(codebook_losses), |
| "gen_loss": np.mean(gen_losses), |
| "disc_loss": np.mean(disc_losses), |
| "overall_loss": np.mean(losses) |
| }) |
|
|
| if len(disc_losses) > 0: |
| print( |
| "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | " |
| "Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}".format( |
| epoch + 1, |
| np.mean(recon_losses), |
| np.mean(perceptual_losses), |
| np.mean(codebook_losses), |
| np.mean(gen_losses), |
| np.mean(disc_losses), |
| ) |
| ) |
| else: |
| print( |
| "Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}".format( |
| epoch + 1, |
| np.mean(recon_losses), |
| np.mean(perceptual_losses), |
| np.mean(codebook_losses), |
| ) |
| ) |
| torch.save( |
| model, |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] |
| ), |
| ) |
| torch.save( |
| discriminator, |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"] |
| ), |
| ) |
| torch.save( |
| optimizer_d, |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"] |
| ), |
| ) |
| torch.save( |
| optimizer_g, |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"] |
| ), |
| ) |
|
|
| wandb.save( |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] |
| ) |
| ) |
| wandb.save( |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"] |
| ) |
| ) |
| wandb.save( |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"] |
| ) |
| ) |
| wandb.save( |
| os.path.join( |
| train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"] |
| ) |
| ) |
|
|
| print("Done Training....") |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Arguments for vq vae training") |
| parser.add_argument( |
| "--config_path", type=str, dest="config_path", default="config/celebahq.yaml" |
| ) |
| args = parser.parse_args() |
| train(args) |
|
|