Diffusion-Transformer / sample_dit.py
YashNagraj75's picture
Add samples at all good, less budget
da3087e
import argparse
import os
import torch
import torchvision
import yaml
from torchvision.utils import make_grid
from tqdm import tqdm
from model.transformer import DIT
from model.vae import VAE
from scheduler.linear_scheduler import LinearNoiseScheduler
device = "cuda" if torch.cuda.is_available() else "cpu"
def sample(
model,
scheduler,
train_config,
dit_config,
vae_config,
vae,
diffusion_config,
dataset_config,
):
im_size = dataset_config["im_size"] // 2 ** sum(vae_config["down_sample"])
xt = torch.randn(
(train_config["num_samples"], vae_config["z_channels"], im_size, im_size)
).to(device)
for i in tqdm(reversed(range(diffusion_config["num_timesteps"]))):
noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
xt, x0_pred = scheduler.sample_prev_timestep(
xt, noise_pred, torch.as_tensor(i).to(device)
)
if i == 0:
ims = vae.to(device).decode(xt)
else:
ims = xt
ims = xt[:, :-1, :, :]
ims = torch.clamp(ims, -1.0, 1.0).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=train_config["num_grid_rows"])
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(os.path.join(train_config["task_name"], "samples")):
os.mkdir(os.path.join(train_config["task_name"], "samples"))
img.save(
os.path.join(train_config["task_name"], "samples", "x0_{}.jpg".format(i))
)
img.close()
def infer(args):
# Read the config file #
with open(args.config_path, "r") as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
########################
diffusion_config = config["diffusion_params"]
dataset_config = config["dataset_params"]
dit_model_config = config["dit_params"]
autoencoder_model_config = config["autoencoder_params"]
train_config = config["train_params"]
# Create the noise scheduler
scheduler = LinearNoiseScheduler(
num_timesteps=diffusion_config["num_timesteps"],
beta_start=diffusion_config["beta_start"],
beta_end=diffusion_config["beta_end"],
)
# Get latent image size
im_size = dataset_config["im_size"] // 2 ** sum(
autoencoder_model_config["down_sample"]
)
model = DIT(
im_size=im_size,
im_channels=autoencoder_model_config["z_channels"],
config=dit_model_config,
).to(device)
model.eval()
assert os.path.exists(
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"])
), "Train DiT first"
checkpoint = torch.load(
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
map_location=device,
)
model.load_state_dict(checkpoint["dit"])
print("Loaded dit checkpoint")
# Create output directories
if not os.path.exists(train_config["task_name"]):
os.mkdir(train_config["task_name"])
vae = VAE(
im_channels=dataset_config["im_channels"], model_config=autoencoder_model_config
)
vae.eval()
# Load vae if found
assert os.path.exists(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
)
), "VAE checkpoint not present. Train VAE first."
vae.load_state_dict(
torch.load(
os.path.join(
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
),
map_location=device,
),
strict=True,
)
print("Loaded vae checkpoint")
with torch.no_grad():
sample(
model=model,
dataset_config=dataset_config,
vae_config=autoencoder_model_config,
dit_config=dit_model_config,
scheduler=scheduler,
vae=vae,
train_config=train_config,
diffusion_config=diffusion_config,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Arguments for dit image generation")
parser.add_argument(
"--config", dest="config_path", default="celeba/config.yaml", type=str
)
args = parser.parse_args()
infer(args)