# utils.py import argparse import os import yaml import json import random import numpy as np import torch def mkdir(dirpath): os.makedirs(dirpath, exist_ok=True) def get_config(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config.json', help="Path to base config file (json or yaml)") parser.add_argument('--tokenizer_dir', type=str) parser.add_argument('--data_dir', type=str) parser.add_argument('--train_glob', type=str) parser.add_argument('--valid_glob', type=str) parser.add_argument('--output_dir', type=str) # Training settings parser.add_argument('--datapoint_length', type=int) parser.add_argument('--training_type', type=str, choices=["strict", "strict_small"]) parser.add_argument('--n_epochs', type=int) parser.add_argument('--batch_size', type=int) parser.add_argument('--learning_rate', type=float) parser.add_argument('--weight_decay', type=float) parser.add_argument('--num_training_steps', type=int) parser.add_argument('--num_warmup_steps', type=int) parser.add_argument('--gradient_clip_norm', type=float) # Experiment parser.add_argument('--seed', type=int) parser.add_argument('--base_folder', type=str) parser.add_argument('--experiment_name', type=str) parser.add_argument('--use_wandb', action='store_true') parser.add_argument('--wandb_project_name', type=str) parser.add_argument('--wandb_experiment_name', type=str) args = parser.parse_args() config = construct_config(args) return config def setup_experiment(cfg): # Seed if cfg.get("seed", -1) == -1: cfg["seed"] = random.randint(0, 10**9) random.seed(cfg["seed"]) np.random.seed(cfg["seed"]) torch.manual_seed(cfg["seed"]) torch.cuda.manual_seed_all(cfg["seed"]) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print(f"[utils] Using seed {cfg['seed']}") # Folders cfg["expdir"] = os.path.join(cfg["base_folder"], cfg["experiment_name"]) cfg["checkpoint_dir"] = os.path.join(cfg["expdir"], 'checkpoints') cfg["logdir"] = os.path.join(cfg["expdir"], 'logging') mkdir(cfg["expdir"]); mkdir(cfg["checkpoint_dir"]); mkdir(cfg["logdir"]) # Save resolved config with open(os.path.join(cfg["logdir"], "exp_cfg.yaml"), 'w') as cfg_file: yaml.safe_dump(cfg, cfg_file, sort_keys=False) def setup_wandb(cfg): try: import wandb except ImportError: raise RuntimeError("use_wandb is true but wandb is not installed") wandb.init( project=cfg["wandb_project_name"], name=cfg["wandb_experiment_name"] ) def load_file_any(filepath): ext = os.path.splitext(filepath)[1].lower() with open(filepath, 'r') as f: if ext in ['.yaml', '.yml']: return yaml.safe_load(f) else: return json.load(f) def construct_config(args): base_cfg = load_file_any(args.config) # Overlay CLI args when provided for k, v in vars(args).items(): if k == "config": continue if v is not None: base_cfg[k] = v return base_cfg