Spaces:
Sleeping
Sleeping
File size: 4,999 Bytes
f075308 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import shutil
import subprocess
from distutils.dir_util import copy_tree
from shutil import copyfile
from typing import List, Optional
from hydra.utils import instantiate
import click
import git
from omegaconf import DictConfig
#----------------------------------------------------------------------------
def copy_objects(target_dir: os.PathLike, objects_to_copy: List[os.PathLike]):
for src_path in objects_to_copy:
trg_path = os.path.join(target_dir, os.path.basename(src_path))
if os.path.islink(src_path):
os.symlink(os.readlink(src_path), trg_path)
elif os.path.isfile(src_path):
copyfile(src_path, trg_path)
elif os.path.isdir(src_path):
copy_tree(src_path, trg_path)
else:
raise NotImplementedError(f"Unknown object type: {src_path}")
#----------------------------------------------------------------------------
def create_symlinks(target_dir: os.PathLike, symlinks_to_create: List[os.PathLike]):
"""
Creates symlinks to the given paths
"""
for src_path in symlinks_to_create:
trg_path = os.path.join(target_dir, os.path.basename(src_path))
if os.path.islink(src_path):
# Let's not create symlinks to symlinks
# Since dropping the current symlink will break the experiment
os.symlink(os.readlink(src_path), trg_path)
else:
print(f'Creating a symlink to {src_path}, so try not to delete it occasionally!')
os.symlink(src_path, trg_path)
#----------------------------------------------------------------------------
def is_git_repo(path: os.PathLike):
try:
_ = git.Repo(path).git_dir
return True
except git.exc.InvalidGitRepositoryError:
return False
#----------------------------------------------------------------------------
def create_project_dir(
project_dir: os.PathLike,
objects_to_copy: List[os.PathLike],
symlinks_to_create: List[os.PathLike],
quiet: bool=False,
ignore_uncommited_changes: bool=False,
overwrite: bool=False):
if is_git_repo(os.getcwd()) and are_there_uncommitted_changes():
if ignore_uncommited_changes or click.confirm("There are uncommited changes. Continue?", default=False):
pass
else:
raise PermissionError("Cannot created a dir when there are uncommited changes")
if os.path.exists(project_dir):
if overwrite or click.confirm(f'Dir {project_dir} already exists. Overwrite it?', default=False):
shutil.rmtree(project_dir)
else:
print('User refused to delete an existing project dir.')
raise PermissionError("There is an existing dir and I cannot delete it.")
os.makedirs(project_dir)
copy_objects(project_dir, objects_to_copy)
create_symlinks(project_dir, symlinks_to_create)
if not quiet:
print(f'Created a project dir: {project_dir}')
#----------------------------------------------------------------------------
def get_git_hash() -> Optional[str]:
if not is_git_repo(os.getcwd()):
return None
try:
return subprocess \
.check_output(['git', 'rev-parse', '--short', 'HEAD']) \
.decode("utf-8") \
.strip()
except:
return None
#----------------------------------------------------------------------------
# def get_experiment_path(master_dir: os.PathLike, experiment_name: str) -> os.PathLike:
# return os.path.join(master_dir, f"{experiment_name}-{get_git_hash()}")
#----------------------------------------------------------------------------
def get_git_hash_suffix() -> str:
git_hash: Optional[str] = get_git_hash()
git_hash_suffix = "-nogit" if git_hash is None else f"-{git_hash}"
return git_hash_suffix
#----------------------------------------------------------------------------
def are_there_uncommitted_changes() -> bool:
return len(subprocess.check_output('git status -s'.split()).decode("utf-8")) > 0
#----------------------------------------------------------------------------
def cfg_to_args_str(cfg: DictConfig, use_dashes=True) -> str:
dashes = '--' if use_dashes else ''
return ' '.join([f'{dashes}{p}={cfg[p]}' for p in cfg])
#----------------------------------------------------------------------------
def recursive_instantiate(cfg: DictConfig):
for key in cfg:
# print(type(cfg[key]))
if isinstance(cfg[key], DictConfig):
if '_target_' in cfg[key]:
cfg[key] = instantiate(cfg[key])
else:
recursive_instantiate(cfg[key])
#----------------------------------------------------------------------------
def num_gpus_to_mem(num_gpus: int, mem_per_gpu: 64) -> str:
# Doing it here since hydra config cannot do formatting for ${...}
return f"{num_gpus * mem_per_gpu}G"
#---------------------------------------------------------------------------- |