Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
from transformers import LlamaConfig, LlamaTokenizer
import torch
import torch.nn as nn
from HybridTensor.models.create_sparse_model import GPTLMHeadModel
from HybridTensor.modules.SelectiveRouters import create_mlp_router_state_dict, create_attn_router_state_dict
# from flash_attn.models.gpt import GPTLMHeadModel
from transformers import AutoConfig, AutoTokenizer
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.models.llama import (
config_from_checkpoint,
inv_remap_state_dict_hf_llama,
llama_config_to_gpt2_config,
remap_state_dict_hf_llama,
remap_state_dict_meta_llama,
state_dicts_from_checkpoint,
)
class SparseConfig:
def __init__(self):
self.mlp_low_rank_dim = 1024
self.attn_low_rank_dim = 128
self.mlp_act_th = 0.5
self.attn_topk = 0.3
def build_dense_llama(model_name: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs):
config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.prenorm = True
state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype)
state_dict = remap_state_dict_hf_llama(state_dict, config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(state_dict, strict=True)
model.eval()
return model
def build_sparse_llama(args, model_name: str, attn_ckpt_dir: str, device = None, dtype=torch.float16, process_group = None, world_size = None, rank = None, **kwargs):
config = llama_config_to_gpt2_config(AutoConfig.from_pretrained(model_name, trust_remote_code=True))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.prenorm = True
spconfig = SparseConfig()
spconfig.attn_topk = args.attn_topk
config.mlp_sparse = False
config.att_sparse = True
state_dict = state_dict_from_pretrained(model_name, device='cpu', dtype=dtype)
state_dict = remap_state_dict_hf_llama(state_dict, config)
model = GPTLMHeadModel(config, sp_config= spconfig, device=device, dtype=dtype)
if attn_ckpt_dir is not None:
attn_router_state_dict = create_attn_router_state_dict(attn_ckpt_dir)
merged_state_dict = {**state_dict, **attn_router_state_dict}
# TODO: Add code for tensor parallel state dict sharding
model.load_state_dict(merged_state_dict, strict=True)
model.eval()
return model