File size: 4,999 Bytes
f075308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import shutil
import subprocess
from distutils.dir_util import copy_tree
from shutil import copyfile
from typing import List, Optional

from hydra.utils import instantiate
import click
import git
from omegaconf import DictConfig

#----------------------------------------------------------------------------

def copy_objects(target_dir: os.PathLike, objects_to_copy: List[os.PathLike]):
    for src_path in objects_to_copy:
        trg_path = os.path.join(target_dir, os.path.basename(src_path))

        if os.path.islink(src_path):
            os.symlink(os.readlink(src_path), trg_path)
        elif os.path.isfile(src_path):
            copyfile(src_path, trg_path)
        elif os.path.isdir(src_path):
            copy_tree(src_path, trg_path)
        else:
            raise NotImplementedError(f"Unknown object type: {src_path}")

#----------------------------------------------------------------------------

def create_symlinks(target_dir: os.PathLike, symlinks_to_create: List[os.PathLike]):
    """
    Creates symlinks to the given paths
    """
    for src_path in symlinks_to_create:
        trg_path = os.path.join(target_dir, os.path.basename(src_path))

        if os.path.islink(src_path):
            # Let's not create symlinks to symlinks
            # Since dropping the current symlink will break the experiment
            os.symlink(os.readlink(src_path), trg_path)
        else:
            print(f'Creating a symlink to {src_path}, so try not to delete it occasionally!')
            os.symlink(src_path, trg_path)

#----------------------------------------------------------------------------

def is_git_repo(path: os.PathLike):
    try:
        _ = git.Repo(path).git_dir
        return True
    except git.exc.InvalidGitRepositoryError:
        return False

#----------------------------------------------------------------------------

def create_project_dir(
    project_dir: os.PathLike,
    objects_to_copy: List[os.PathLike],
    symlinks_to_create: List[os.PathLike],
    quiet: bool=False,
    ignore_uncommited_changes: bool=False,
    overwrite: bool=False):

    if is_git_repo(os.getcwd()) and are_there_uncommitted_changes():
        if ignore_uncommited_changes or click.confirm("There are uncommited changes. Continue?", default=False):
            pass
        else:
            raise PermissionError("Cannot created a dir when there are uncommited changes")

    if os.path.exists(project_dir):
        if overwrite or click.confirm(f'Dir {project_dir} already exists. Overwrite it?', default=False):
            shutil.rmtree(project_dir)
        else:
            print('User refused to delete an existing project dir.')
            raise PermissionError("There is an existing dir and I cannot delete it.")

    os.makedirs(project_dir)
    copy_objects(project_dir, objects_to_copy)
    create_symlinks(project_dir, symlinks_to_create)

    if not quiet:
        print(f'Created a project dir: {project_dir}')

#----------------------------------------------------------------------------

def get_git_hash() -> Optional[str]:
    if not is_git_repo(os.getcwd()):
        return None

    try:
        return subprocess \
            .check_output(['git', 'rev-parse', '--short', 'HEAD']) \
            .decode("utf-8") \
            .strip()
    except:
        return None

#----------------------------------------------------------------------------

# def get_experiment_path(master_dir: os.PathLike, experiment_name: str) -> os.PathLike:
#     return os.path.join(master_dir, f"{experiment_name}-{get_git_hash()}")

#----------------------------------------------------------------------------

def get_git_hash_suffix() -> str:
    git_hash: Optional[str] = get_git_hash()
    git_hash_suffix = "-nogit" if git_hash is None else f"-{git_hash}"

    return git_hash_suffix

#----------------------------------------------------------------------------

def are_there_uncommitted_changes() -> bool:
    return len(subprocess.check_output('git status -s'.split()).decode("utf-8")) > 0

#----------------------------------------------------------------------------

def cfg_to_args_str(cfg: DictConfig, use_dashes=True) -> str:
    dashes = '--' if use_dashes else ''

    return ' '.join([f'{dashes}{p}={cfg[p]}' for p in cfg])

#----------------------------------------------------------------------------

def recursive_instantiate(cfg: DictConfig):
    for key in cfg:
        # print(type(cfg[key]))
        if isinstance(cfg[key], DictConfig):
            if '_target_' in cfg[key]:
                cfg[key] = instantiate(cfg[key])
            else:
                recursive_instantiate(cfg[key])

#----------------------------------------------------------------------------

def num_gpus_to_mem(num_gpus: int, mem_per_gpu: 64) -> str:
    # Doing it here since hydra config cannot do formatting for ${...}
    return f"{num_gpus * mem_per_gpu}G"

#----------------------------------------------------------------------------