File size: 1,534 Bytes
d62394f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import random
import shutil
from argparse import ArgumentParser
import numpy as np
import torch
import yaml
def clean_dir(path):
if os.path.exists(path):
shutil.rmtree(path)
def get_latest_ckpt_step(load_path):
saved_steps = [
int(os.path.splitext(path)[0].split("-")[-1])
for path in os.listdir(load_path)
if path.endswith(".pt")
]
latest_step = -1 if len(saved_steps) == 0 else max(saved_steps)
return latest_step
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_cfg(cfg_path: str, parser: ArgumentParser) -> ArgumentParser:
with open(cfg_path, "r", encoding="utf-8") as file:
cfg: dict = yaml.safe_load(file)
for key, value in cfg.items():
if value is None:
raise ValueError("'None' is not a supported value in the config file")
if isinstance(value, bool):
parser.add_argument(f"--{key}", action="store_true", default=value)
else:
parser.add_argument(f"--{key}", type=type(value), default=value)
return parser
def save_cfg(path: str, args, mode="w"):
with open(path, mode=mode, encoding="utf-8") as file:
print("#################### Training Config ####################", file=file)
yaml.dump(vars(args), file, default_flow_style=False)
|