File size: 1,138 Bytes
7ef7abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import os

from accelerate.utils import set_seed
from omegaconf import open_dict, DictConfig


def check_args_and_env(args: DictConfig) -> None:
    assert args.optim.batch_size % args.optim.grad_acc == 0
    # Train log must happen before eval log
    assert args.eval.every_steps % args.logging.every_steps == 0

    if args.device == "gpu":
        assert torch.cuda.is_available(), "We use GPU to train/eval the model"


def opti_flags(args: DictConfig) -> None:
    # This lines reduce training step by 2.4x
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


def update_args_with_env_info(args: DictConfig) -> None:
    with open_dict(args):
        slurm_id = os.getenv("SLURM_JOB_ID")

        if slurm_id is not None:
            args.slurm_id = slurm_id
        else:
            args.slurm_id = "none"

        args.working_dir = os.getcwd()


def setup_args(args: DictConfig) -> None:
    check_args_and_env(args)
    update_args_with_env_info(args)
    opti_flags(args)

    if args.seed is not None:
        set_seed(args.seed)