| """ |
| Sampling Scripts of LightningDiT. |
| |
| by Maple (Jingfeng Yao) from HUST-VL |
| """ |
|
|
| import os, math, json, pickle, logging, argparse, yaml, torch, numpy as np |
| from time import time, strftime |
| from glob import glob |
| from copy import deepcopy |
| from collections import OrderedDict |
| from PIL import Image |
| from tqdm import tqdm |
| import torch.distributed as dist |
| from accelerate import Accelerator |
| from torch.utils.data import DataLoader |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.tensorboard import SummaryWriter |
| import torchvision |
| |
| from tokenizer.vavae import VA_VAE |
| from tokenizer.sdvae import Diffusers_AutoencoderKL |
| from tokenizer import models_mae |
| import threading |
|
|
| from models.lightningdit import LightningDiT_models |
| from transport import create_transport, Sampler |
| from datasets.img_latent_dataset import ImgLatentDataset |
| from torchvision.utils import save_image |
|
|
| |
| def save_images_async(images, indices, save_dir): |
| """비동기적으로 이미지를 저장하는 함수""" |
| for img, idx in zip(images, indices): |
| |
| if isinstance(img, np.ndarray): |
| img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
| save_image(img, f"{save_dir}/{idx:06d}.png") |
|
|
| def do_sample(train_config, accelerator, ckpt_path=None, cfg_scale=None, model=None, vae=None, demo_sample_mode=False): |
| """ |
| Run sampling. |
| """ |
|
|
| folder_name = f"{train_config['model']['model_type'].replace('/', '-')}-ckpt-{ckpt_path.split('/')[-1].split('.')[0]}-{train_config['sample']['sampling_method']}-{train_config['sample']['num_sampling_steps']}".lower() |
| if cfg_scale is None: |
| cfg_scale = train_config['sample']['cfg_scale'] |
| cfg_interval_start = train_config['sample']['cfg_interval_start'] if 'cfg_interval_start' in train_config['sample'] else 0 |
| timestep_shift = train_config['sample']['timestep_shift'] if 'timestep_shift' in train_config['sample'] else 0 |
| if cfg_scale > 1.0: |
| folder_name += f"-interval{cfg_interval_start:.2f}"+f"-cfg{cfg_scale:.2f}" |
| folder_name += f"-shift{timestep_shift:.2f}" |
|
|
| if demo_sample_mode: |
| cfg_interval_start = 0 |
| timestep_shift = 0 |
| |
|
|
| sample_folder_dir = os.path.join(train_config['train']['output_dir'], train_config['train']['exp_name'], folder_name) |
| if accelerator.process_index == 0: |
| if not demo_sample_mode: |
| print_with_prefix('Sample_folder_dir=', sample_folder_dir) |
| print_with_prefix('ckpt_path=', ckpt_path) |
| print_with_prefix('cfg_scale=', cfg_scale) |
| print_with_prefix('cfg_interval_start=', cfg_interval_start) |
| print_with_prefix('timestep_shift=', timestep_shift) |
| if not demo_sample_mode: |
| if not os.path.exists(sample_folder_dir): |
| if accelerator.process_index == 0: |
| os.makedirs(sample_folder_dir, exist_ok=True) |
| else: |
| png_files = [f for f in os.listdir(sample_folder_dir) if f.endswith('.png')] |
| png_count = len(png_files) |
| if png_count > train_config['sample']['fid_num']: |
| if accelerator.process_index == 0: |
| print_with_prefix(f"Found {png_count} PNG files in {sample_folder_dir}, skip sampling.") |
| return sample_folder_dir |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" |
| torch.set_grad_enabled(False) |
|
|
| |
| device = accelerator.device |
|
|
| |
| seed = train_config['train']['global_seed'] * accelerator.num_processes + accelerator.process_index |
| torch.manual_seed(seed) |
| |
| print_with_prefix(f"Starting rank={accelerator.local_process_index}, seed={seed}, world_size={accelerator.num_processes}.") |
| rank = accelerator.local_process_index |
|
|
| |
| if 'downsample_ratio' in train_config['vae']: |
| downsample_ratio = train_config['vae']['downsample_ratio'] |
| else: |
| downsample_ratio = 16 |
| latent_size = train_config['data']['image_size'] // downsample_ratio |
|
|
| checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) |
| if "ema" in checkpoint: |
| checkpoint = checkpoint["ema"] |
| model.load_state_dict(checkpoint) |
| model.eval() |
| model.to(device) |
|
|
| transport = create_transport( |
| train_config['transport']['path_type'], |
| train_config['transport']['prediction'], |
| train_config['transport']['loss_weight'], |
| train_config['transport']['train_eps'], |
| train_config['transport']['sample_eps'], |
| use_cosine_loss = train_config['transport']['use_cosine_loss'] if 'use_cosine_loss' in train_config['transport'] else False, |
| use_lognorm = train_config['transport']['use_lognorm'] if 'use_lognorm' in train_config['transport'] else False, |
| ) |
| sampler = Sampler(transport) |
| mode = train_config['sample']['mode'] |
| if mode == "ODE": |
| sample_fn = sampler.sample_ode( |
| sampling_method=train_config['sample']['sampling_method'], |
| num_steps=train_config['sample']['num_sampling_steps'], |
| atol=train_config['sample']['atol'], |
| rtol=train_config['sample']['rtol'], |
| reverse=train_config['sample']['reverse'], |
| timestep_shift=timestep_shift, |
| ) |
| else: |
| raise NotImplementedError(f"Sampling mode {mode} is not supported.") |
| |
| if vae is None: |
| if train_config['vae']['model_name'].split("_")[0] == 'vmae': |
| chkpt = train_config['vae']['weight_path'] |
| arch = 'mae_for_ldmae_f8d16_prev' |
| vae = getattr(models_mae, arch)(ldmae_mode=True, no_cls=True, kl_loss_weight=True, smooth_output=True, img_size=train_config['data']['image_size']) |
| checkpoint = torch.load(chkpt, map_location='cpu') |
| vae = vae.to(device).eval() |
| msg = vae.load_state_dict(checkpoint['model'], strict=False) |
| elif train_config['vae']['model_name'].split("_")[0] in ['ae','dae', 'vae','sdv3']: |
| vae = Diffusers_AutoencoderKL( |
| img_size=train_config['data']['image_size'], |
| sample_size=128, |
| in_channels=3, |
| out_channels=3, |
| layers_per_block=2, |
| latent_channels=16, |
| norm_num_groups=32, |
| act_fn="silu", |
| block_out_channels=(128, 256, 512, 512), |
| force_upcast=False, |
| use_quant_conv=False, |
| use_post_quant_conv=False, |
| down_block_types=( |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| "DownEncoderBlock2D", |
| ), |
| up_block_types=( |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| "UpDecoderBlock2D", |
| ), |
| ).to(device).eval() |
| chkpt_dir = train_config['vae']['weight_path'] |
| checkpoint = torch.load(chkpt_dir, map_location='cpu') |
| msg = vae.load_state_dict(checkpoint['model'], strict=False) |
| else: |
| raise |
| if accelerator.process_index == 0: |
| print_with_prefix(f'Model Loaded') |
|
|
| using_cfg = cfg_scale > 1.0 |
| if using_cfg: |
| if accelerator.process_index == 0: |
| print_with_prefix('Using cfg:', using_cfg) |
|
|
| if rank == 0: |
| os.makedirs(sample_folder_dir, exist_ok=True) |
| if accelerator.process_index == 0 and not demo_sample_mode: |
| print_with_prefix(f"Saving .png samples at {sample_folder_dir}") |
| accelerator.wait_for_everyone() |
|
|
| |
| n = train_config['sample']['per_proc_batch_size'] |
| global_batch_size = n * accelerator.num_processes |
| |
| num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)]) |
| total_samples = int(math.ceil(train_config['sample']['fid_num'] / global_batch_size) * global_batch_size) |
| if rank == 0: |
| if accelerator.process_index == 0: |
| print_with_prefix(f"Total number of images that will be sampled: {total_samples}") |
| assert total_samples % accelerator.num_processes == 0, "total_samples must be divisible by world_size" |
| samples_needed_this_gpu = int(total_samples // accelerator.num_processes) |
| assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" |
| iterations = int(samples_needed_this_gpu // n) |
| done_iterations = int( int(num_samples // accelerator.num_processes) // n) |
| pbar = range(iterations) |
| if not demo_sample_mode: |
| pbar = tqdm(pbar) if rank == 0 else pbar |
| total = 0 |
| |
| if accelerator.process_index == 0: |
| print_with_prefix("Using latent normalization") |
| if 'sample' in train_config['data']: |
| train_config['data']['data_path'] += '_sample' |
| dataset = ImgLatentDataset( |
| data_dir=train_config['data']['data_path'], |
| latent_norm=train_config['data']['latent_norm'] if 'latent_norm' in train_config['data'] else False, |
| latent_multiplier=train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215, |
| sample=train_config['data']['sample'] if 'sample' in train_config['data'] else False, |
| ) |
| latent_mean, latent_std = dataset.get_latent_stats() |
| latent_multiplier = train_config['data']['latent_multiplier'] if 'latent_multiplier' in train_config['data'] else 0.18215 |
| |
| latent_mean = latent_mean.clone().detach().to(device) |
| latent_std = latent_std.clone().detach().to(device) |
|
|
| if demo_sample_mode: |
| if accelerator.process_index == 0: |
| images = [] |
| if using_cfg: |
| for label in tqdm([975, 3, 207, 387, 388, 88, 979, 279], desc="Generating Demo Samples"): |
| z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device) |
| y = torch.tensor([label], device=device) |
| z = torch.cat([z, z], 0) |
| y_null = torch.tensor([1000] * 1, device=device) |
| y = torch.cat([y, y_null], 0) |
| model_kwargs = dict(y=y, cfg_scale=cfg_scale, cfg_interval=False, cfg_interval_start=cfg_interval_start) |
| model_fn = model.forward_with_cfg |
| samples = sample_fn(z, model_fn, **model_kwargs)[-1] |
| samples = (samples * latent_std) / latent_multiplier + latent_mean |
| samples = vae.decode_to_images(samples) |
| images.append(samples) |
| |
| else: |
| for label in tqdm([0]*8, desc="Generating Demo Samples"): |
| z = torch.randn(1, model.in_channels, latent_size, latent_size, device=device) |
| y = torch.tensor([label], device=device) |
| model_kwargs = dict(y=y) |
| model_fn = model.forward |
| samples = sample_fn(z, model_fn, **model_kwargs)[-1] |
| samples = (samples * latent_std) / latent_multiplier + latent_mean |
| samples = vae.decode_to_images(samples) |
| images.append(samples) |
|
|
| |
| os.makedirs('demo_images', exist_ok=True) |
| |
| all_images = np.stack([img[0] for img in images]) |
| |
| h, w = all_images.shape[1:3] |
| grid = np.zeros((2 * h, 4 * w, 3), dtype=np.uint8) |
| for idx, image in enumerate(all_images): |
| i, j = divmod(idx, 4) |
| grid[i*h:(i+1)*h, j*w:(j+1)*w] = image |
| |
| |
| exp_name = train_config['train']['exp_name'] |
| ckpt_iter = train_config['ckpt_path'].split("/")[-1][:-3] |
| Image.fromarray(grid).save(f'demo_images/{exp_name}_cfg{cfg_scale}_{ckpt_iter}_demo_samples.png') |
| return None |
| else: |
| for i in pbar: |
| |
| z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) |
| if 'trunaction' in train_config['sample']: |
| truncation_bound = train_config['sample']['truncation'] |
| for _ in range(100): |
| invalid_mask = torch.abs(z) > truncation_bound |
| if not invalid_mask.any(): |
| break |
| z[invalid_mask] = torch.randn_like(z[invalid_mask]) |
| y = torch.randint(0, train_config['data']['num_classes'], (n,), device=device) |
| |
| |
| if using_cfg: |
| z = torch.cat([z, z], 0) |
| y_null = torch.tensor([1000] * n, device=device) |
| y = torch.cat([y, y_null], 0) |
| model_kwargs = dict(y=y, cfg_scale=cfg_scale, cfg_interval=True, cfg_interval_start=cfg_interval_start) |
| model_fn = model.forward_with_cfg |
| else: |
| model_kwargs = dict(y=y) |
| model_fn = model.forward |
|
|
| samples = sample_fn(z, model_fn, **model_kwargs)[-1] |
| if using_cfg: |
| samples, _ = samples.chunk(2, dim=0) |
|
|
| samples = (samples * latent_std) / latent_multiplier + latent_mean |
| samples = vae.decode_to_images(samples) |
| |
| |
| for i, sample in enumerate(samples): |
| index = i * accelerator.num_processes + accelerator.process_index + total |
| Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") |
| total += global_batch_size |
| accelerator.wait_for_everyone() |
|
|
| return sample_folder_dir |
|
|
| |
| def print_with_prefix(*messages): |
| prefix = f"\033[34m[LightningDiT-Sampling {strftime('%Y-%m-%d %H:%M:%S')}]\033[0m" |
| combined_message = ' '.join(map(str, messages)) |
| print(f"{prefix}: {combined_message}") |
|
|
| def load_config(config_path): |
| with open(config_path, "r") as file: |
| config = yaml.safe_load(file) |
| return config |
|
|
| if __name__ == "__main__": |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='configs/lightningdit_b_ldmvae_f16d16.yaml') |
| parser.add_argument('--demo', action='store_true', default=False) |
| args = parser.parse_args() |
| accelerator = Accelerator() |
| train_config = load_config(args.config) |
|
|
| |
| assert 'ckpt_path' in train_config, "ckpt_path must be specified in config" |
| if accelerator.process_index == 0: |
| print_with_prefix('Using ckpt:', train_config['ckpt_path']) |
| ckpt_dir = train_config['ckpt_path'] |
|
|
| if 'downsample_ratio' in train_config['vae']: |
| latent_size = train_config['data']['image_size'] // train_config['vae']['downsample_ratio'] |
| else: |
| latent_size = train_config['data']['image_size'] // 16 |
|
|
| |
| model = LightningDiT_models[train_config['model']['model_type']]( |
| input_size=latent_size, |
| num_classes=train_config['data']['num_classes'], |
| use_qknorm=train_config['model']['use_qknorm'], |
| use_swiglu=train_config['model']['use_swiglu'] if 'use_swiglu' in train_config['model'] else False, |
| use_rope=train_config['model']['use_rope'] if 'use_rope' in train_config['model'] else False, |
| use_rmsnorm=train_config['model']['use_rmsnorm'] if 'use_rmsnorm' in train_config['model'] else False, |
| wo_shift=train_config['model']['wo_shift'] if 'wo_shift' in train_config['model'] else False, |
| in_channels=train_config['model']['in_chans'] if 'in_chans' in train_config['model'] else 4, |
| learn_sigma=train_config['model']['learn_sigma'] if 'learn_sigma' in train_config['model'] else False, |
| class_dropout_prob=0 if train_config['data']['num_classes'] == 1 else 0.1, |
| ) |
|
|
| |
| sample_folder_dir = do_sample(train_config, accelerator, ckpt_path=ckpt_dir, model=model, demo_sample_mode=args.demo) |
| |
| if not args.demo: |
| |
| |
| if accelerator.process_index == 0: |
| from tools.calculate_fid import calculate_fid_given_paths |
| print_with_prefix('Calculating FID with {} number of samples'.format(train_config['sample']['fid_num'])) |
| assert 'fid_reference_file' in train_config['data'], "fid_reference_file must be specified in config" |
| fid_reference_file = train_config['data']['fid_reference_file'] |
| fid = calculate_fid_given_paths( |
| [fid_reference_file, sample_folder_dir], |
| batch_size=50, |
| dims=2048, |
| device='cuda', |
| num_workers=8, |
| sp_len = train_config['sample']['fid_num'] |
| ) |
| print_with_prefix('fid=',fid) |
|
|