| import argparse |
| import os |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import wandb.filesync |
| import yaml |
| from torch.optim import AdamW |
| from tqdm import tqdm |
|
|
| import wandb |
| from celeba import create_dataloader |
| from model.transformer import DIT |
| from model.vae import VAE |
| from scheduler.linear_scheduler import LinearNoiseScheduler |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def train(args): |
| with open(args.config_path, "r") as file: |
| try: |
| config = yaml.safe_load(file) |
| except yaml.YAMLError as e: |
| print(f"Error in loading yaml: {e}") |
|
|
| train_config = config["train_params"] |
| dit_config = config["dit_params"] |
| dataset_config = config["dataset_params"] |
| diffusion_params = config["diffusion_params"] |
| vae_config = config["autoencoder_params"] |
|
|
| wandb.init( |
| project="diffusion-transformer", |
| name=f"{train_config['task_name']}_dit_training", |
| config={ |
| "train_config": train_config, |
| "dit_config": dit_config, |
| "dataset_config": dataset_config, |
| "diffusion_params": diffusion_params, |
| "vae_config": vae_config, |
| "device": device, |
| }, |
| tags=["dit", "diffusion", "transformer"], |
| ) |
|
|
| dataloader = create_dataloader(dataset_config["im_path"]) |
|
|
| scheduler = LinearNoiseScheduler( |
| diffusion_params["num_timesteps"], |
| diffusion_params["beta_start"], |
| diffusion_params["beta_end"], |
| ) |
|
|
| im_size = dataset_config["im_size"] // 2 ** sum(vae_config["down_sample"]) |
| model = DIT( |
| im_size=im_size, im_channels=vae_config["z_channels"], config=dit_config |
| ).to(device) |
| model.train() |
| wandb.watch(model, log="all", log_freq=100) |
|
|
| if os.path.exists( |
| os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]) |
| ): |
| checkpoint = torch.load( |
| os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), |
| map_location=device, |
| ) |
| optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"]) |
|
|
| model.load_state_dict(checkpoint["dit"]) |
| start_epoch = checkpoint["epoch"] |
| step_count = checkpoint["step_count"] |
| optimizer.load_state_dict(checkpoint["optimizer"]) |
| print(f"Resuming from epoch {start_epoch}, step {step_count}") |
| wandb.log({"resumed_from_epoch": start_epoch, "resumed_from_step": step_count}) |
| else: |
| step_count = 0 |
| start_epoch = 0 |
| optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"]) |
|
|
| if not os.path.exists( |
| os.path.join( |
| train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] |
| ) |
| ): |
| print("No VAE checkpoint found, VAE checkpoint needed") |
| wandb.finish() |
| return |
| else: |
| vae = VAE(dataset_config["im_channels"], vae_config).to(device) |
| vae.load_state_dict( |
| torch.load( |
| os.path.join( |
| train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] |
| ), |
| map_location=device, |
| ) |
| ) |
| vae.eval() |
| for param in vae.parameters(): |
| param.requires_grad = False |
| print("VAE checkpoint loaded") |
|
|
| |
| wandb.log( |
| { |
| "model_parameters": sum(p.numel() for p in model.parameters()), |
| "trainable_parameters": sum( |
| p.numel() for p in model.parameters() if p.requires_grad |
| ), |
| } |
| ) |
| num_epochs = train_config["dit_epochs"] |
| accu_steps = train_config["dit_acc_steps"] |
| criterion = nn.MSELoss() |
|
|
| for epoch in range(start_epoch, num_epochs): |
| losses = [] |
| for im in tqdm(dataloader): |
| im = im.float().to(device) |
| step_count += 1 |
| with torch.no_grad(): |
| im, _ = vae.encode(im) |
|
|
| noise = torch.randn_like(im).to(device) |
|
|
| t = torch.randint(0, diffusion_params["num_timesteps"], (im.shape[0],)).to( |
| device |
| ) |
| noisy_im = scheduler.add_noise(im, noise, t) |
| pred = model(noisy_im, t) |
| loss = criterion(pred, noise) |
| losses.append(loss.item()) |
| loss = loss / accu_steps |
| loss.backward() |
|
|
| if step_count % 10 == 0: |
| wandb.log( |
| { |
| "batch_loss": loss.item() * accu_steps, |
| "learning_rate": optimizer.param_groups[0]["lr"], |
| "step_count": step_count, |
| "epoch": epoch, |
| } |
| ) |
| if step_count % accu_steps == 0: |
| optimizer.step() |
| optimizer.zero_grad() |
| optimizer.step() |
| optimizer.zero_grad() |
| wandb.log( |
| { |
| "epoch": epoch, |
| "epoch_loss_std": np.std(losses), |
| "learning_rate": optimizer.param_groups[0]["lr"], |
| } |
| ) |
| print(f"Epoch {epoch}: Loss: {np.mean(losses)}") |
| torch.save( |
| { |
| "dit": model.state_dict(), |
| "epoch": epoch + 1, |
| "step": step_count, |
| "optimizer": optimizer.state_dict(), |
| }, |
| os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), |
| ) |
| if (epoch + 1) % 5 == 0: |
| artifact = wandb.Artifact( |
| f"dit_model_epoch_{epoch + 1}", |
| type="model", |
| description=f"DIT model checkpoint at epoch {epoch + 1}", |
| ) |
| artifact.add_file( |
| os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), |
| ) |
| wandb.log_artifact(artifact) |
|
|
| final_artifact = wandb.Artifact( |
| "dit_model_final", type="model", description="Final DIT model checkpoint" |
| ) |
| final_artifact.add_file( |
| os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]) |
| ) |
| wandb.log_artifact(final_artifact) |
|
|
| print("Done Training") |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Arguments for dit training") |
| parser.add_argument( |
| "--config", dest="config_path", default="celeba/config.yaml", type=str |
| ) |
| args = parser.parse_args() |
| train(args) |
|
|