| | """ |
| | Alternative way to load trained models for evaluation |
| | """ |
| | import copy |
| | import sys |
| | from os.path import join |
| | from omegaconf import OmegaConf |
| |
|
| | import torch |
| |
|
| | from src.utils.logging import print_header, print_config, _format_arg |
| | from .pretrained import get_pretrained_loader |
| | from .peft import create_peft_config |
| | from .load_model import load_and_convert_attns |
| | from .convert_model import remove_base_attention, toggle_attention |
| |
|
| | |
| | def get_args_from_checkpoint(fname: str): |
| | """ |
| | Get arguments from checkpoint filename |
| | """ |
| | id_to_name = { |
| | 'lk': 'learned_kernel', |
| | 'tqk': 'tie_qk_kernels', |
| | 'tq': 'train_qk', |
| | 'lzi': 'lk_zero_init', |
| | 'lsc': 'lk_skip_connection', |
| | 'pmnop': 'pretrained_model_name_or_path', |
| | } |
| | id_to_type = { |
| | 'lk': str, |
| | 'tqk': bool, |
| | 'tq': bool, |
| | 'lzi': bool, |
| | 'lsc': bool, |
| | 'pmnop': str, |
| | } |
| | args = {v: None for k, v in id_to_name.items()} |
| | args['run_name'] = '' |
| | |
| | for id_val in fname.split('-'): |
| | try: |
| | _id, val = id_val.split('=') |
| | if val[-len('_distill.pt'):] == '_distill.pt': |
| | val = val[:-len('_distill.pt')] |
| | if _id in id_to_type: |
| | _type = id_to_type[_id] |
| | args[id_to_name[_id]] = _type(val) |
| | except Exception: |
| | pass |
| | return OmegaConf.create(args) |
| |
|
| |
|
| | def update_model_config_from_args(model_config, args): |
| | """Override default configs""" |
| | |
| | for arg in ['learned_kernel', 'tie_qk_kernels', 'train_qk']: |
| | argval = getattr(args, arg) |
| | if argval is not None: |
| | setattr(model_config['attention'], arg, argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| | |
| | for arg in ['lk_skip_connection', 'lk_zero_init']: |
| | argval = getattr(args, arg) |
| | if argval is not None: |
| | setattr(model_config['attention']['learned_kernel_kwargs'], |
| | arg[len('lk_'):], argval) |
| | args.run_name += f'-{_format_arg(arg)}={argval}' |
| | |
| | if args.pretrained_model_name_or_path is not None: |
| | pmnop = args.pretrained_model_name_or_path |
| | model_config.model.pretrained_model_name_or_path = pmnop |
| | args.run_name += f'-pmnop={pmnop.split("/")[-1]}' |
| | return model_config |
| |
|
| |
|
| |
|
| | def get_lm_eval_model(model_kwargs: dict, |
| | path_to_lm_eval_harness: str, |
| | hedgehog_model: bool = False, |
| | long_model: bool = False, |
| | ): |
| | """ |
| | Load model for evaluation using LM Evaluation Harness |
| | """ |
| | lm_kwargs = copy.deepcopy(model_kwargs) |
| | lm_kwargs['pretrained'] = lm_kwargs['pretrained_model_name_or_path'] |
| | lm_kwargs['dtype'] = str(lm_kwargs['torch_dtype']).split('.')[-1] |
| | del lm_kwargs['torch_dtype'] |
| |
|
| | |
| | lm_kwargs['output_attentions'] = False |
| | lm_kwargs['output_hidden_states'] = False |
| |
|
| | print('-> Loading as lm-evaluation-harness model') |
| | if hedgehog_model: |
| | if 'mistral' in lm_kwargs['pretrained']: |
| | from lm_eval_harness.models import LolcatsMistralForCausalLM as ModelClass |
| | else: |
| | from lm_eval_harness.models import LolcatsLlamaForCausalLM as ModelClass |
| | lm = ModelClass.create_from_arg_string('', lm_kwargs) |
| | else: |
| | sys.path.append(path_to_lm_eval_harness) |
| | from lm_eval.models import get_model |
| | lm = get_model('hf-causal-experimental').create_from_arg_string('', lm_kwargs) |
| | return lm |
| |
|
| |
|
| | def load_model_from_config(model_config_name: str, |
| | config_dir: str = './configs', |
| | lm_eval_model: bool = False, |
| | path_to_lm_eval_harness: str = '/juice2/scr2/mzhang/projects/lm-evaluation-harness', |
| | ): |
| | """ |
| | Load model from a config file |
| | """ |
| | |
| | model_config_path = join(config_dir, 'model', f'{model_config_name}.yaml') |
| | model_config = OmegaConf.load(model_config_path) |
| |
|
| | model_loader = get_pretrained_loader(**model_config.model) |
| | tokenizer = model_loader.load_tokenizer() |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | tokenizer.padding_side = 'left' |
| |
|
| | if lm_eval_model: |
| | lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) |
| | model = lm.model |
| | else: |
| | model = model_loader.load() |
| |
|
| | model.eval() |
| | if lm_eval_model: |
| | lm.model = model |
| | model = lm |
| | return model, model_config, tokenizer |
| |
|
| |
|
| | def load_model_from_checkpoint(attn_mlp_checkpoint_path: str = None, |
| | finetune_checkpoint_path: str = None, |
| | config_dir: str = './configs', |
| | print_model: bool = False, |
| | debug: bool = False, |
| | lm_eval_model: bool = False, |
| | path_to_lm_eval_harness: str = '/juice2/scr2/mzhang/projects/lm-evaluation-harness', |
| | profile_model: bool = False, |
| | ): |
| | """ |
| | Load model architecture from a checkpoint path |
| | -> attn_mlp_checkpoint_path should direct to checkpoint with learned MLPs |
| | -> finetune_checkpoint_path should direct to checkpoint with all other parameters |
| | -> Assumes checkpoint_path stings have names for model_config and finetune_configs |
| | """ |
| |
|
| | |
| | if attn_mlp_checkpoint_path is not None: |
| | if len(attn_mlp_checkpoint_path.split('/')) == 4: |
| | model_config = attn_mlp_checkpoint_path.split('/')[2] |
| | else: |
| | model_config = attn_mlp_checkpoint_path.split('/')[-1].split('-m=')[-1].split('-')[0] |
| | model_config_path = join(config_dir, 'model', f'{model_config}.yaml') |
| | model_config = OmegaConf.load(model_config_path) |
| | args = get_args_from_checkpoint(attn_mlp_checkpoint_path.split('/')[-1]) |
| | model_config = update_model_config_from_args(model_config, args) |
| | else: |
| | if len(finetune_checkpoint_path.split('/')) == 4: |
| | model_config = finetune_checkpoint_path.split('/')[2] |
| | else: |
| | model_config = finetune_checkpoint_path.split('/')[-1].split('-m=')[-1].split('-')[0] |
| | model_config_path = join(config_dir, 'model', f'{model_config}.yaml') |
| | model_config = OmegaConf.load(model_config_path) |
| |
|
| | if profile_model: |
| | model_config['attention']['attention_type'] += '_profile' |
| |
|
| | if finetune_checkpoint_path is not None: |
| | finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] |
| | finetune_config_path = join(config_dir, 'experiment', f'{finetune_config}.yaml') |
| | finetune_config = OmegaConf.load(finetune_config_path) |
| |
|
| | if debug: |
| | print_header('-- Model Config --') |
| | print_config(model_config) |
| | try: |
| | print_header('-- Finetune Config --') |
| | print_config(finetune_config) |
| | except NameError: |
| | pass |
| | |
| | |
| | model_loader = get_pretrained_loader(**model_config.model) |
| | tokenizer = model_loader.load_tokenizer() |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | tokenizer.padding_side = 'left' |
| |
|
| | if lm_eval_model and attn_mlp_checkpoint_path is not None: |
| | lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness, |
| | hedgehog_model=True) |
| | model = lm.model |
| | elif lm_eval_model: |
| | lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) |
| | model = lm.model |
| | elif attn_mlp_checkpoint_path is None: |
| | model = model_loader.load() |
| | else: |
| | model = model_loader.load(model_type=model_config['attention']['attention_type']) |
| | try: |
| | model.state_chunk_len = model_config['attention']['state_chunk_len'] |
| | except KeyError: |
| | pass |
| |
|
| | if attn_mlp_checkpoint_path is not None: |
| | |
| | model = load_and_convert_attns(model, model_config, |
| | checkpoint_path=attn_mlp_checkpoint_path)[0] |
| | if 'peft' in model_config['attention']: |
| | model = model.merge_and_unload() |
| | |
| | |
| | model = toggle_attention(model, False) |
| | if debug: |
| | print_header('*** Model after attention converion ***') |
| | print(model) |
| |
|
| | if finetune_checkpoint_path is not None: |
| | |
| | if finetune_config.finetune.method == 'lora': |
| | model, _ = create_peft_config(model, finetune_config.finetune) |
| | else: |
| | for p in model.parameters(): |
| | p.requires_grad = True |
| | |
| | |
| | state_dict = torch.load(finetune_checkpoint_path)['model_state_dict'] |
| | _keys = model.load_state_dict(state_dict, strict=False) |
| | try: |
| | assert len(_keys.unexpected_keys) == 0 |
| | print_header('*** All expected keys matched successfully ***') |
| | except AssertionError: |
| | print_header('*** Error: unexpected keys in checkpoint ***') |
| | print('Unexpected keys:') |
| | for k in _keys.unexpected_keys: |
| | print(k) |
| | if debug: |
| | print_header('Missing keys:') |
| | for k in _keys.missing_keys: |
| | print(k) |
| | print_header('Unexpected keys:') |
| | for k in _keys.unexpected_keys: |
| | print(k) |
| |
|
| | try: |
| | |
| | print('-> Training attention:', model.model.layers[0].self_attn.train_attention) |
| | except AttributeError as e: |
| | print('Error at:', e) |
| | _train_attn = model.model.model.layers[0].self_attn.train_attention |
| | print(f"But it's ok, {type(model.model.model)} has attribute 'layers'") |
| | print('-> Training attention:', _train_attn) |
| | |
| |
|
| | if print_model or debug: |
| | print_header('*** Model ***') |
| | print(model) |
| | print_header('*** Trainable Parameters ***') |
| | for n, p in model.named_parameters(): |
| | if p.requires_grad: |
| | print(f'├── {n}.requires_grad: {p.requires_grad}') |
| | model.eval() |
| | if lm_eval_model: |
| | lm.model = model |
| | model = lm |
| | return model, model_config, tokenizer |
| |
|