PyTorch
gpt2
achille-fusco's picture
Upload folder using huggingface_hub
c2760fe verified
raw
history blame
3.22 kB
# utils.py
import argparse
import os
import yaml
import json
import random
import numpy as np
import torch
def mkdir(dirpath):
os.makedirs(dirpath, exist_ok=True)
def get_config():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config.json',
help="Path to base config file (json or yaml)")
parser.add_argument('--tokenizer_dir', type=str)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--train_glob', type=str)
parser.add_argument('--valid_glob', type=str)
parser.add_argument('--output_dir', type=str)
# Training settings
parser.add_argument('--datapoint_length', type=int)
parser.add_argument('--training_type', type=str, choices=["strict", "strict_small"])
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--learning_rate', type=float)
parser.add_argument('--weight_decay', type=float)
parser.add_argument('--num_training_steps', type=int)
parser.add_argument('--num_warmup_steps', type=int)
parser.add_argument('--gradient_clip_norm', type=float)
# Experiment
parser.add_argument('--seed', type=int)
parser.add_argument('--base_folder', type=str)
parser.add_argument('--experiment_name', type=str)
parser.add_argument('--use_wandb', action='store_true')
parser.add_argument('--wandb_project_name', type=str)
parser.add_argument('--wandb_experiment_name', type=str)
args = parser.parse_args()
config = construct_config(args)
return config
def setup_experiment(cfg):
# Seed
if cfg.get("seed", -1) == -1:
cfg["seed"] = random.randint(0, 10**9)
random.seed(cfg["seed"])
np.random.seed(cfg["seed"])
torch.manual_seed(cfg["seed"])
torch.cuda.manual_seed_all(cfg["seed"])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f"[utils] Using seed {cfg['seed']}")
# Folders
cfg["expdir"] = os.path.join(cfg["base_folder"], cfg["experiment_name"])
cfg["checkpoint_dir"] = os.path.join(cfg["expdir"], 'checkpoints')
cfg["logdir"] = os.path.join(cfg["expdir"], 'logging')
mkdir(cfg["expdir"]); mkdir(cfg["checkpoint_dir"]); mkdir(cfg["logdir"])
# Save resolved config
with open(os.path.join(cfg["logdir"], "exp_cfg.yaml"), 'w') as cfg_file:
yaml.safe_dump(cfg, cfg_file, sort_keys=False)
def setup_wandb(cfg):
try:
import wandb
except ImportError:
raise RuntimeError("use_wandb is true but wandb is not installed")
wandb.init(
project=cfg["wandb_project_name"],
name=cfg["wandb_experiment_name"]
)
def load_file_any(filepath):
ext = os.path.splitext(filepath)[1].lower()
with open(filepath, 'r') as f:
if ext in ['.yaml', '.yml']:
return yaml.safe_load(f)
else:
return json.load(f)
def construct_config(args):
base_cfg = load_file_any(args.config)
# Overlay CLI args when provided
for k, v in vars(args).items():
if k == "config":
continue
if v is not None:
base_cfg[k] = v
return base_cfg