| | from transformers import LlamaConfig, LlamaTokenizer |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from HybridTensor.models.create_sparse_model import GPTLMHeadModel |
| | from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict |
| |
|
| | |
| | from transformers import AutoConfig, AutoTokenizer |
| | from flash_attn.utils.pretrained import state_dict_from_pretrained |
| |
|
| | from flash_attn.models.llama import ( |
| | config_from_checkpoint, |
| | inv_remap_state_dict_hf_llama, |
| | llama_config_to_gpt2_config, |
| | remap_state_dict_hf_llama, |
| | remap_state_dict_meta_llama, |
| | state_dicts_from_checkpoint, |
| | ) |
| |
|
| | class SparseConfig: |
| | def __init__(self): |
| | self.mlp_low_rank_dim = 1024 |
| | self.attn_low_rank_dim = 128 |
| | self.mlp_act_th = 0.5 |
| | self.attn_topk = 0.3 |
| |
|
| | def build_dense_llama(model_name: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs): |
| | config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True)) |
| | config.use_flash_attn = True |
| | config.fused_bias_fc = True |
| | config.fused_mlp = False |
| | config.fused_dropout_add_ln = True |
| | config.residual_in_fp32 = True |
| | config.prenorm = True |
| | |
| | state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype) |
| | state_dict = remap_state_dict_hf_llama(state_dict, config) |
| |
|
| | model = GPTLMHeadModel(config, device=device, dtype=dtype) |
| | model.load_state_dict(state_dict, strict=True) |
| | model.eval() |
| | |
| | return model |
| |
|
| | def build_sparse_llama(args, model_name: str, attn_ckpt_dir: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs): |
| | config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True)) |
| | config.use_flash_attn = True |
| | config.fused_bias_fc = True |
| | config.fused_mlp = False |
| | config.fused_dropout_add_ln = True |
| | config.residual_in_fp32 = True |
| | config.prenorm = True |
| | |
| | spconfig = SparseConfig() |
| | spconfig.attn_topk = args.attn_topk |
| | config.mlp_sparse = False |
| | config.att_sparse = True |
| | |
| | state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype) |
| | state_dict = remap_state_dict_hf_llama(state_dict, config) |
| |
|
| | model = GPTLMHeadModel(config, sp_config= spconfig, device=device, dtype=dtype) |
| | |
| | if attn_ckpt_dir is not None: |
| | attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir) |
| | merged_state_dict = {**state_dict, **attn_router_state_dict} |
| | |
| | |
| | |
| | model.load_state_dict(merged_state_dict, strict=True) |
| | model.eval() |
| | |
| | return model |