ldmae / VMAE /train.py
isno0907's picture
Upload 115 files
6c49103 verified
"""
Training Codes of LightningDiT together with VA-VAE.
It envolves advanced training methods, sampling methods,
architecture design methods, computation methods. We achieve
state-of-the-art FID 1.35 on ImageNet 256x256.
by Maple (Jingfeng Yao) from HUST-VL
"""
import torch
import torch.distributed as dist
import torch.backends.cuda
import torch.backends.cudnn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import math
import yaml
import json
import numpy as np
import logging
import os
import argparse
from time import time
from glob import glob
from copy import deepcopy
from collections import OrderedDict
from PIL import Image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from models.lightningdit import LightningDiT_models
from transport import create_transport, Sampler
from accelerate import Accelerator
from datasets.img_latent_dataset import ImgLatentDataset
def do_train(train_config, accelerator):
"""
Trains a LightningDiT.
"""
# Setup accelerator:
device = accelerator.device
# Setup an experiment folder:
if accelerator.is_main_process:
os.makedirs(train_config['train']['output_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
experiment_index = len(glob(f"{train_config['train']['output_dir']}/*"))
model_string_name = train_config['model']['model_type'].replace("/", "-")
if train_config['train']['exp_name'] is None:
exp_name = f'{experiment_index:03d}-{model_string_name}'
else:
exp_name = train_config['train']['exp_name']
experiment_dir = f"{train_config['train']['output_dir']}/{exp_name}" # Create an experiment folder
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
tensorboard_dir_log = f"tensorboard_logs/{exp_name}"
os.makedirs(tensorboard_dir_log, exist_ok=True)
writer = SummaryWriter(log_dir=tensorboard_dir_log)
# add configs to tensorboard
config_str=json.dumps(train_config, indent=4)
writer.add_text('training configs', config_str, global_step=0)
checkpoint_dir = f"{train_config['train']['output_dir']}/{train_config['train']['exp_name']}/checkpoints"
# get rank
rank = accelerator.local_process_index
# Create model:
if 'downsample_ratio' in train_config['vae']:
downsample_ratio = train_config['vae']['downsample_ratio']
else:
downsample_ratio = 16
assert train_config['data']['image_size'] % downsample_ratio == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = train_config['data']['image_size'] // downsample_ratio
model = LightningDiT_models[train_config['model']['model_type']](
input_size=latent_size,
num_classes=train_config['data']['num_classes'],
use_qknorm=train_config['model']['use_qknorm'],
use_swiglu=train_config['model']['use_swiglu'] if 'use_swiglu' in train_config['model'] else False,
use_rope=train_config['model']['use_rope'] if 'use_rope' in train_config['model'] else False,
use_rmsnorm=train_config['model']['use_rmsnorm'] if 'use_rmsnorm' in train_config['model'] else False,
wo_shift=train_config['model']['wo_shift'] if 'wo_shift' in train_config['model'] else False,
in_channels=train_config['model']['in_chans'] if 'in_chans' in train_config['model'] else 4,
use_checkpoint=train_config['model']['use_checkpoint'] if 'use_checkpoint' in train_config['model'] else False,
)
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
# load pretrained model
if 'weight_init' in train_config['train']:
checkpoint = torch.load(train_config['train']['weight_init'], map_location=lambda storage, loc: storage)
# remove the prefix 'module.' from the keys
checkpoint['model'] = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
model = load_weights_with_shape_check(model, checkpoint, rank=rank)
ema = load_weights_with_shape_check(ema, checkpoint, rank=rank)
if accelerator.is_main_process:
logger.info(f"Loaded pretrained model from {train_config['train']['weight_init']}")
requires_grad(ema, False)
model = DDP(model.to(device), device_ids=[rank])
transport = create_transport(
train_config['transport']['path_type'],
train_config['transport']['prediction'],
train_config['transport']['loss_weight'],
train_config['transport']['train_eps'],
train_config['transport']['sample_eps'],
use_cosine_loss = train_config['transport']['use_cosine_loss'] if 'use_cosine_loss' in train_config['transport'] else False,
use_lognorm = train_config['transport']['use_lognorm'] if 'use_lognorm' in train_config['transport'] else False,
) # default: velocity;
if accelerator.is_main_process:
logger.info(f"LightningDiT Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
logger.info(f"Optimizer: AdamW, lr={train_config['optimizer']['lr']}, beta2={train_config['optimizer']['beta2']}")
logger.info(f'Use lognorm sampling: {train_config["transport"]["use_lognorm"]}')
logger.info(f'Use cosine loss: {train_config["transport"]["use_cosine_loss"]}')
opt = torch.optim.AdamW(model.parameters(), lr=train_config['optimizer']['lr'], weight_decay=0, betas=(0.9, train_config['optimizer']['beta2']))
# Setup data
dataset = ImgLatentDataset(
data_dir=train_config['data']['data_path'],
latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False,
latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215,
)
batch_size_per_gpu = int(np.round(train_config['train']['global_batch_size'] / accelerator.num_processes))
global_batch_size = batch_size_per_gpu * accelerator.num_processes
loader = DataLoader(
dataset,
batch_size=batch_size_per_gpu,
shuffle=True,
num_workers=train_config['data']['num_workers'],
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(dataset):,} images {train_config['data']['data_path']}")
logger.info(f"Batch size {batch_size_per_gpu} per gpu, with {global_batch_size} global batch size")
if 'valid_path' in train_config['data']:
valid_dataset = ImgLatentDataset(
data_dir=train_config['data']['valid_path'],
latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False,
latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=batch_size_per_gpu,
shuffle=True,
num_workers=train_config['data']['num_workers'],
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Validation Dataset contains {len(valid_dataset):,} images {train_config['data']['valid_path']}")
# Prepare models for training:
update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode
train_config['train']['resume'] = train_config['train']['resume'] if 'resume' in train_config['train'] else False
if train_config['train']['resume']:
# check if the checkpoint exists
checkpoint_files = glob(f"{checkpoint_dir}/*.pt")
if checkpoint_files:
checkpoint_files.sort(key=lambda x: os.path.getsize(x))
latest_checkpoint = checkpoint_files[-1]
checkpoint = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['model'])
# opt.load_state_dict(checkpoint['opt'])
ema.load_state_dict(checkpoint['ema'])
train_steps = int(latest_checkpoint.split('/')[-1].split('.')[0])
if accelerator.is_main_process:
logger.info(f"Resuming training from checkpoint: {latest_checkpoint}")
else:
if accelerator.is_main_process:
logger.info("No checkpoint found. Starting training from scratch.")
model, opt, loader = accelerator.prepare(model, opt, loader)
# Variables for monitoring/logging purposes:
if not train_config['train']['resume']:
train_steps = 0
log_steps = 0
running_loss = 0
start_time = time()
use_checkpoint = train_config['train']['use_checkpoint'] if 'use_checkpoint' in train_config['train'] else True
if accelerator.is_main_process:
logger.info(f"Using checkpointing: {use_checkpoint}")
while True:
for x, y in loader:
if accelerator.mixed_precision == 'no':
x = x.to(device, dtype=torch.float32)
y = y
else:
x = x.to(device)
y = y.to(device)
model_kwargs = dict(y=y)
loss_dict = transport.training_losses(model, x, model_kwargs)
if 'cos_loss' in loss_dict:
mse_loss = loss_dict["loss"].mean()
loss = loss_dict["cos_loss"].mean() + mse_loss
else:
loss = loss_dict["loss"].mean()
opt.zero_grad()
accelerator.backward(loss)
if 'max_grad_norm' in train_config['optimizer']:
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), train_config['optimizer']['max_grad_norm'])
opt.step()
update_ema(ema, model.module)
# Log loss values:
if 'cos_loss' in loss_dict:
running_loss += mse_loss.item()
else:
running_loss += loss.item()
log_steps += 1
train_steps += 1
if train_steps % train_config['train']['log_every'] == 0:
# Measure training speed:
torch.cuda.synchronize()
end_time = time()
steps_per_sec = log_steps / (end_time - start_time)
# Reduce loss history over all processes:
avg_loss = torch.tensor(running_loss / log_steps, device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / dist.get_world_size()
if accelerator.is_main_process:
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
writer.add_scalar('Loss/train', avg_loss, train_steps)
# Reset monitoring variables:
running_loss = 0
log_steps = 0
start_time = time()
# Save checkpoint:
if train_steps % train_config['train']['ckpt_every'] == 0 and train_steps > 0:
if accelerator.is_main_process:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": opt.state_dict(),
"config": train_config,
}
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
torch.save(checkpoint, checkpoint_path)
if accelerator.is_main_process:
logger.info(f"Saved checkpoint to {checkpoint_path}")
dist.barrier()
# Evaluate on validation set
if 'valid_path' in train_config['data']:
if accelerator.is_main_process:
logger.info(f"Start evaluating at step {train_steps}")
val_loss = evaluate(model, valid_loader, device, transport, (0.0, 1.0))
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
val_loss = val_loss.item() / dist.get_world_size()
if accelerator.is_main_process:
logger.info(f"Validation Loss: {val_loss:.4f}")
writer.add_scalar('Loss/validation', val_loss, train_steps)
model.train()
if train_steps >= train_config['train']['max_steps']:
break
if train_steps >= train_config['train']['max_steps']:
break
if accelerator.is_main_process:
logger.info("Done!")
return accelerator
def load_weights_with_shape_check(model, checkpoint, rank=0):
model_state_dict = model.state_dict()
# check shape and load weights
for name, param in checkpoint['model'].items():
if name in model_state_dict:
if param.shape == model_state_dict[name].shape:
model_state_dict[name].copy_(param)
elif name == 'x_embedder.proj.weight':
# special case for x_embedder.proj.weight
# the pretrained model is trained with 256x256 images
# we can load the weights by resizing the weights
# and keep the first 3 channels the same
weight = torch.zeros_like(model_state_dict[name])
weight[:, :16] = param[:, :16]
model_state_dict[name] = weight
else:
if rank == 0:
print(f"Skipping loading parameter '{name}' due to shape mismatch: "
f"checkpoint shape {param.shape}, model shape {model_state_dict[name].shape}")
else:
if rank == 0:
print(f"Parameter '{name}' not found in model, skipping.")
# load state dict
model.load_state_dict(model_state_dict, strict=False)
return model
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def load_config(config_path):
with open(config_path, "r") as file:
config = yaml.safe_load(file)
return config
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
if __name__ == "__main__":
# read config
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/debug.yaml')
args = parser.parse_args()
accelerator = Accelerator()
train_config = load_config(args.config)
do_train(train_config, accelerator)