import argparse import os import torch import torchvision import yaml from torchvision.utils import make_grid from tqdm import tqdm from model.transformer import DIT from model.vae import VAE from scheduler.linear_scheduler import LinearNoiseScheduler device = "cuda" if torch.cuda.is_available() else "cpu" def sample( model, scheduler, train_config, dit_config, vae_config, vae, diffusion_config, dataset_config, ): im_size = dataset_config["im_size"] // 2 ** sum(vae_config["down_sample"]) xt = torch.randn( (train_config["num_samples"], vae_config["z_channels"], im_size, im_size) ).to(device) for i in tqdm(reversed(range(diffusion_config["num_timesteps"]))): noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) xt, x0_pred = scheduler.sample_prev_timestep( xt, noise_pred, torch.as_tensor(i).to(device) ) if i == 0: ims = vae.to(device).decode(xt) else: ims = xt ims = xt[:, :-1, :, :] ims = torch.clamp(ims, -1.0, 1.0).detach().cpu() ims = (ims + 1) / 2 grid = make_grid(ims, nrow=train_config["num_grid_rows"]) img = torchvision.transforms.ToPILImage()(grid) if not os.path.exists(os.path.join(train_config["task_name"], "samples")): os.mkdir(os.path.join(train_config["task_name"], "samples")) img.save( os.path.join(train_config["task_name"], "samples", "x0_{}.jpg".format(i)) ) img.close() def infer(args): # Read the config file # with open(args.config_path, "r") as file: try: config = yaml.safe_load(file) except yaml.YAMLError as exc: print(exc) ######################## diffusion_config = config["diffusion_params"] dataset_config = config["dataset_params"] dit_model_config = config["dit_params"] autoencoder_model_config = config["autoencoder_params"] train_config = config["train_params"] # Create the noise scheduler scheduler = LinearNoiseScheduler( num_timesteps=diffusion_config["num_timesteps"], beta_start=diffusion_config["beta_start"], beta_end=diffusion_config["beta_end"], ) # Get latent image size im_size = dataset_config["im_size"] // 2 ** sum( autoencoder_model_config["down_sample"] ) model = DIT( im_size=im_size, im_channels=autoencoder_model_config["z_channels"], config=dit_model_config, ).to(device) model.eval() assert os.path.exists( os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]) ), "Train DiT first" checkpoint = torch.load( os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]), map_location=device, ) model.load_state_dict(checkpoint["dit"]) print("Loaded dit checkpoint") # Create output directories if not os.path.exists(train_config["task_name"]): os.mkdir(train_config["task_name"]) vae = VAE( im_channels=dataset_config["im_channels"], model_config=autoencoder_model_config ) vae.eval() # Load vae if found assert os.path.exists( os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ) ), "VAE checkpoint not present. Train VAE first." vae.load_state_dict( torch.load( os.path.join( train_config["task_name"], train_config["vae_autoencoder_ckpt_name"] ), map_location=device, ), strict=True, ) print("Loaded vae checkpoint") with torch.no_grad(): sample( model=model, dataset_config=dataset_config, vae_config=autoencoder_model_config, dit_config=dit_model_config, scheduler=scheduler, vae=vae, train_config=train_config, diffusion_config=diffusion_config, ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Arguments for dit image generation") parser.add_argument( "--config", dest="config_path", default="celeba/config.yaml", type=str ) args = parser.parse_args() infer(args)