|
|
|
|
|
from typing import Any, Dict |
|
|
|
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
config_mapping = { |
|
|
'num_layers': ['num_hidden_layers'], |
|
|
'hidden_size': ['hidden_size'], |
|
|
'ffn_hidden_size': ['intermediate_size'], |
|
|
'num_attention_heads': ['num_attention_heads'], |
|
|
'num_query_groups': ['num_key_value_heads'], |
|
|
'max_position_embeddings': ['max_position_embeddings'], |
|
|
'norm_epsilon': ['rms_norm_eps'], |
|
|
'rotary_base': ['rope_theta'], |
|
|
'padded_vocab_size': ['vocab_size'], |
|
|
'attention_dropout': ['attention_dropout'], |
|
|
'untie_embeddings_and_output_weights': ['tie_word_embeddings'], |
|
|
'swiglu': ['hidden_act'], |
|
|
'add_qkv_bias': ['attention_bias'], |
|
|
'disable_bias_linear': ['mlp_bias'], |
|
|
'kv_channels': ['head_dim'], |
|
|
'model_type': ['model_type'], |
|
|
|
|
|
'moe_ffn_hidden_size': ['moe_intermediate_size'], |
|
|
'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], |
|
|
'moe_router_topk': ['num_experts_per_tok'], |
|
|
'num_experts': ['num_experts'], |
|
|
'moe_router_pre_softmax': ['norm_topk_prob'], |
|
|
'moe_aux_loss_coeff': ['router_aux_loss_coef'], |
|
|
} |
|
|
|
|
|
|
|
|
def convert_hf_config(config) -> Dict[str, Any]: |
|
|
megatron_config = {} |
|
|
for k, hf_keys in config_mapping.items(): |
|
|
for hf_k in hf_keys: |
|
|
if hasattr(config, hf_k): |
|
|
hf_v = getattr(config, hf_k) |
|
|
if k == 'rotary_base': |
|
|
megatron_config[k] = int(hf_v) |
|
|
elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: |
|
|
megatron_config[k] = not hf_v |
|
|
elif k == 'swiglu': |
|
|
if hf_v == 'silu': |
|
|
megatron_config[k] = True |
|
|
else: |
|
|
megatron_config[k] = hf_v |
|
|
break |
|
|
|
|
|
if getattr(config, 'rope_scaling', None) is not None: |
|
|
if isinstance(config.rope_scaling, int): |
|
|
megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, |
|
|
elif isinstance(config.rope_scaling, dict): |
|
|
megatron_config['rope_scaling'] = config.rope_scaling |
|
|
logger.info(f'megatron_config: {megatron_config}') |
|
|
return megatron_config |
|
|
|