| """ |
| Training Codes of LightningDiT together with VA-VAE. |
| It envolves advanced training methods, sampling methods, |
| architecture design methods, computation methods. We achieve |
| state-of-the-art FID 1.35 on ImageNet 256x256. |
| |
| by Maple (Jingfeng Yao) from HUST-VL |
| """ |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.backends.cuda |
| import torch.backends.cudnn |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| import math |
| import yaml |
| import json |
| import numpy as np |
| import logging |
| import os |
| import argparse |
| from time import time |
| from glob import glob |
| from copy import deepcopy |
| from collections import OrderedDict |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from diffusers.models import AutoencoderKL |
| from models.lightningdit import LightningDiT_models |
| from transport import create_transport, Sampler |
| from accelerate import Accelerator |
| from datasets.img_latent_dataset import ImgLatentDataset |
|
|
| def do_train(train_config, accelerator): |
| """ |
| Trains a LightningDiT. |
| """ |
| |
| device = accelerator.device |
|
|
| |
| if accelerator.is_main_process: |
| os.makedirs(train_config['train']['output_dir'], exist_ok=True) |
| experiment_index = len(glob(f"{train_config['train']['output_dir']}/*")) |
| model_string_name = train_config['model']['model_type'].replace("/", "-") |
| if train_config['train']['exp_name'] is None: |
| exp_name = f'{experiment_index:03d}-{model_string_name}' |
| else: |
| exp_name = train_config['train']['exp_name'] |
| experiment_dir = f"{train_config['train']['output_dir']}/{exp_name}" |
| checkpoint_dir = f"{experiment_dir}/checkpoints" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| logger = create_logger(experiment_dir) |
| logger.info(f"Experiment directory created at {experiment_dir}") |
| tensorboard_dir_log = f"tensorboard_logs/{exp_name}" |
| os.makedirs(tensorboard_dir_log, exist_ok=True) |
| writer = SummaryWriter(log_dir=tensorboard_dir_log) |
|
|
| |
| config_str=json.dumps(train_config, indent=4) |
| writer.add_text('training configs', config_str, global_step=0) |
| checkpoint_dir = f"{train_config['train']['output_dir']}/{train_config['train']['exp_name']}/checkpoints" |
|
|
| |
| rank = accelerator.local_process_index |
|
|
| |
| if 'downsample_ratio' in train_config['vae']: |
| downsample_ratio = train_config['vae']['downsample_ratio'] |
| else: |
| downsample_ratio = 16 |
| assert train_config['data']['image_size'] % downsample_ratio == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| latent_size = train_config['data']['image_size'] // downsample_ratio |
| model = LightningDiT_models[train_config['model']['model_type']]( |
| input_size=latent_size, |
| num_classes=train_config['data']['num_classes'], |
| use_qknorm=train_config['model']['use_qknorm'], |
| use_swiglu=train_config['model']['use_swiglu'] if 'use_swiglu' in train_config['model'] else False, |
| use_rope=train_config['model']['use_rope'] if 'use_rope' in train_config['model'] else False, |
| use_rmsnorm=train_config['model']['use_rmsnorm'] if 'use_rmsnorm' in train_config['model'] else False, |
| wo_shift=train_config['model']['wo_shift'] if 'wo_shift' in train_config['model'] else False, |
| in_channels=train_config['model']['in_chans'] if 'in_chans' in train_config['model'] else 4, |
| use_checkpoint=train_config['model']['use_checkpoint'] if 'use_checkpoint' in train_config['model'] else False, |
| ) |
|
|
| ema = deepcopy(model).to(device) |
|
|
| |
| if 'weight_init' in train_config['train']: |
| checkpoint = torch.load(train_config['train']['weight_init'], map_location=lambda storage, loc: storage) |
| |
| checkpoint['model'] = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()} |
| model = load_weights_with_shape_check(model, checkpoint, rank=rank) |
| ema = load_weights_with_shape_check(ema, checkpoint, rank=rank) |
| if accelerator.is_main_process: |
| logger.info(f"Loaded pretrained model from {train_config['train']['weight_init']}") |
| requires_grad(ema, False) |
| |
| model = DDP(model.to(device), device_ids=[rank]) |
| transport = create_transport( |
| train_config['transport']['path_type'], |
| train_config['transport']['prediction'], |
| train_config['transport']['loss_weight'], |
| train_config['transport']['train_eps'], |
| train_config['transport']['sample_eps'], |
| use_cosine_loss = train_config['transport']['use_cosine_loss'] if 'use_cosine_loss' in train_config['transport'] else False, |
| use_lognorm = train_config['transport']['use_lognorm'] if 'use_lognorm' in train_config['transport'] else False, |
| ) |
| if accelerator.is_main_process: |
| logger.info(f"LightningDiT Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M") |
| logger.info(f"Optimizer: AdamW, lr={train_config['optimizer']['lr']}, beta2={train_config['optimizer']['beta2']}") |
| logger.info(f'Use lognorm sampling: {train_config["transport"]["use_lognorm"]}') |
| logger.info(f'Use cosine loss: {train_config["transport"]["use_cosine_loss"]}') |
| opt = torch.optim.AdamW(model.parameters(), lr=train_config['optimizer']['lr'], weight_decay=0, betas=(0.9, train_config['optimizer']['beta2'])) |
| |
| |
| dataset = ImgLatentDataset( |
| data_dir=train_config['data']['data_path'], |
| latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False, |
| latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215, |
| ) |
| batch_size_per_gpu = int(np.round(train_config['train']['global_batch_size'] / accelerator.num_processes)) |
| global_batch_size = batch_size_per_gpu * accelerator.num_processes |
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size_per_gpu, |
| shuffle=True, |
| num_workers=train_config['data']['num_workers'], |
| pin_memory=True, |
| drop_last=True |
| ) |
| if accelerator.is_main_process: |
| logger.info(f"Dataset contains {len(dataset):,} images {train_config['data']['data_path']}") |
| logger.info(f"Batch size {batch_size_per_gpu} per gpu, with {global_batch_size} global batch size") |
| |
| if 'valid_path' in train_config['data']: |
| valid_dataset = ImgLatentDataset( |
| data_dir=train_config['data']['valid_path'], |
| latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False, |
| latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215, |
| ) |
| valid_loader = DataLoader( |
| valid_dataset, |
| batch_size=batch_size_per_gpu, |
| shuffle=True, |
| num_workers=train_config['data']['num_workers'], |
| pin_memory=True, |
| drop_last=True |
| ) |
| if accelerator.is_main_process: |
| logger.info(f"Validation Dataset contains {len(valid_dataset):,} images {train_config['data']['valid_path']}") |
|
|
| |
| update_ema(ema, model.module, decay=0) |
| model.train() |
| ema.eval() |
| |
| train_config['train']['resume'] = train_config['train']['resume'] if 'resume' in train_config['train'] else False |
|
|
| if train_config['train']['resume']: |
| |
| checkpoint_files = glob(f"{checkpoint_dir}/*.pt") |
| if checkpoint_files: |
| checkpoint_files.sort(key=lambda x: os.path.getsize(x)) |
| latest_checkpoint = checkpoint_files[-1] |
| checkpoint = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) |
| model.load_state_dict(checkpoint['model']) |
| |
| ema.load_state_dict(checkpoint['ema']) |
| train_steps = int(latest_checkpoint.split('/')[-1].split('.')[0]) |
| if accelerator.is_main_process: |
| logger.info(f"Resuming training from checkpoint: {latest_checkpoint}") |
| else: |
| if accelerator.is_main_process: |
| logger.info("No checkpoint found. Starting training from scratch.") |
| model, opt, loader = accelerator.prepare(model, opt, loader) |
|
|
| |
| if not train_config['train']['resume']: |
| train_steps = 0 |
| log_steps = 0 |
| running_loss = 0 |
| start_time = time() |
| use_checkpoint = train_config['train']['use_checkpoint'] if 'use_checkpoint' in train_config['train'] else True |
| if accelerator.is_main_process: |
| logger.info(f"Using checkpointing: {use_checkpoint}") |
|
|
| while True: |
| for x, y in loader: |
| if accelerator.mixed_precision == 'no': |
| x = x.to(device, dtype=torch.float32) |
| y = y |
| else: |
| x = x.to(device) |
| y = y.to(device) |
| model_kwargs = dict(y=y) |
| loss_dict = transport.training_losses(model, x, model_kwargs) |
| if 'cos_loss' in loss_dict: |
| mse_loss = loss_dict["loss"].mean() |
| loss = loss_dict["cos_loss"].mean() + mse_loss |
| else: |
| loss = loss_dict["loss"].mean() |
| opt.zero_grad() |
| accelerator.backward(loss) |
| if 'max_grad_norm' in train_config['optimizer']: |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), train_config['optimizer']['max_grad_norm']) |
| opt.step() |
| update_ema(ema, model.module) |
|
|
| |
| if 'cos_loss' in loss_dict: |
| running_loss += mse_loss.item() |
| else: |
| running_loss += loss.item() |
| log_steps += 1 |
| train_steps += 1 |
| if train_steps % train_config['train']['log_every'] == 0: |
| |
| torch.cuda.synchronize() |
| end_time = time() |
| steps_per_sec = log_steps / (end_time - start_time) |
| |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
| avg_loss = avg_loss.item() / dist.get_world_size() |
| if accelerator.is_main_process: |
| logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") |
| writer.add_scalar('Loss/train', avg_loss, train_steps) |
| |
| running_loss = 0 |
| log_steps = 0 |
| start_time = time() |
|
|
| |
| if train_steps % train_config['train']['ckpt_every'] == 0 and train_steps > 0: |
| if accelerator.is_main_process: |
| checkpoint = { |
| "model": model.module.state_dict(), |
| "ema": ema.state_dict(), |
| "opt": opt.state_dict(), |
| "config": train_config, |
| } |
| checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" |
| torch.save(checkpoint, checkpoint_path) |
| if accelerator.is_main_process: |
| logger.info(f"Saved checkpoint to {checkpoint_path}") |
| dist.barrier() |
|
|
| |
| if 'valid_path' in train_config['data']: |
| if accelerator.is_main_process: |
| logger.info(f"Start evaluating at step {train_steps}") |
| val_loss = evaluate(model, valid_loader, device, transport, (0.0, 1.0)) |
| dist.all_reduce(val_loss, op=dist.ReduceOp.SUM) |
| val_loss = val_loss.item() / dist.get_world_size() |
| if accelerator.is_main_process: |
| logger.info(f"Validation Loss: {val_loss:.4f}") |
| writer.add_scalar('Loss/validation', val_loss, train_steps) |
| model.train() |
| if train_steps >= train_config['train']['max_steps']: |
| break |
| if train_steps >= train_config['train']['max_steps']: |
| break |
|
|
| if accelerator.is_main_process: |
| logger.info("Done!") |
|
|
| return accelerator |
|
|
| def load_weights_with_shape_check(model, checkpoint, rank=0): |
| |
| model_state_dict = model.state_dict() |
| |
| for name, param in checkpoint['model'].items(): |
| if name in model_state_dict: |
| if param.shape == model_state_dict[name].shape: |
| model_state_dict[name].copy_(param) |
| elif name == 'x_embedder.proj.weight': |
| |
| |
| |
| |
| weight = torch.zeros_like(model_state_dict[name]) |
| weight[:, :16] = param[:, :16] |
| model_state_dict[name] = weight |
| else: |
| if rank == 0: |
| print(f"Skipping loading parameter '{name}' due to shape mismatch: " |
| f"checkpoint shape {param.shape}, model shape {model_state_dict[name].shape}") |
| else: |
| if rank == 0: |
| print(f"Parameter '{name}' not found in model, skipping.") |
| |
| model.load_state_dict(model_state_dict, strict=False) |
| |
| return model |
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| name = name.replace("module.", "") |
| |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def requires_grad(model, flag=True): |
| """ |
| Set requires_grad flag for all parameters in a model. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
| def load_config(config_path): |
| with open(config_path, "r") as file: |
| config = yaml.safe_load(file) |
| return config |
|
|
| def create_logger(logging_dir): |
| """ |
| Create a logger that writes to a log file and stdout. |
| """ |
| if dist.get_rank() == 0: |
| logging.basicConfig( |
| level=logging.INFO, |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] |
| ) |
| logger = logging.getLogger(__name__) |
| else: |
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
| return logger |
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='configs/debug.yaml') |
| args = parser.parse_args() |
|
|
| accelerator = Accelerator() |
| train_config = load_config(args.config) |
| do_train(train_config, accelerator) |