import argparse import os import numpy as np import torch import torch.nn as nn import wandb.filesync import yaml from torch.optim import AdamW from tqdm import tqdm import wandb from celeba import create_dataloader from model.transformer import DIT from model.vae import VAE from scheduler.linear_scheduler import LinearNoiseScheduler device = "cuda" if torch.cuda.is_available() else "cpu" def train(args): with open(args.config_path, "r") as file: try: config = yaml.safe_load(file) except yaml.YAMLError as e: print(f"Error in loading yaml: {e}") train_config = config["train_params"] dit_config = config["dit_params"] dataset_config = config["dataset_params"] diffusion_params = config["diffusion_params"] vae_config = config["autoencoder_params"] wandb.init( project="diffusion-transformer", name=f"{train_config['task_name']}_dit_training", config={ "train_config": train_config, "dit_config": dit_config, "dataset_config": dataset_config, "diffusion_params": diffusion_params, "vae_config": vae_config, "device": device, }, tags=["dit", "diffusion", "transformer"], ) dataloader = create_dataloader(dataset_config["im_path"]) scheduler = LinearNoiseScheduler( diffusion_params["num_timesteps"], diffusion_params["beta_start"], diffusion_params["beta_end"], ) im_size = dataset_config["im_size"] // 2 ** sum(vae_config["down_sample"]) model = DIT( im_size=im_size, im_channels=vae_config["z_channels"], config=dit_config ).to(device) model.train() wandb.watch(model, log="all", log_freq=100) if os.path.exists( os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]) ): checkpoint = torch.load( os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), map_location=device, ) optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"]) model.load_state_dict(checkpoint["dit"]) start_epoch = checkpoint["epoch"] step_count = checkpoint["step_count"] optimizer.load_state_dict(checkpoint["optimizer"]) print(f"Resuming from epoch {start_epoch}, step {step_count}") wandb.log({"resumed_from_epoch": start_epoch, "resumed_from_step": step_count}) else: step_count = 0 start_epoch = 0 optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"]) if not os.path.exists( os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ) ): print("No VAE checkpoint found, VAE checkpoint needed") wandb.finish() return else: vae = VAE(dataset_config["im_channels"], vae_config).to(device) vae.load_state_dict( torch.load( os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ), map_location=device, ) ) vae.eval() for param in vae.parameters(): param.requires_grad = False print("VAE checkpoint loaded") # Log model architecture wandb.log( { "model_parameters": sum(p.numel() for p in model.parameters()), "trainable_parameters": sum( p.numel() for p in model.parameters() if p.requires_grad ), } ) num_epochs = train_config["dit_epochs"] accu_steps = train_config["dit_acc_steps"] criterion = nn.MSELoss() for epoch in range(start_epoch, num_epochs): losses = [] for im in tqdm(dataloader): im = im.float().to(device) step_count += 1 with torch.no_grad(): im, _ = vae.encode(im) noise = torch.randn_like(im).to(device) t = torch.randint(0, diffusion_params["num_timesteps"], (im.shape[0],)).to( device ) noisy_im = scheduler.add_noise(im, noise, t) pred = model(noisy_im, t) loss = criterion(pred, noise) losses.append(loss.item()) loss = loss / accu_steps loss.backward() if step_count % 10 == 0: # Log every 10 steps wandb.log( { "batch_loss": loss.item() * accu_steps, "learning_rate": optimizer.param_groups[0]["lr"], "step_count": step_count, "epoch": epoch, } ) if step_count % accu_steps == 0: optimizer.step() optimizer.zero_grad() optimizer.step() optimizer.zero_grad() wandb.log( { "epoch": epoch, "epoch_loss_std": np.std(losses), "learning_rate": optimizer.param_groups[0]["lr"], } ) print(f"Epoch {epoch}: Loss: {np.mean(losses)}") torch.save( { "dit": model.state_dict(), "epoch": epoch + 1, "step": step_count, "optimizer": optimizer.state_dict(), }, os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), ) if (epoch + 1) % 5 == 0: # Save every 5 epochs artifact = wandb.Artifact( f"dit_model_epoch_{epoch + 1}", type="model", description=f"DIT model checkpoint at epoch {epoch + 1}", ) artifact.add_file( os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), ) wandb.log_artifact(artifact) final_artifact = wandb.Artifact( "dit_model_final", type="model", description="Final DIT model checkpoint" ) final_artifact.add_file( os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]) ) wandb.log_artifact(final_artifact) print("Done Training") wandb.finish() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Arguments for dit training") parser.add_argument( "--config", dest="config_path", default="celeba/config.yaml", type=str ) args = parser.parse_args() train(args)