PyTorch
gpt2
File size: 3,216 Bytes
c2760fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# utils.py
import argparse
import os
import yaml
import json
import random
import numpy as np
import torch

def mkdir(dirpath):
    os.makedirs(dirpath, exist_ok=True)

def get_config():
    parser = argparse.ArgumentParser()

    parser.add_argument('--config', type=str, default='config.json',
                        help="Path to base config file (json or yaml)")
    parser.add_argument('--tokenizer_dir', type=str)
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--train_glob', type=str)
    parser.add_argument('--valid_glob', type=str)
    parser.add_argument('--output_dir', type=str)

    # Training settings
    parser.add_argument('--datapoint_length', type=int)
    parser.add_argument('--training_type', type=str, choices=["strict", "strict_small"])
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--learning_rate', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--num_training_steps', type=int)
    parser.add_argument('--num_warmup_steps', type=int)
    parser.add_argument('--gradient_clip_norm', type=float)

    # Experiment
    parser.add_argument('--seed', type=int)
    parser.add_argument('--base_folder', type=str)
    parser.add_argument('--experiment_name', type=str)
    parser.add_argument('--use_wandb', action='store_true')
    parser.add_argument('--wandb_project_name', type=str)
    parser.add_argument('--wandb_experiment_name', type=str)



    args = parser.parse_args()
    config = construct_config(args)
    return config

def setup_experiment(cfg):
    # Seed
    if cfg.get("seed", -1) == -1:
        cfg["seed"] = random.randint(0, 10**9)
    random.seed(cfg["seed"])
    np.random.seed(cfg["seed"])
    torch.manual_seed(cfg["seed"])
    torch.cuda.manual_seed_all(cfg["seed"])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"[utils] Using seed {cfg['seed']}")

    # Folders
    cfg["expdir"] = os.path.join(cfg["base_folder"], cfg["experiment_name"])
    cfg["checkpoint_dir"] = os.path.join(cfg["expdir"], 'checkpoints')
    cfg["logdir"] = os.path.join(cfg["expdir"], 'logging')
    mkdir(cfg["expdir"]); mkdir(cfg["checkpoint_dir"]); mkdir(cfg["logdir"])

    # Save resolved config
    with open(os.path.join(cfg["logdir"], "exp_cfg.yaml"), 'w') as cfg_file:
        yaml.safe_dump(cfg, cfg_file, sort_keys=False)

def setup_wandb(cfg):
    try:
        import wandb
    except ImportError:
        raise RuntimeError("use_wandb is true but wandb is not installed")
    wandb.init(
        project=cfg["wandb_project_name"],
        name=cfg["wandb_experiment_name"]
    )

def load_file_any(filepath):
    ext = os.path.splitext(filepath)[1].lower()
    with open(filepath, 'r') as f:
        if ext in ['.yaml', '.yml']:
            return yaml.safe_load(f)
        else:
            return json.load(f)

def construct_config(args):
    base_cfg = load_file_any(args.config)
    # Overlay CLI args when provided
    for k, v in vars(args).items():
        if k == "config":
            continue
        if v is not None:
            base_cfg[k] = v
    return base_cfg