| |
| |
|
|
| """ |
| A minimal training script for SiT using PyTorch DDP. |
| """ |
| import torch |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torchvision.datasets import ImageFolder |
| from torchvision import transforms |
| import numpy as np |
| from collections import OrderedDict |
| from PIL import Image |
| from copy import deepcopy |
| from glob import glob |
| from time import time |
| import argparse |
| import logging |
| import os |
|
|
| from models import SiT_models |
| from download import find_model |
| from transport import create_transport, Sampler |
| from diffusers.models import AutoencoderKL |
| from train_utils import parse_transport_args |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def requires_grad(model, flag=True): |
| """ |
| Set requires_grad flag for all parameters in a model. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| def cleanup(): |
| """ |
| End DDP training. |
| """ |
| dist.destroy_process_group() |
|
|
|
|
| def create_logger(logging_dir): |
| """ |
| Create a logger that writes to a log file and stdout. |
| """ |
| if dist.get_rank() == 0: |
| logging.basicConfig( |
| level=logging.INFO, |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] |
| ) |
| logger = logging.getLogger(__name__) |
| else: |
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
| return logger |
|
|
|
|
| def center_crop_arr(pil_image, image_size): |
| """ |
| Center cropping implementation from ADM. |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
| """ |
| while min(*pil_image.size) >= 2 * image_size: |
| pil_image = pil_image.resize( |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| ) |
|
|
| scale = image_size / min(*pil_image.size) |
| pil_image = pil_image.resize( |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| ) |
|
|
| arr = np.array(pil_image) |
| crop_y = (arr.shape[0] - image_size) // 2 |
| crop_x = (arr.shape[1] - image_size) // 2 |
| return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
|
|
| |
| |
| |
|
|
| def main(args): |
| """ |
| Trains a new SiT model. |
| """ |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
|
|
| |
| dist.init_process_group("nccl") |
| assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." |
| rank = dist.get_rank() |
| device = rank % torch.cuda.device_count() |
| seed = args.global_seed * dist.get_world_size() + rank |
| torch.manual_seed(seed) |
| torch.cuda.set_device(device) |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| local_batch_size = int(args.global_batch_size // dist.get_world_size()) |
|
|
| |
| if rank == 0: |
| os.makedirs(args.results_dir, exist_ok=True) |
| experiment_index = len(glob(f"{args.results_dir}/*")) |
| model_string_name = args.model.replace("/", "-") |
| experiment_name = f"{experiment_index:03d}-{model_string_name}-" \ |
| f"{args.path_type}-{args.prediction}-{args.loss_weight}" |
| experiment_dir = f"{args.results_dir}/{experiment_name}" |
| checkpoint_dir = f"{experiment_dir}/checkpoints" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| |
| |
| pic_dir = f"{experiment_dir}/pic" |
| os.makedirs(pic_dir, exist_ok=True) |
| |
| logger = create_logger(experiment_dir) |
| logger.info(f"Experiment directory created at {experiment_dir}") |
| logger.info(f"Sample images will be saved to {pic_dir}") |
|
|
| else: |
| logger = create_logger(None) |
|
|
| |
| assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| latent_size = args.image_size // 8 |
| model = SiT_models[args.model]( |
| input_size=latent_size, |
| num_classes=args.num_classes |
| ) |
|
|
| |
| ema = deepcopy(model).to(device) |
|
|
| if args.ckpt is not None: |
| ckpt_path = args.ckpt |
| state_dict = find_model(ckpt_path) |
| model.load_state_dict(state_dict["model"]) |
| ema.load_state_dict(state_dict["ema"]) |
| opt.load_state_dict(state_dict["opt"]) |
| args = state_dict["args"] |
|
|
| requires_grad(ema, False) |
| |
| model = DDP(model.to(device), device_ids=[device]) |
| transport = create_transport( |
| args.path_type, |
| args.prediction, |
| args.loss_weight, |
| args.train_eps, |
| args.sample_eps |
| ) |
| transport_sampler = Sampler(transport) |
| vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) |
| logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| ]) |
| dataset = ImageFolder(args.data_path, transform=transform) |
| sampler = DistributedSampler( |
| dataset, |
| num_replicas=dist.get_world_size(), |
| rank=rank, |
| shuffle=True, |
| seed=args.global_seed |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=local_batch_size, |
| shuffle=False, |
| sampler=sampler, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=True |
| ) |
| logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") |
|
|
| |
| update_ema(ema, model.module, decay=0) |
| model.train() |
| ema.eval() |
|
|
| |
| train_steps = 0 |
| log_steps = 0 |
| running_loss = 0 |
| start_time = time() |
|
|
| |
| ys = torch.randint(1000, size=(local_batch_size,), device=device) |
| use_cfg = args.cfg_scale > 1.0 |
| |
| n = ys.size(0) |
| zs = torch.randn(n, 4, latent_size, latent_size, device=device) |
| |
| |
| fixed_ys = torch.randint(1000, size=(16,), device=device) |
| fixed_zs = torch.randn(16, 4, latent_size, latent_size, device=device) |
| |
| |
| if use_cfg: |
| zs = torch.cat([zs, zs], 0) |
| y_null = torch.tensor([1000] * n, device=device) |
| ys = torch.cat([ys, y_null], 0) |
| sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale) |
| model_fn = ema.forward_with_cfg |
| else: |
| sample_model_kwargs = dict(y=ys) |
| model_fn = ema.forward |
| |
| |
| if args.cfg_scale > 1.0: |
| fixed_zs = torch.cat([fixed_zs, fixed_zs], 0) |
| fixed_y_null = torch.tensor([1000] * 16, device=device) |
| fixed_ys = torch.cat([fixed_ys, fixed_y_null], 0) |
| fixed_sample_model_kwargs = dict(y=fixed_ys, cfg_scale=args.cfg_scale) |
| else: |
| fixed_sample_model_kwargs = dict(y=fixed_ys) |
|
|
| logger.info(f"Training for {args.epochs} epochs...") |
| for epoch in range(args.epochs): |
| sampler.set_epoch(epoch) |
| logger.info(f"Beginning epoch {epoch}...") |
| for x, y in loader: |
| x = x.to(device) |
| y = y.to(device) |
| with torch.no_grad(): |
| |
| x = vae.encode(x).latent_dist.sample().mul_(0.18215) |
| model_kwargs = dict(y=y) |
| loss_dict = transport.training_losses(model, x, model_kwargs) |
| loss = loss_dict["loss"].mean() |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| update_ema(ema, model.module) |
|
|
| |
| running_loss += loss.item() |
| log_steps += 1 |
| train_steps += 1 |
| if train_steps % args.log_every == 0: |
| |
| torch.cuda.synchronize() |
| end_time = time() |
| steps_per_sec = log_steps / (end_time - start_time) |
| |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
| avg_loss = avg_loss.item() / dist.get_world_size() |
| logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") |
| |
| running_loss = 0 |
| log_steps = 0 |
| start_time = time() |
|
|
| |
| if train_steps % args.ckpt_every == 0 and train_steps > 0: |
| if rank == 0: |
| checkpoint = { |
| "model": model.module.state_dict(), |
| "ema": ema.state_dict(), |
| "opt": opt.state_dict(), |
| "args": args |
| } |
| checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Saved checkpoint to {checkpoint_path}") |
| dist.barrier() |
| |
| |
| if train_steps % args.sample_every == 0 and train_steps > 0: |
| logger.info("Generating EMA samples...") |
| sample_fn = transport_sampler.sample_ode() |
| samples = sample_fn(fixed_zs, model_fn, **fixed_sample_model_kwargs)[-1] |
| dist.barrier() |
|
|
| if args.cfg_scale > 1.0: |
| samples, _ = samples.chunk(2, dim=0) |
| samples = vae.decode(samples / 0.18215).sample |
| |
| |
| if rank == 0: |
| |
| |
| samples = (samples.clamp(-1, 1) + 1) / 2 |
| |
| |
| grid_size = args.image_size |
| grid_image = Image.new('RGB', (4 * grid_size, 4 * grid_size)) |
| |
| |
| for i in range(min(16, samples.shape[0])): |
| |
| img = samples[i].permute(1, 2, 0).cpu().detach().numpy() |
| img = (img * 255).astype(np.uint8) |
| pil_img = Image.fromarray(img) |
| |
| |
| row = i // 4 |
| col = i % 4 |
| grid_image.paste(pil_img, (col * grid_size, row * grid_size)) |
| |
| |
| img_path = f"{pic_dir}/step_{train_steps:07d}_samples_grid.png" |
| grid_image.save(img_path) |
| logger.info(f"Saved sample images grid to {img_path}") |
| |
| logging.info("Generating EMA samples done.") |
|
|
| model.eval() |
| |
|
|
| logger.info("Done!") |
| cleanup() |
|
|
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-path", type=str, default="/gemini/platform/public/hzh/datasets/Imagenet/train/") |
| parser.add_argument("--results-dir", type=str, default="results") |
| parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") |
| parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) |
| parser.add_argument("--num-classes", type=int, default=1000) |
| parser.add_argument("--epochs", type=int, default=140000) |
| parser.add_argument("--global-batch-size", type=int, default=256) |
| parser.add_argument("--global-seed", type=int, default=0) |
| parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") |
| parser.add_argument("--num-workers", type=int, default=4) |
| parser.add_argument("--log-every", type=int, default=100) |
| parser.add_argument("--ckpt-every", type=int, default=10) |
| parser.add_argument("--sample-every", type=int, default=10) |
| parser.add_argument("--cfg-scale", type=float, default=4.0) |
| parser.add_argument("--ckpt", type=str, default=None, |
| help="Optional path to a custom SiT checkpoint") |
|
|
| parse_transport_args(parser) |
| args = parser.parse_args() |
| main(args) |
|
|