File size: 6,345 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from tensorrt_llm.quantization import QuantAlgo
from ..._utils import str_dtype_to_torch
from .split_weights import shuffle_qkv_weights, split_weights_tp
def convert_hf_weights(hf_model, dtype, config, small_variant, args, rank):
torch_dtype = str_dtype_to_torch(dtype)
hf_state_dict = hf_model.state_dict()
weights = {}
# replace key name
for key, value in hf_state_dict.items():
# Decoder Layers
orig_key = key
if "model.layers." in key:
key = key.replace("model.layers.", "transformer.layers.")
#Attention
key = key.replace("self_attn.", "attention.")
key = key.replace("query_key_value.", "qkv.") # small
key = key.replace("Wqkv.weight", "qkv.weight")
key = key.replace("qkv_proj.", "qkv.") #128k
#MLP
key = key.replace("mlp.fc1.", "mlp.fc.")
key = key.replace("mlp.fc2.", "mlp.proj.")
key = key.replace("mlp.gate_up_proj.", "mlp.fc.")
key = key.replace(
"mlp.up_proj.",
"mlp.fc." if small_variant else "mlp.gate.") #128k
key = key.replace("mlp.down_proj.", "mlp.proj.") #128k
key = key.replace("mlp.gate_proj.", "mlp.fc.") #128k
key = key.replace("o_proj.", "dense.") #128k
#Layer norm
key = key.replace("post_attention_layernorm.",
"post_layernorm.") #128k
# Embedding
key = key.replace("model.embed_tokens.weight",
"transformer.vocab_embedding.weight")
# Final Layer norm
key = key.replace("model.final_layernorm.", "transformer.ln_f.")
key = key.replace("model.norm.", "transformer.ln_f.") #128k
if "mlp.gate_up_proj." in orig_key: #4k
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half, second_half = original_weights[:
half_split, :], original_weights[
half_split:, :]
# Swap the halves
value = torch.cat((second_half, first_half), dim=0)
if "q_proj" in key: #128k
q_param = value
k_param = hf_state_dict[orig_key.replace("q_proj", "k_proj")]
v_param = hf_state_dict[orig_key.replace("q_proj", "v_proj")]
value = torch.cat([q_param, k_param, v_param], dim=0)
key = key.replace("q_proj.weight", "qkv.weight")
elif "k_proj" in key or "v_proj" in key:
continue
weights[key] = value.to(torch_dtype).cpu()
if small_variant:
weights['lm_head.weight'] = weights[
'transformer.vocab_embedding.weight'].clone()
# Transform QKV weights from custom Phi3Small format to TRT-LLM format
for key, value in weights.items():
if "qkv." in key:
weights[key] = shuffle_qkv_weights(weights[key], config)
weights = split_weights_tp(config, weights, args, rank, torch_dtype)
return weights
def convert_small_hf_config(hf_config):
return {
'architecture': "Phi3SmallForCausalLM",
'rotary_base': hf_config.rope_embedding_base,
'gegelu_limit': hf_config.gegelu_limit,
'mup_attn_multiplier': hf_config.mup_attn_multiplier,
'mup_embedding_multiplier': hf_config.mup_embedding_multiplier,
'mup_use_scaling': hf_config.mup_use_scaling,
'mup_width_multiplier': hf_config.mup_width_multiplier,
'blocksparse_block_size': hf_config.blocksparse_block_size,
'blocksparse_homo_head_pattern':
hf_config.blocksparse_homo_head_pattern,
'blocksparse_num_local_blocks': hf_config.blocksparse_num_local_blocks,
'blocksparse_vertical_stride': hf_config.blocksparse_vert_stride,
'dense_attention_every_n_layers':
hf_config.dense_attention_every_n_layers,
}
def convert_hf_config(hf_config, dtype, args):
config = {
'architecture': "Phi3ForCausalLM",
'dtype': dtype,
'num_hidden_layers': hf_config.num_hidden_layers,
'num_attention_heads': hf_config.num_attention_heads,
'num_key_value_heads': hf_config.num_key_value_heads,
'hidden_size': hf_config.hidden_size,
'intermediate_size': hf_config.intermediate_size,
'vocab_size': hf_config.vocab_size,
'max_position_embeddings': hf_config.max_position_embeddings,
'hidden_act': hf_config.hidden_act,
'share_embedding_table': False,
}
small_variant = hf_config.architectures[0] == "Phi3SmallForCausalLM"
if small_variant:
config.update(convert_small_hf_config(hf_config))
else:
config.update({
'rotary_base': hf_config.rope_theta,
'norm_epsilon': hf_config.rms_norm_eps,
})
# Long-context variants
if hf_config.max_position_embeddings >= 128000:
config.update({
'original_max_position_embeddings':
hf_config.original_max_position_embeddings,
'longrope_scaling_short_factors':
hf_config.rope_scaling["short_factor"],
'longrope_scaling_long_factors':
hf_config.rope_scaling["long_factor"]
})
if small_variant:
config.update({
'longrope_long_mscale':
hf_config.rope_scaling["long_mscale"],
'longrope_short_mscale':
hf_config.rope_scaling["short_mscale"]
})
if config["hidden_act"] == "silu":
config["hidden_act"] = "swiglu"
# Tensor parallelism and weight-only quantization
if args is not None:
config.update({
'mapping': {
'world_size': args.tp_size * args.pp_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
}
})
if args.use_weight_only and args.weight_only_precision == 'int8':
config.update({'quantization': {'quant_algo': QuantAlgo.W8A16}})
elif args.use_weight_only and args.weight_only_precision == 'int4':
config.update({'quantization': {'quant_algo': QuantAlgo.W4A16}})
return config
|