import torch import yaml import os from tqdm import tqdm from torch.utils.data import DataLoader from dataset.celeba import create_dataloader from models import vqvae from models.unet_cond import UNet from models.vqvae import VQVAE from scheduler import LinearNoiseScheduler from torch.optim import Adam device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def train(args): with open(args.config_path, "r") as f: try: config = yaml.safe_load(f) except yaml.YAMLError as exc: print(exc) print(config) diffusion_config = config["diffusion_params"] dataset_config = config["dataset_params"] diffusion_model_config = config["ldm_params"] autoencoder_config = config["autoencoder_params"] train_config = config["train_config"] # Create noise scheduler scheduler = LinearNoiseScheduler( num_timesteps=diffusion_config["num_timesteps"], beta_start=diffusion_config["beta_start"], beta_end=diffusion_config["beta_end"], ) dataloader = create_dataloader(dataset_config["im_path"]) ldm_ckpt_path = os.path.join(train_config["task_name"],train_config["ldm_ckpt_name"]) vqvae_ckpt_path = os.path.join( train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] ) if os.path.exists