| import os |
| from source.trainer import EDGSTrainer |
| from source.utils_aux import set_seed |
| import omegaconf |
| import wandb |
| import hydra |
| from argparse import Namespace |
| from omegaconf import OmegaConf |
|
|
|
|
| @hydra.main(config_path="configs", config_name="train", version_base="1.2") |
| def main(cfg: omegaconf.DictConfig): |
| _ = wandb.init(entity=cfg.wandb.entity, |
| project=cfg.wandb.project, |
| config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), |
| tags=[cfg.wandb.tag], |
| name = cfg.wandb.name, |
| mode = cfg.wandb.mode) |
| omegaconf.OmegaConf.resolve(cfg) |
| set_seed(cfg.seed) |
|
|
| |
| print("Output folder: {}".format(cfg.gs.dataset.model_path)) |
| os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) |
| with open(os.path.join(cfg.gs.dataset.model_path, "cfg_args"), 'w') as cfg_log_f: |
| params = { |
| "sh_degree": 3, |
| "source_path": cfg.gs.dataset.source_path, |
| "model_path": cfg.gs.dataset.model_path, |
| "images": cfg.gs.dataset.images, |
| "depths": "", |
| "resolution": -1, |
| "_white_background": cfg.gs.dataset.white_background, |
| "train_test_exp": False, |
| "data_device": cfg.gs.dataset.data_device, |
| "eval": False, |
| "convert_SHs_python": False, |
| "compute_cov3D_python": False, |
| "debug": False, |
| "antialiasing": False |
| } |
| cfg_log_f.write(str(Namespace(**params))) |
|
|
| |
| gs = hydra.utils.instantiate(cfg.gs) |
|
|
| |
| trainer = EDGSTrainer(GS=gs, |
| training_config=cfg.gs.opt, |
| device=cfg.device) |
| |
| trainer.load_checkpoints(cfg.load) |
| trainer.timer.start() |
| trainer.init_with_corr(cfg.init_wC) |
| trainer.train(cfg.train) |
| |
| |
| wandb.finish() |
| print("\nTraining complete.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|