Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| A minimal training script for DiT. | |
| """ | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| # the first flag below was False when we tested this script but True makes A100 training a lot faster: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| from copy import deepcopy | |
| from glob import glob | |
| from time import time | |
| import argparse | |
| import os | |
| import yaml | |
| from accelerate import Accelerator | |
| from torch.utils.tensorboard import SummaryWriter | |
| from core.models import DiT_models | |
| from core.diffusion import create_diffusion | |
| from core.dataset import ImageParamsDataset | |
| from core.utils.train_utils import create_logger, update_ema, requires_grad | |
| ################################################################################# | |
| # Training Loop # | |
| ################################################################################# | |
| def main(cfg): | |
| """ | |
| Trains a new DiT model. | |
| """ | |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." | |
| # Setup accelerator: | |
| accelerator = Accelerator() | |
| device = accelerator.device | |
| # Setup an experiment folder: | |
| if accelerator.is_main_process: | |
| os.makedirs(cfg["save_dir"], exist_ok=True) # Make results folder (holds all experiment subfolders) | |
| save_dir = cfg["save_dir"] | |
| experiment_index = len(glob(f"{save_dir}/*")) | |
| experiment_dir = "{}/{:03d}-{}-{}-{}".format(save_dir, experiment_index, cfg["model"], cfg["epochs"], cfg["batch_size"]) # Create an experiment folder | |
| checkpoint_dir = "{}/checkpoints".format(experiment_dir) # Stores saved model checkpoints | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| logger = create_logger(experiment_dir) | |
| logger.info(f"Experiment directory created at {experiment_dir}") | |
| writer = SummaryWriter(experiment_dir) | |
| # Create model: | |
| latent_size = cfg["num_params"] | |
| condition_channels = 768 | |
| model = DiT_models[cfg["model"]](input_size=latent_size, condition_channels=condition_channels) | |
| # Note that parameter initialization is done within the DiT constructor | |
| model = model.to(device) | |
| ema = deepcopy(model).to(device) # Create an EMA of the model for use after training | |
| requires_grad(ema, False) | |
| diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule | |
| if accelerator.is_main_process: | |
| logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg["lr"]), weight_decay=0) | |
| # Setup data: | |
| dataset = ImageParamsDataset(cfg["data_root"], cfg["train_file"], cfg["params_dict_file"]) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=int(cfg["batch_size"] // accelerator.num_processes), | |
| shuffle=True, | |
| num_workers=cfg["num_workers"], | |
| pin_memory=True, | |
| drop_last=True | |
| ) | |
| if accelerator.is_main_process: | |
| logger.info(f"Dataset contains {len(dataset):,} images") | |
| # Prepare models for training: | |
| update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights | |
| model.train() # important! This enables embedding dropout for classifier-free guidance | |
| ema.eval() # EMA model should always be in eval mode | |
| model, optimizer, loader = accelerator.prepare(model, optimizer, loader) | |
| # Variables for monitoring/logging purposes: | |
| train_steps = 0 | |
| log_steps = 0 | |
| running_loss = 0 | |
| start_time = time() | |
| if accelerator.is_main_process: | |
| logger.info("Training for {} epochs...".format(cfg["epochs"])) | |
| # main training loop | |
| for epoch in range(int(cfg["epochs"])): | |
| if accelerator.is_main_process: | |
| logger.info(f"Beginning epoch {epoch}...") | |
| for x, img_feat, img in loader: | |
| # prepare the inputs | |
| x = x.to(device) | |
| img_feat = img_feat.to(device) | |
| x = x.unsqueeze(dim=1) # [B, 1, N] | |
| t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) | |
| model_kwargs = dict(y=img_feat) | |
| loss_dict = diffusion.training_losses(model, x, t, model_kwargs) | |
| loss = loss_dict["loss"].mean() | |
| optimizer.zero_grad() | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| update_ema(ema, model) | |
| writer.add_scalar("train/loss", loss.item(), train_steps) | |
| # Log loss values: | |
| running_loss += loss.item() | |
| log_steps += 1 | |
| train_steps += 1 | |
| if train_steps % cfg["logging_iter"] == 0: | |
| # Measure training speed: | |
| torch.cuda.synchronize() | |
| end_time = time() | |
| steps_per_sec = log_steps / (end_time - start_time) | |
| # Reduce loss history over all processes: | |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) | |
| avg_loss = avg_loss.item() / accelerator.num_processes | |
| if accelerator.is_main_process: | |
| logger.info(f"(Step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") | |
| # Reset monitoring variables: | |
| running_loss = 0 | |
| log_steps = 0 | |
| start_time = time() | |
| # Save DiT checkpoint: | |
| if train_steps % cfg["ckpt_iter"] == 0 and train_steps > 0: | |
| if accelerator.is_main_process: | |
| checkpoint = { | |
| "model": model.state_dict(), | |
| "ema": ema.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "config": cfg, | |
| } | |
| checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" | |
| torch.save(checkpoint, checkpoint_path) | |
| logger.info(f"Saved checkpoint to {checkpoint_path}") | |
| model.eval() # important! This disables randomized embedding dropout | |
| # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... | |
| if accelerator.is_main_process: | |
| writer.flush() | |
| logger.info("Done!") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, required=True) | |
| args = parser.parse_args() | |
| with open(args.config) as f: | |
| cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| main(cfg) | |