Diffusion-Transformer / train_vae.py
YashNagraj75's picture
Add the dataset and the training script
31677e7
import argparse
import os
import random
import numpy as np
import torch
import torchvision
import yaml
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from celeba import create_dataloader
from model.discriminator import Discriminator
from model.lpips import LPIPS
from model.vae import VAE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using mps")
def train(args):
# Read the config file #
with open(args.config_path, "r") as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
dataset_config = config["dataset_params"]
autoencoder_config = config["autoencoder_params"]
train_config = config["train_params"]
# Set the desired seed value #
seed = train_config["seed"]
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == "cuda":
torch.cuda.manual_seed_all(seed)
#############################
# Create the model and dataset #
model = VAE(
im_channels=dataset_config["im_channels"], model_config=autoencoder_config
).to(device)
# Create the dataset
# Create output directories
if not os.path.exists(train_config["task_name"]):
os.mkdir(train_config["task_name"])
num_epochs = train_config["autoencoder_epochs"]
# L1/L2 loss for Reconstruction
recon_criterion = torch.nn.MSELoss()
# Disc Loss can even be BCEWithLogits
disc_criterion = torch.nn.MSELoss()
# No need to freeze lpips as lpips.py takes care of that
lpips_model = LPIPS().eval().to(device)
discriminator = Discriminator(im_channels=dataset_config["im_channels"]).to(device)
data_loader = create_dataloader(dataset_config["im_path"])
if os.path.exists(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
)
):
model.load_state_dict(
torch.load(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
),
map_location=device,
)
)
print("Loaded autoencoder from checkpoint")
if os.path.exists(
os.path.join(
train_config["task_name"], train_config["vae_discriminator_ckpt_name"]
)
):
discriminator.load_state_dict(
torch.load(
os.path.join(
train_config["task_name"],
train_config["vae_discriminator_ckpt_name"],
),
map_location=device,
)
)
print("Loaded discriminator from checkpoint")
optimizer_d = Adam(
discriminator.parameters(),
lr=train_config["autoencoder_lr"],
betas=(0.5, 0.999),
)
optimizer_g = Adam(
model.parameters(), lr=train_config["autoencoder_lr"], betas=(0.5, 0.999)
)
disc_step_start = train_config["disc_start"]
step_count = 0
# This is for accumulating gradients incase the images are huge
# And one cant afford higher batch sizes
acc_steps = train_config["autoencoder_acc_steps"]
image_save_steps = train_config["autoencoder_img_save_steps"]
img_save_count = 0
for epoch_idx in range(num_epochs):
recon_losses = []
perceptual_losses = []
disc_losses = []
gen_losses = []
losses = []
optimizer_g.zero_grad()
optimizer_d.zero_grad()
for im in tqdm(data_loader):
step_count += 1
im = im.float().to(device)
# Fetch autoencoders output(reconstructions)
model_output = model(im)
output, encoder_output = model_output
# Image Saving Logic
if step_count % image_save_steps == 0 or step_count == 1:
sample_size = min(8, im.shape[0])
save_output = (
torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()
)
save_output = (save_output + 1) / 2
save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
grid = make_grid(
torch.cat([save_input, save_output], dim=0), nrow=sample_size
)
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(
os.path.join(train_config["task_name"], "vae_autoencoder_samples")
):
os.mkdir(
os.path.join(
train_config["task_name"], "vae_autoencoder_samples"
)
)
img.save(
os.path.join(
train_config["task_name"],
"vae_autoencoder_samples",
"current_autoencoder_sample_{}.png".format(img_save_count),
)
)
img_save_count += 1
img.close()
######### Optimize Generator ##########
# L2 Loss
recon_loss = recon_criterion(output, im)
recon_losses.append(recon_loss.item())
recon_loss = recon_loss / acc_steps
mean, logvar = torch.chunk(encoder_output, 2, dim=1)
kl_loss = torch.mean(
0.5 * torch.sum(torch.exp(logvar) + mean**2 - 1 - logvar, dim=[1, 2, 3])
)
g_loss = recon_loss + (train_config["kl_weight"] * kl_loss / acc_steps)
# Adversarial loss only if disc_step_start steps passed
if step_count > disc_step_start:
disc_fake_pred = discriminator(model_output[0])
disc_fake_loss = disc_criterion(
disc_fake_pred,
torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device),
)
gen_losses.append(train_config["disc_weight"] * disc_fake_loss.item())
g_loss += train_config["disc_weight"] * disc_fake_loss / acc_steps
lpips_loss = torch.mean(lpips_model(output, im))
perceptual_losses.append(
train_config["perceptual_weight"] * lpips_loss.item()
)
g_loss += train_config["perceptual_weight"] * lpips_loss / acc_steps
losses.append(g_loss.item())
g_loss.backward()
#####################################
######### Optimize Discriminator #######
if step_count > disc_step_start:
fake = output
disc_fake_pred = discriminator(fake.detach())
disc_real_pred = discriminator(im)
disc_fake_loss = disc_criterion(
disc_fake_pred,
torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device),
)
disc_real_loss = disc_criterion(
disc_real_pred,
torch.ones(disc_real_pred.shape, device=disc_real_pred.device),
)
disc_loss = (
train_config["disc_weight"] * (disc_fake_loss + disc_real_loss) / 2
)
disc_losses.append(disc_loss.item())
disc_loss = disc_loss / acc_steps
disc_loss.backward()
if step_count % acc_steps == 0:
optimizer_d.step()
optimizer_d.zero_grad()
#####################################
if step_count % acc_steps == 0:
optimizer_g.step()
optimizer_g.zero_grad()
optimizer_d.step()
optimizer_d.zero_grad()
optimizer_g.step()
optimizer_g.zero_grad()
if len(disc_losses) > 0:
print(
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | "
"G Loss : {:.4f} | D Loss {:.4f}".format(
epoch_idx + 1,
np.mean(recon_losses),
np.mean(perceptual_losses),
np.mean(gen_losses),
np.mean(disc_losses),
)
)
else:
print(
"Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}".format(
epoch_idx + 1, np.mean(recon_losses), np.mean(perceptual_losses)
)
)
torch.save(
model.state_dict(),
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
),
)
torch.save(
discriminator.state_dict(),
os.path.join(
train_config["task_name"], train_config["vae_discriminator_ckpt_name"]
),
)
print("Done Training...")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Arguments for vae training")
parser.add_argument(
"--config", dest="config_path", default="celeba/config.yaml", type=str
)
args = parser.parse_args()
train(args)