YashNagraj75's picture
Add wandb logging (Life is more peacefull)
7e6af10
import os
import argparse
import torch
import torch.nn as nn
from models.vqvae import VQVAE
from models.discriminator import Discriminator
from torch.optim import Adam
from models.lpips import LPIPS
from dataset.celeba import create_dataloader
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import yaml
import numpy as np
from tqdm import tqdm
import wandb
wandb.init(project="vqvae")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(args):
with open(args.config_path, "r") as f:
try:
config = yaml.safe_load(f)
except yaml.YAMLError as e:
print(e)
autoencoder_config = config["autoencoder_params"]
train_config = config["train_config"]
dataset_config = config["dataset_config"]
vqvae_ckpt_path = os.path.join(
train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"]
)
discriminator_ckpt_path = os.path.join(
train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"]
)
optimizer_d_ckpt = os.path.join(
train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"]
)
optimizer_g_ckpt = os.path.join(
train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"]
)
# Set seed for reproducability
seed = train_config["seed"]
torch.manual_seed(seed)
np.random.seed(seed)
data_loader = create_dataloader(dataset_config["im_path"])
if not os.path.exists(train_config["task_name"]):
os.mkdir(train_config["task_name"])
# L1/L2 as reconstruction loss
recon_criterion = nn.MSELoss()
# Disc loss is BCEWith Logits Loss
disc_criterion = nn.BCEWithLogitsLoss()
if os.path.exists(vqvae_ckpt_path) and os.path.exists(discriminator_ckpt_path):
print("Loading checkpoint...")
model = torch.load(vqvae_ckpt_path).to(device)
discriminator = torch.load(discriminator_ckpt_path).to(device)
optimizer_d = torch.load(optimizer_d_ckpt)
optimizer_g = torch.load(optimizer_g_ckpt)
else:
model = VQVAE(
im_channels=dataset_config["im_channels"], model_config=autoencoder_config
).to(device)
discriminator = Discriminator(im_channels=dataset_config["im_channels"]).to(
device
)
optimizer_d = Adam(
discriminator.parameters(),
lr=train_config["autoencoder_lr"],
betas=(0.5, 0.999),
)
optimizer_g = Adam(
model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
)
wandb.watch(model,log="all", log_freq=100)
# LPIPS model
lpips_model = LPIPS().eval().to(device)
img_save_steps = train_config["autoencoder_img_save_steps"]
img_saved = 14
disc_step_start = train_config["disc_start"]
steps = 15001
for epoch in range(train_config["autoencoder_epochs"]):
recon_losses = []
codebook_losses = []
perceptual_losses = []
disc_losses = []
gen_losses = []
losses = []
optimizer_d.zero_grad()
optimizer_g.zero_grad()
for im_tensor in tqdm(data_loader):
# Model output with losses
im_tensor = im_tensor.to(device)
model_output = model(im_tensor)
output, z, quatize_losses = model_output
# Image saving
if steps % img_save_steps == 0:
sample_size = min(8, im_tensor.shape[0])
save_output = (
torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()
)
save_output = (save_output + 1) / 2
save_input = ((im_tensor[:sample_size] + 1) / 2).detach().cpu()
grid = make_grid(
torch.cat([save_input, save_output], dim=0), nrow=sample_size
)
grid_image = ToPILImage(grid)
wandb.log({"Latent generation": wandb.Image(grid_image,caption=f"Epoch: {epoch+1}, Step: {steps}")})
img_saved += 1
steps += 1
# Optimizing generator
# Reconstruction loss
recon_loss = recon_criterion(output, im_tensor)
recon_losses.append(recon_loss.item())
g_loss = (
recon_loss
+ (train_config["codebook_beta"] * quatize_losses["codebook_loss"])
+ (train_config["commitment_beta"] * quatize_losses["commitment_loss"])
)
codebook_losses.append(
train_config["codebook_beta"] * quatize_losses["codebook_loss"].item()
)
# Adversarial loss if disc_step_start is met
if steps > disc_step_start:
disc_fake_pred = discriminator(model_output[0])
disc_fake_loss = disc_criterion(
disc_fake_pred,
torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device),
)
gen_losses.append(train_config["disc_beta"] * disc_fake_loss.item())
g_loss += train_config["disc_beta"] * disc_fake_loss
lpips_loss = torch.mean(lpips_model(output, im_tensor))
perceptual_losses.append(
train_config["perceptual_weight"] * lpips_loss.item()
)
g_loss += train_config["perceptual_weight"] * lpips_loss
losses.append(g_loss.item())
g_loss.backward()
# Optmizing Discriminator
if steps > disc_step_start:
fake = output
disc_fake_pred = discriminator(fake.detach())
disc_real_pred = discriminator(im_tensor)
disc_fake_loss = disc_criterion(
disc_fake_pred,
torch.zeros_like(disc_fake_pred, device=disc_fake_pred.device),
)
disc_real_loss = disc_criterion(
disc_real_pred,
torch.ones_like(disc_real_pred, device=disc_real_pred.device),
)
disc_loss = (
train_config["disc_beta"] * (disc_real_loss + disc_fake_loss) / 2
)
disc_losses.append(disc_loss.item())
disc_loss.backward()
optimizer_g.step()
optimizer_g.zero_grad()
optimizer_d.step()
optimizer_d.zero_grad()
wandb.log({
"epoch": epoch + 1,
"step": steps,
"image_saved": img_saved,
"recon_loss": np.mean(recon_losses),
"perceptual_loss": np.mean(perceptual_losses),
"codebook_loss": np.mean(codebook_losses),
"gen_loss": np.mean(gen_losses),
"disc_loss": np.mean(disc_losses),
"overall_loss": np.mean(losses)
})
if len(disc_losses) > 0:
print(
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
"Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}".format(
epoch + 1,
np.mean(recon_losses),
np.mean(perceptual_losses),
np.mean(codebook_losses),
np.mean(gen_losses),
np.mean(disc_losses),
)
)
else:
print(
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}".format(
epoch + 1,
np.mean(recon_losses),
np.mean(perceptual_losses),
np.mean(codebook_losses),
)
)
torch.save(
model,
os.path.join(
train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"]
),
)
torch.save(
discriminator,
os.path.join(
train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"]
),
)
torch.save(
optimizer_d,
os.path.join(
train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"]
),
)
torch.save(
optimizer_g,
os.path.join(
train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"]
),
)
wandb.save(
os.path.join(
train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"]
)
)
wandb.save(
os.path.join(
train_config["task_name"], train_config["vqvae_discriminator_ckpt_name"]
)
)
wandb.save(
os.path.join(
train_config["task_name"], train_config["vqvae_optim_d_ckpt_name"]
)
)
wandb.save(
os.path.join(
train_config["task_name"], train_config["vqvae_optim_g_ckpt_name"]
)
)
print("Done Training....")
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Arguments for vq vae training")
parser.add_argument(
"--config_path", type=str, dest="config_path", default="config/celebahq.yaml"
)
args = parser.parse_args()
train(args)