Reward-Forcing / train.py
fffiloni's picture
Migrated from GitHub
34c9227 verified
raw
history blame
1.77 kB
import argparse
import os
from omegaconf import OmegaConf
import wandb
from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer, RewardedDistillationTrainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--no_save", action="store_true")
parser.add_argument("--no_visualize", action="store_true")
parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
parser.add_argument("--disable-wandb", action="store_true")
args = parser.parse_args()
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
config.no_save = args.no_save
config.no_visualize = args.no_visualize
# get the filename of config_path
config_name = os.path.basename(args.config_path).split(".")[0]
config.config_name = config_name
config.logdir = args.logdir
config.wandb_save_dir = args.wandb_save_dir
config.disable_wandb = args.disable_wandb
if config.trainer == "diffusion":
trainer = DiffusionTrainer(config)
elif config.trainer == "gan":
trainer = GANTrainer(config)
elif config.trainer == "ode":
trainer = ODETrainer(config)
elif config.trainer == "score_distillation":
trainer = ScoreDistillationTrainer(config)
elif config.trainer == "rewarded_distillation":
trainer = RewardedDistillationTrainer(config)
trainer.train()
wandb.finish()
if __name__ == "__main__":
main()