| |
| import argparse |
| import os |
| import yaml |
| import json |
| import random |
| import numpy as np |
| import torch |
|
|
| def mkdir(dirpath): |
| os.makedirs(dirpath, exist_ok=True) |
|
|
| def get_config(): |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('--config', type=str, default='config.json', |
| help="Path to base config file (json or yaml)") |
| parser.add_argument('--tokenizer_dir', type=str) |
| parser.add_argument('--data_dir', type=str) |
| parser.add_argument('--train_glob', type=str) |
| parser.add_argument('--valid_glob', type=str) |
| parser.add_argument('--output_dir', type=str) |
|
|
| |
| parser.add_argument('--datapoint_length', type=int) |
| parser.add_argument('--training_type', type=str, choices=["strict", "strict_small"]) |
| parser.add_argument('--n_epochs', type=int) |
| parser.add_argument('--batch_size', type=int) |
| parser.add_argument('--learning_rate', type=float) |
| parser.add_argument('--weight_decay', type=float) |
| parser.add_argument('--num_training_steps', type=int) |
| parser.add_argument('--num_warmup_steps', type=int) |
| parser.add_argument('--gradient_clip_norm', type=float) |
|
|
| |
| parser.add_argument('--seed', type=int) |
| parser.add_argument('--base_folder', type=str) |
| parser.add_argument('--experiment_name', type=str) |
| parser.add_argument('--use_wandb', action='store_true') |
| parser.add_argument('--wandb_project_name', type=str) |
| parser.add_argument('--wandb_experiment_name', type=str) |
|
|
|
|
|
|
| args = parser.parse_args() |
| config = construct_config(args) |
| return config |
|
|
| def setup_experiment(cfg): |
| |
| if cfg.get("seed", -1) == -1: |
| cfg["seed"] = random.randint(0, 10**9) |
| random.seed(cfg["seed"]) |
| np.random.seed(cfg["seed"]) |
| torch.manual_seed(cfg["seed"]) |
| torch.cuda.manual_seed_all(cfg["seed"]) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| print(f"[utils] Using seed {cfg['seed']}") |
|
|
| |
| cfg["expdir"] = os.path.join(cfg["base_folder"], cfg["experiment_name"]) |
| cfg["checkpoint_dir"] = os.path.join(cfg["expdir"], 'checkpoints') |
| cfg["logdir"] = os.path.join(cfg["expdir"], 'logging') |
| mkdir(cfg["expdir"]); mkdir(cfg["checkpoint_dir"]); mkdir(cfg["logdir"]) |
|
|
| |
| with open(os.path.join(cfg["logdir"], "exp_cfg.yaml"), 'w') as cfg_file: |
| yaml.safe_dump(cfg, cfg_file, sort_keys=False) |
|
|
| def setup_wandb(cfg): |
| try: |
| import wandb |
| except ImportError: |
| raise RuntimeError("use_wandb is true but wandb is not installed") |
| wandb.init( |
| project=cfg["wandb_project_name"], |
| name=cfg["wandb_experiment_name"] |
| ) |
|
|
| def load_file_any(filepath): |
| ext = os.path.splitext(filepath)[1].lower() |
| with open(filepath, 'r') as f: |
| if ext in ['.yaml', '.yml']: |
| return yaml.safe_load(f) |
| else: |
| return json.load(f) |
|
|
| def construct_config(args): |
| base_cfg = load_file_any(args.config) |
| |
| for k, v in vars(args).items(): |
| if k == "config": |
| continue |
| if v is not None: |
| base_cfg[k] = v |
| return base_cfg |