File size: 1,534 Bytes
d62394f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import random
import shutil
from argparse import ArgumentParser

import numpy as np
import torch
import yaml


def clean_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)


def get_latest_ckpt_step(load_path):
    saved_steps = [
        int(os.path.splitext(path)[0].split("-")[-1])
        for path in os.listdir(load_path)
        if path.endswith(".pt")
    ]
    latest_step = -1 if len(saved_steps) == 0 else max(saved_steps)
    return latest_step


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_cfg(cfg_path: str, parser: ArgumentParser) -> ArgumentParser:
    with open(cfg_path, "r", encoding="utf-8") as file:
        cfg: dict = yaml.safe_load(file)
    for key, value in cfg.items():
        if value is None:
            raise ValueError("'None' is not a supported value in the config file")
        if isinstance(value, bool):
            parser.add_argument(f"--{key}", action="store_true", default=value)
        else:
            parser.add_argument(f"--{key}", type=type(value), default=value)
    return parser


def save_cfg(path: str, args, mode="w"):
    with open(path, mode=mode, encoding="utf-8") as file:
        print("#################### Training Config ####################", file=file)
        yaml.dump(vars(args), file, default_flow_style=False)