| | import logging |
| | import os |
| | import time |
| |
|
| | import git |
| | import omegaconf |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def load_config(config_path, command_line_args=None): |
| | """Loads the config file using OmegaConf, performing merges with base configs and the command line arguments""" |
| |
|
| | logger.info(f"Loading from: {config_path}") |
| |
|
| | |
| | config = omegaconf.OmegaConf.load(config_path) |
| |
|
| | |
| | if hasattr(config, "include"): |
| | base_config_paths = [os.path.join(os.path.dirname(config_path), include_path) for include_path in config.include] |
| | base_configs = [load_config(base_config_path) for base_config_path in base_config_paths] |
| | config = omegaconf.OmegaConf.merge(*base_configs, config) |
| |
|
| | |
| | if command_line_args is not None: |
| | command_line_config = omegaconf.OmegaConf.from_dotlist(command_line_args) |
| | config = omegaconf.OmegaConf.merge(config, command_line_config) |
| |
|
| | return config |
| |
|
| |
|
| | def save_git_commit_info(save_path): |
| | """Use gitpython to save info about the current git commit to a file""" |
| |
|
| | repo = git.Repo(search_parent_directories=True) |
| | head_commit = repo.head.commit |
| | git_commit_info = { |
| | "hexsha": head_commit.hexsha, |
| | "authored": { |
| | "author": head_commit.author.name, |
| | "authored_time": head_commit.authored_date, |
| | }, |
| | "committed": { |
| | "commit": head_commit.committer.name, |
| | "committed_time": head_commit.committed_date, |
| | }, |
| | "message": head_commit.message.strip(), |
| | } |
| |
|
| | git_commit_info = omegaconf.OmegaConf.create(git_commit_info) |
| | omegaconf.OmegaConf.save(git_commit_info, save_path) |
| | return git_commit_info |
| |
|
| |
|
| | def create_workspace(config): |
| | """Create a results folder in the target directory""" |
| |
|
| | |
| | config.name = time.strftime(config.name, time.localtime()) |
| |
|
| | |
| | os.makedirs(config.save_dir) |
| |
|
| | |
| | omegaconf.OmegaConf.save(config, os.path.join(config.save_dir, "config.yaml")) |
| | save_git_commit_info(os.path.join(config.save_dir, "git.yaml")) |
| |
|
| | |
| | |
| | for handler in logging.root.handlers: |
| | logging.root.removeHandler(handler) |
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | level=logging.INFO, |
| | handlers=[ |
| | logging.FileHandler(os.path.join(config.save_dir, "output.log")), |
| | logging.StreamHandler(), |
| | ], |
| | ) |
| |
|
| | return config |
| |
|