Diffusion-Transformer / train_dit.py
YashNagraj75's picture
Add samples at all good, less budget
da3087e
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import wandb.filesync
import yaml
from torch.optim import AdamW
from tqdm import tqdm
import wandb
from celeba import create_dataloader
from model.transformer import DIT
from model.vae import VAE
from scheduler.linear_scheduler import LinearNoiseScheduler
device = "cuda" if torch.cuda.is_available() else "cpu"
def train(args):
with open(args.config_path, "r") as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as e:
print(f"Error in loading yaml: {e}")
train_config = config["train_params"]
dit_config = config["dit_params"]
dataset_config = config["dataset_params"]
diffusion_params = config["diffusion_params"]
vae_config = config["autoencoder_params"]
wandb.init(
project="diffusion-transformer",
name=f"{train_config['task_name']}_dit_training",
config={
"train_config": train_config,
"dit_config": dit_config,
"dataset_config": dataset_config,
"diffusion_params": diffusion_params,
"vae_config": vae_config,
"device": device,
},
tags=["dit", "diffusion", "transformer"],
)
dataloader = create_dataloader(dataset_config["im_path"])
scheduler = LinearNoiseScheduler(
diffusion_params["num_timesteps"],
diffusion_params["beta_start"],
diffusion_params["beta_end"],
)
im_size = dataset_config["im_size"] // 2 ** sum(vae_config["down_sample"])
model = DIT(
im_size=im_size, im_channels=vae_config["z_channels"], config=dit_config
).to(device)
model.train()
wandb.watch(model, log="all", log_freq=100)
if os.path.exists(
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"])
):
checkpoint = torch.load(
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
map_location=device,
)
optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"])
model.load_state_dict(checkpoint["dit"])
start_epoch = checkpoint["epoch"]
step_count = checkpoint["step_count"]
optimizer.load_state_dict(checkpoint["optimizer"])
print(f"Resuming from epoch {start_epoch}, step {step_count}")
wandb.log({"resumed_from_epoch": start_epoch, "resumed_from_step": step_count})
else:
step_count = 0
start_epoch = 0
optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"])
if not os.path.exists(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
)
):
print("No VAE checkpoint found, VAE checkpoint needed")
wandb.finish()
return
else:
vae = VAE(dataset_config["im_channels"], vae_config).to(device)
vae.load_state_dict(
torch.load(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
),
map_location=device,
)
)
vae.eval()
for param in vae.parameters():
param.requires_grad = False
print("VAE checkpoint loaded")
# Log model architecture
wandb.log(
{
"model_parameters": sum(p.numel() for p in model.parameters()),
"trainable_parameters": sum(
p.numel() for p in model.parameters() if p.requires_grad
),
}
)
num_epochs = train_config["dit_epochs"]
accu_steps = train_config["dit_acc_steps"]
criterion = nn.MSELoss()
for epoch in range(start_epoch, num_epochs):
losses = []
for im in tqdm(dataloader):
im = im.float().to(device)
step_count += 1
with torch.no_grad():
im, _ = vae.encode(im)
noise = torch.randn_like(im).to(device)
t = torch.randint(0, diffusion_params["num_timesteps"], (im.shape[0],)).to(
device
)
noisy_im = scheduler.add_noise(im, noise, t)
pred = model(noisy_im, t)
loss = criterion(pred, noise)
losses.append(loss.item())
loss = loss / accu_steps
loss.backward()
if step_count % 10 == 0: # Log every 10 steps
wandb.log(
{
"batch_loss": loss.item() * accu_steps,
"learning_rate": optimizer.param_groups[0]["lr"],
"step_count": step_count,
"epoch": epoch,
}
)
if step_count % accu_steps == 0:
optimizer.step()
optimizer.zero_grad()
optimizer.step()
optimizer.zero_grad()
wandb.log(
{
"epoch": epoch,
"epoch_loss_std": np.std(losses),
"learning_rate": optimizer.param_groups[0]["lr"],
}
)
print(f"Epoch {epoch}: Loss: {np.mean(losses)}")
torch.save(
{
"dit": model.state_dict(),
"epoch": epoch + 1,
"step": step_count,
"optimizer": optimizer.state_dict(),
},
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
)
if (epoch + 1) % 5 == 0: # Save every 5 epochs
artifact = wandb.Artifact(
f"dit_model_epoch_{epoch + 1}",
type="model",
description=f"DIT model checkpoint at epoch {epoch + 1}",
)
artifact.add_file(
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
)
wandb.log_artifact(artifact)
final_artifact = wandb.Artifact(
"dit_model_final", type="model", description="Final DIT model checkpoint"
)
final_artifact.add_file(
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"])
)
wandb.log_artifact(final_artifact)
print("Done Training")
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Arguments for dit training")
parser.add_argument(
"--config", dest="config_path", default="celeba/config.yaml", type=str
)
args = parser.parse_args()
train(args)