| | import torch
|
| | import os
|
| |
|
| | from accelerate.utils import set_seed
|
| | from omegaconf import open_dict, DictConfig
|
| |
|
| |
|
| | def check_args_and_env(args: DictConfig) -> None:
|
| | assert args.optim.batch_size % args.optim.grad_acc == 0
|
| |
|
| | assert args.eval.every_steps % args.logging.every_steps == 0
|
| |
|
| | if args.device == "gpu":
|
| | assert torch.cuda.is_available(), "We use GPU to train/eval the model"
|
| |
|
| |
|
| | def opti_flags(args: DictConfig) -> None:
|
| |
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| | torch.backends.cudnn.allow_tf32 = True
|
| |
|
| |
|
| | def update_args_with_env_info(args: DictConfig) -> None:
|
| | with open_dict(args):
|
| | slurm_id = os.getenv("SLURM_JOB_ID")
|
| |
|
| | if slurm_id is not None:
|
| | args.slurm_id = slurm_id
|
| | else:
|
| | args.slurm_id = "none"
|
| |
|
| | args.working_dir = os.getcwd()
|
| |
|
| |
|
| | def setup_args(args: DictConfig) -> None:
|
| | check_args_and_env(args)
|
| | update_args_with_env_info(args)
|
| | opti_flags(args)
|
| |
|
| | if args.seed is not None:
|
| | set_seed(args.seed)
|
| |
|