YashNagraj75's picture
Add scheduler
04897d8
import torch
import yaml
import os
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset.celeba import create_dataloader
from models import vqvae
from models.unet_cond import UNet
from models.vqvae import VQVAE
from scheduler import LinearNoiseScheduler
from torch.optim import Adam
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(args):
with open(args.config_path, "r") as f:
try:
config = yaml.safe_load(f)
except yaml.YAMLError as exc:
print(exc)
print(config)
diffusion_config = config["diffusion_params"]
dataset_config = config["dataset_params"]
diffusion_model_config = config["ldm_params"]
autoencoder_config = config["autoencoder_params"]
train_config = config["train_config"]
# Create noise scheduler
scheduler = LinearNoiseScheduler(
num_timesteps=diffusion_config["num_timesteps"],
beta_start=diffusion_config["beta_start"],
beta_end=diffusion_config["beta_end"],
)
dataloader = create_dataloader(dataset_config["im_path"])
ldm_ckpt_path = os.path.join(train_config["task_name"],train_config["ldm_ckpt_name"])
vqvae_ckpt_path = os.path.join(
train_config["task_name"], train_config["vqvae_autoencoder_ckpt_name"]
)
if os.path.exists