osu_mapper / osuT5 /utils /init_utils.py
Tiger14n's picture
edit
98121e1
import torch
import os
from accelerate.utils import set_seed
from omegaconf import open_dict, DictConfig
def check_args_and_env(args: DictConfig) -> None:
assert args.optim.batch_size % args.optim.grad_acc == 0
# Train log must happen before eval log
assert args.eval.every_steps % args.logging.every_steps == 0
if args.device == "gpu":
assert torch.cuda.is_available(), "We use GPU to train/eval the model"
def opti_flags(args: DictConfig) -> None:
# This lines reduce training step by 2.4x
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def update_args_with_env_info(args: DictConfig) -> None:
with open_dict(args):
slurm_id = os.getenv("SLURM_JOB_ID")
if slurm_id is not None:
args.slurm_id = slurm_id
else:
args.slurm_id = "none"
args.working_dir = os.getcwd()
def setup_args(args: DictConfig) -> None:
check_args_and_env(args)
update_args_with_env_info(args)
opti_flags(args)
if args.seed is not None:
set_seed(args.seed)