| from train import eval_ddim, train, eval | |
| def main(model_config=None): | |
| modelConfig = { | |
| "state": "eval", # or eval | |
| "epochs": 200, | |
| "batch_size": 80, | |
| "T": 700, | |
| "channel": 128, | |
| "ch_mult": [1, 2, 3, 4], | |
| "attn": [2], | |
| "num_res_blocks": 2, | |
| "dropout": 0.15, | |
| "lr": 1e-4, | |
| "multiplier": 2., | |
| "beta_1": 1e-4, | |
| "beta_T": 0.02, | |
| "img_size": 32, | |
| "grad_clip": 1., | |
| "device": "cuda:0", | |
| "training_load_weight": None, | |
| "checkpoint_dir": "Checkpoints/", | |
| "test_load_weight": "ckpt_199.pth", | |
| "sample_dir": "SampledImgs/", | |
| "sampledNoisyImgName": "NoisyNoGuidenceImgs.png", | |
| "sampledImgName": "SampledNoGuidenceImgs_600.png", | |
| "nrow": 8 | |
| } | |
| if modelConfig['state'] == 'train': | |
| train(modelConfig) | |
| elif modelConfig['state'] == 'eval': | |
| eval(modelConfig) | |
| elif modelConfig['state'] == 'eval_ddim': | |
| eval_ddim(modelConfig) | |
| if __name__ == "__main__": | |
| main() |