Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import subprocess | |
| from distutils.dir_util import copy_tree | |
| from shutil import copyfile | |
| from typing import List, Optional | |
| from hydra.utils import instantiate | |
| import click | |
| import git | |
| from omegaconf import DictConfig | |
| #---------------------------------------------------------------------------- | |
| def copy_objects(target_dir: os.PathLike, objects_to_copy: List[os.PathLike]): | |
| for src_path in objects_to_copy: | |
| trg_path = os.path.join(target_dir, os.path.basename(src_path)) | |
| if os.path.islink(src_path): | |
| os.symlink(os.readlink(src_path), trg_path) | |
| elif os.path.isfile(src_path): | |
| copyfile(src_path, trg_path) | |
| elif os.path.isdir(src_path): | |
| copy_tree(src_path, trg_path) | |
| else: | |
| raise NotImplementedError(f"Unknown object type: {src_path}") | |
| #---------------------------------------------------------------------------- | |
| def create_symlinks(target_dir: os.PathLike, symlinks_to_create: List[os.PathLike]): | |
| """ | |
| Creates symlinks to the given paths | |
| """ | |
| for src_path in symlinks_to_create: | |
| trg_path = os.path.join(target_dir, os.path.basename(src_path)) | |
| if os.path.islink(src_path): | |
| # Let's not create symlinks to symlinks | |
| # Since dropping the current symlink will break the experiment | |
| os.symlink(os.readlink(src_path), trg_path) | |
| else: | |
| print(f'Creating a symlink to {src_path}, so try not to delete it occasionally!') | |
| os.symlink(src_path, trg_path) | |
| #---------------------------------------------------------------------------- | |
| def is_git_repo(path: os.PathLike): | |
| try: | |
| _ = git.Repo(path).git_dir | |
| return True | |
| except git.exc.InvalidGitRepositoryError: | |
| return False | |
| #---------------------------------------------------------------------------- | |
| def create_project_dir( | |
| project_dir: os.PathLike, | |
| objects_to_copy: List[os.PathLike], | |
| symlinks_to_create: List[os.PathLike], | |
| quiet: bool=False, | |
| ignore_uncommited_changes: bool=False, | |
| overwrite: bool=False): | |
| if is_git_repo(os.getcwd()) and are_there_uncommitted_changes(): | |
| if ignore_uncommited_changes or click.confirm("There are uncommited changes. Continue?", default=False): | |
| pass | |
| else: | |
| raise PermissionError("Cannot created a dir when there are uncommited changes") | |
| if os.path.exists(project_dir): | |
| if overwrite or click.confirm(f'Dir {project_dir} already exists. Overwrite it?', default=False): | |
| shutil.rmtree(project_dir) | |
| else: | |
| print('User refused to delete an existing project dir.') | |
| raise PermissionError("There is an existing dir and I cannot delete it.") | |
| os.makedirs(project_dir) | |
| copy_objects(project_dir, objects_to_copy) | |
| create_symlinks(project_dir, symlinks_to_create) | |
| if not quiet: | |
| print(f'Created a project dir: {project_dir}') | |
| #---------------------------------------------------------------------------- | |
| def get_git_hash() -> Optional[str]: | |
| if not is_git_repo(os.getcwd()): | |
| return None | |
| try: | |
| return subprocess \ | |
| .check_output(['git', 'rev-parse', '--short', 'HEAD']) \ | |
| .decode("utf-8") \ | |
| .strip() | |
| except: | |
| return None | |
| #---------------------------------------------------------------------------- | |
| # def get_experiment_path(master_dir: os.PathLike, experiment_name: str) -> os.PathLike: | |
| # return os.path.join(master_dir, f"{experiment_name}-{get_git_hash()}") | |
| #---------------------------------------------------------------------------- | |
| def get_git_hash_suffix() -> str: | |
| git_hash: Optional[str] = get_git_hash() | |
| git_hash_suffix = "-nogit" if git_hash is None else f"-{git_hash}" | |
| return git_hash_suffix | |
| #---------------------------------------------------------------------------- | |
| def are_there_uncommitted_changes() -> bool: | |
| return len(subprocess.check_output('git status -s'.split()).decode("utf-8")) > 0 | |
| #---------------------------------------------------------------------------- | |
| def cfg_to_args_str(cfg: DictConfig, use_dashes=True) -> str: | |
| dashes = '--' if use_dashes else '' | |
| return ' '.join([f'{dashes}{p}={cfg[p]}' for p in cfg]) | |
| #---------------------------------------------------------------------------- | |
| def recursive_instantiate(cfg: DictConfig): | |
| for key in cfg: | |
| # print(type(cfg[key])) | |
| if isinstance(cfg[key], DictConfig): | |
| if '_target_' in cfg[key]: | |
| cfg[key] = instantiate(cfg[key]) | |
| else: | |
| recursive_instantiate(cfg[key]) | |
| #---------------------------------------------------------------------------- | |
| def num_gpus_to_mem(num_gpus: int, mem_per_gpu: 64) -> str: | |
| # Doing it here since hydra config cannot do formatting for ${...} | |
| return f"{num_gpus * mem_per_gpu}G" | |
| #---------------------------------------------------------------------------- |