from train import eval_ddim, train, eval def main(model_config=None): modelConfig = { "state": "eval", # or eval "epochs": 200, "batch_size": 80, "T": 700, "channel": 128, "ch_mult": [1, 2, 3, 4], "attn": [2], "num_res_blocks": 2, "dropout": 0.15, "lr": 1e-4, "multiplier": 2., "beta_1": 1e-4, "beta_T": 0.02, "img_size": 32, "grad_clip": 1., "device": "cuda:0", "training_load_weight": None, "checkpoint_dir": "Checkpoints/", "test_load_weight": "ckpt_199.pth", "sample_dir": "SampledImgs/", "sampledNoisyImgName": "NoisyNoGuidenceImgs.png", "sampledImgName": "SampledNoGuidenceImgs_600.png", "nrow": 8 } if modelConfig['state'] == 'train': train(modelConfig) elif modelConfig['state'] == 'eval': eval(modelConfig) elif modelConfig['state'] == 'eval_ddim': eval_ddim(modelConfig) if __name__ == "__main__": main()