import torch import os from accelerate.utils import set_seed from omegaconf import open_dict, DictConfig def check_args_and_env(args: DictConfig) -> None: assert args.optim.batch_size % args.optim.grad_acc == 0 # Train log must happen before eval log assert args.eval.every_steps % args.logging.every_steps == 0 if args.device == "gpu": assert torch.cuda.is_available(), "We use GPU to train/eval the model" def opti_flags(args: DictConfig) -> None: # This lines reduce training step by 2.4x torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def update_args_with_env_info(args: DictConfig) -> None: with open_dict(args): slurm_id = os.getenv("SLURM_JOB_ID") if slurm_id is not None: args.slurm_id = slurm_id else: args.slurm_id = "none" args.working_dir = os.getcwd() def setup_args(args: DictConfig) -> None: check_args_and_env(args) update_args_with_env_info(args) opti_flags(args) if args.seed is not None: set_seed(args.seed)