| import torch |
| import yaml |
| import os |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from dataset.celeba import create_dataloader |
| from models import vqvae |
| from models.unet_cond import UNet |
| from models.vqvae import VQVAE |
| from scheduler import LinearNoiseScheduler |
| from torch.optim import Adam |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def train(args): |
| with open(args.config_path, "r") as f: |
| try: |
| config = yaml.safe_load(f) |
| except yaml.YAMLError as exc: |
| print(exc) |
|
|
| print(config) |
|
|
| diffusion_config = config["diffusion_params"] |
| dataset_config = config["dataset_params"] |
| diffusion_model_config = config["ldm_params"] |
| autoencoder_config = config["autoencoder_params"] |
| train_config = config["train_config"] |
|
|
| |
| scheduler = LinearNoiseScheduler( |
| num_timesteps=diffusion_config["num_timesteps"], |
| beta_start=diffusion_config["beta_start"], |
| beta_end=diffusion_config["beta_end"], |
| ) |
|
|
| dataloader = create_dataloader(dataset_config["im_path"]) |
|
|
| ldm_ckpt_path = os.path.join(train_config["task_name"],train_config["ldm_ckpt_name"]) |
| vqvae_ckpt_path = os.path.join( |
| train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"] |
| ) |
|
|
| if os.path.exists |
|
|