|
|
|
|
|
import argparse |
|
|
import os |
|
|
import yaml |
|
|
import json |
|
|
import random |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
def mkdir(dirpath): |
|
|
os.makedirs(dirpath, exist_ok=True) |
|
|
|
|
|
def get_config(): |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--config', type=str, default='config.json', |
|
|
help="Path to base config file (json or yaml)") |
|
|
parser.add_argument('--tokenizer_dir', type=str) |
|
|
parser.add_argument('--data_dir', type=str) |
|
|
parser.add_argument('--train_glob', type=str) |
|
|
parser.add_argument('--valid_glob', type=str) |
|
|
parser.add_argument('--output_dir', type=str) |
|
|
|
|
|
|
|
|
parser.add_argument('--datapoint_length', type=int) |
|
|
parser.add_argument('--training_type', type=str, choices=["strict", "strict_small"]) |
|
|
parser.add_argument('--n_epochs', type=int) |
|
|
parser.add_argument('--batch_size', type=int) |
|
|
parser.add_argument('--learning_rate', type=float) |
|
|
parser.add_argument('--weight_decay', type=float) |
|
|
parser.add_argument('--num_training_steps', type=int) |
|
|
parser.add_argument('--num_warmup_steps', type=int) |
|
|
parser.add_argument('--gradient_clip_norm', type=float) |
|
|
|
|
|
|
|
|
parser.add_argument('--seed', type=int) |
|
|
parser.add_argument('--base_folder', type=str) |
|
|
parser.add_argument('--experiment_name', type=str) |
|
|
parser.add_argument('--use_wandb', action='store_true') |
|
|
parser.add_argument('--wandb_project_name', type=str) |
|
|
parser.add_argument('--wandb_experiment_name', type=str) |
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
config = construct_config(args) |
|
|
return config |
|
|
|
|
|
def setup_experiment(cfg): |
|
|
|
|
|
if cfg.get("seed", -1) == -1: |
|
|
cfg["seed"] = random.randint(0, 10**9) |
|
|
random.seed(cfg["seed"]) |
|
|
np.random.seed(cfg["seed"]) |
|
|
torch.manual_seed(cfg["seed"]) |
|
|
torch.cuda.manual_seed_all(cfg["seed"]) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
print(f"[utils] Using seed {cfg['seed']}") |
|
|
|
|
|
|
|
|
cfg["expdir"] = os.path.join(cfg["base_folder"], cfg["experiment_name"]) |
|
|
cfg["checkpoint_dir"] = os.path.join(cfg["expdir"], 'checkpoints') |
|
|
cfg["logdir"] = os.path.join(cfg["expdir"], 'logging') |
|
|
mkdir(cfg["expdir"]); mkdir(cfg["checkpoint_dir"]); mkdir(cfg["logdir"]) |
|
|
|
|
|
|
|
|
with open(os.path.join(cfg["logdir"], "exp_cfg.yaml"), 'w') as cfg_file: |
|
|
yaml.safe_dump(cfg, cfg_file, sort_keys=False) |
|
|
|
|
|
def setup_wandb(cfg): |
|
|
try: |
|
|
import wandb |
|
|
except ImportError: |
|
|
raise RuntimeError("use_wandb is true but wandb is not installed") |
|
|
wandb.init( |
|
|
project=cfg["wandb_project_name"], |
|
|
name=cfg["wandb_experiment_name"] |
|
|
) |
|
|
|
|
|
def load_file_any(filepath): |
|
|
ext = os.path.splitext(filepath)[1].lower() |
|
|
with open(filepath, 'r') as f: |
|
|
if ext in ['.yaml', '.yml']: |
|
|
return yaml.safe_load(f) |
|
|
else: |
|
|
return json.load(f) |
|
|
|
|
|
def construct_config(args): |
|
|
base_cfg = load_file_any(args.config) |
|
|
|
|
|
for k, v in vars(args).items(): |
|
|
if k == "config": |
|
|
continue |
|
|
if v is not None: |
|
|
base_cfg[k] = v |
|
|
return base_cfg |