| | """ |
| | General helper functions for setting up experiments |
| | """ |
| | import os |
| | import random |
| |
|
| | from argparse import ArgumentParser |
| | from omegaconf import DictConfig |
| |
|
| | import torch |
| | import numpy as np |
| |
|
| | from .logging import _format_arg |
| |
|
| |
|
| | def init_wandb(args: ArgumentParser) -> any: |
| | """Initialize WandB""" |
| | if args.no_wandb: |
| | wandb = None |
| | else: |
| | import wandb |
| | wandb.init(config={}, |
| | entity=args.wandb_entity, |
| | name=args.run_name, |
| | project=args.project_name) |
| | return wandb |
| |
|
| |
|
| | def seed_everything(seed: int) -> None: |
| | """ |
| | Seed everything |
| | """ |
| | random.seed(seed) |
| | os.environ['PYTHONHASHSEED'] = str(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| |
|
| | def get_run_name_from_checkpoint(checkpoint_path: str) -> str: |
| | """ |
| | Helper function to get a condensed run name from the checkpoint path |
| | """ |
| | name = [] |
| | for s in checkpoint_path.split('/')[-1].split('-'): |
| | if '.pt' in s: |
| | name.append(f'_{s[:-3]}') |
| | try: |
| | s = s.split('=') |
| | s = ''.join([c[0] for c in s[1].split('_')]) |
| | name.append(s) |
| | except IndexError: |
| | pass |
| | return ''.join(name) |
| |
|
| |
|
| | def get_run_name_from_args(args) -> str: |
| | """ |
| | Prepare a heinous identifier for the run based on args |
| | """ |
| | if args.load_distill_checkpoint is not None and args.load_distill_checkpoint != 'default': |
| | distill_name = get_run_name_from_checkpoint(args.load_distill_checkpoint) |
| | else: |
| | distill_name = args.distill_config |
| | if args.load_finetune_checkpoint is not None and args.finetune_config is None: |
| | finetune_name = get_run_name_from_checkpoint(args.load_finetune_checkpoint) |
| | else: |
| | finetune_name = args.finetune_config |
| | args.run_name = f'dl-d={distill_name}-m={args.model_config}-f={finetune_name}' |
| | if args.no_peft_grad_ckpt is not None: |
| | args.run_name += f'-npgc={args.no_peft_grad_ckpt}' |
| | args.run_name += f'-s={args.seed}' |
| | if args.debug: |
| | args.run_name += f'-debug' |
| | if args.no_attention_mask is not None: |
| | args.run_name += f'-nam=1' |
| | return args.run_name.replace('True', '1').replace('False', '0') |
| |
|
| |
|
| | def flatten_config(config: dict, flattened: dict, key: str) -> dict: |
| | """ |
| | Recursive way to flatten config args for saving to WandB |
| | """ |
| | for k, v in config.items(): |
| | if isinstance(v, dict): |
| | flatten_config(v, flattened, f'{key}{k}_') |
| | elif isinstance(v, list): |
| | for ix, _config in enumerate(v): |
| | if isinstance(_config, dict): |
| | flatten_config(_config, flattened, f'{key}{k}_{ix}_') |
| | else: |
| | flattened[f'{key}{k}'] = v |
| | return flattened |
| |
|
| |
|
| | def update_config_from_args(config: DictConfig, |
| | args: ArgumentParser, |
| | ignore_args: list = None) -> DictConfig: |
| | """ |
| | Quick hacks to override default configs |
| | """ |
| | ignore_args = [] if ignore_args is None else ignore_args |
| | |
| | |
| | if getattr(args, 'dataset', None): |
| | config.dataset.name = args.dataset |
| | args.run_name += f'-ds={args.dataset}' |
| | |
| | |
| | for arg in ['lr', 'weight_decay']: |
| | if arg not in ignore_args: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | setattr(config.optimizer, arg, argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| | try: |
| | if getattr(args, 'optim', None): |
| | config.optimizer.optim = args.optim |
| | args.run_name += f'-o={args.optim}' |
| | except AttributeError: |
| | pass |
| | |
| | |
| | try: |
| | if getattr(args, 'scheduler', None): |
| | config.lr_scheduler.lr_scheduler_type = args.scheduler |
| | args.run_name += f'-sc={args.scheduler}' |
| | except AttributeError: |
| | pass |
| |
|
| | |
| | for arg in [a for a in dir(args) if 'dataset_' in a]: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | setattr(config.dataset.dataset_config, arg[len('dataset_'):], argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| |
|
| | |
| | for arg in ['batch_size']: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | setattr(config.dataloader, arg, argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| |
|
| | |
| | for arg in ['gradient_accumulation_steps', 'num_train_epochs', |
| | 'max_steps', 'max_finetune_steps', 'eval_steps', |
| | 'seed', 'max_eval_batches']: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | setattr(config.trainer, arg, argval) |
| | if arg in ['max_steps', 'max_finetune_steps', |
| | 'gradient_accumulation_steps', 'num_train_epochs', 'seed']: |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| |
|
| | |
| | for arg in ['replicate']: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| |
|
| | return config |
| |
|
| |
|
| | def update_model_config_from_args(model_config: DictConfig, |
| | args: ArgumentParser) -> DictConfig: |
| | """ |
| | Override default configs given argparse args |
| | """ |
| | |
| | for arg in ['attention_type', 'learned_kernel', 'tie_qk_kernels', |
| | 'train_qk', 'state_chunk_len', 'no_peft_grad_ckpt', |
| | 'window_size']: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | setattr(model_config['attention'], arg, argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| | else: |
| | try: |
| | getattr(model_config['attention'], arg) |
| | except AttributeError: |
| | setattr(model_config['attention'], arg, None) |
| |
|
| | |
| | for arg in ['lk_skip_connection', 'lk_zero_init', 'lk_normal_init']: |
| | argval = getattr(args, arg, None) |
| | if argval is not None: |
| | setattr(model_config['attention']['learned_kernel_kwargs'], |
| | arg[len('lk_'):], argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| | |
| | |
| | if args.pretrained_model_name_or_path is not None: |
| | pmnop = args.pretrained_model_name_or_path |
| | model_config.model.pretrained_model_name_or_path = pmnop |
| | args.run_name += f'-pmnop={pmnop.split("/")[-1]}' |
| | |
| | return model_config |
| |
|