| from typing import TYPE_CHECKING |
|
|
| if TYPE_CHECKING: |
| from vllm.config import VllmConfig |
|
|
| from vllm.model_executor.models.config import VerifyAndUpdateConfig |
|
|
|
|
| class EmbedderModelConfig(VerifyAndUpdateConfig): |
| @staticmethod |
| def verify_and_update_config(vllm_config: "VllmConfig") -> None: |
| from copy import deepcopy |
|
|
| from vllm.transformers_utils.config import set_default_rope_theta |
|
|
| config = vllm_config.model_config.hf_config |
| assert config.__class__.__name__ == "EmbedderConfig" |
| assert config.activation_function in ["swiglu", "gelu"] |
| config.position_embedding_type = getattr( |
| config, "position_embedding_type", "rope" |
| ) |
|
|
| if config.activation_function == "swiglu": |
| config.hidden_act = "silu" |
| else: |
| config.hidden_act = config.activation_function |
|
|
| assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias |
| config.bias = config.qkv_proj_bias |
|
|
| assert config.rotary_emb_scale_base is None |
| assert not config.rotary_emb_interleaved |
|
|
| config.layer_norm_eps = config.layer_norm_epsilon |
| config.intermediate_size = config.n_inner |
| config.hidden_size = config.n_embd |
| config.num_hidden_layers = config.n_layer |
|
|
| head_dim = config.hidden_size // config.num_attention_heads |
| rotary_emb_dim = int(head_dim * config.rotary_emb_fraction) |
| max_trained_positions = getattr(config, "max_trained_positions", 2048) |
|
|
| set_default_rope_theta(config, default_theta=config.rotary_emb_base) |
|
|
| config.rotary_kwargs = { |
| "head_size": head_dim, |
| "rotary_dim": rotary_emb_dim, |
| "max_position": max_trained_positions, |
| "rope_parameters": config.rope_parameters, |
| } |
|
|
| |
| |
| |
| |
| |
| if ( |
| not vllm_config.model_config.hf_overrides |
| and vllm_config.model_config.original_max_model_len is None |
| ): |
| |
| |
| |
| |
| max_model_len = min( |
| vllm_config.model_config.max_model_len, max_trained_positions |
| ) |
|
|
| vllm_config.recalculate_max_model_len(max_model_len) |
|
|
| else: |
| |
| |
| model_config = vllm_config.model_config |
| hf_text_config = model_config.hf_text_config |
|
|
| if isinstance(model_config.hf_overrides, dict): |
| |
| max_model_len = model_config.hf_overrides.get( |
| "max_model_len", vllm_config.model_config.max_model_len |
| ) |
| else: |
| |
| |
| max_model_len = vllm_config.model_config.max_model_len |
|
|
| |
| if hasattr(hf_text_config, "max_model_len"): |
| delattr(hf_text_config, "max_model_len") |
| hf_text_config.max_position_embeddings = max_trained_positions |
| hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"] |
|
|
| |
| |
| encoder_config = deepcopy(model_config.encoder_config) |
| if encoder_config: |
| encoder_config.pop("max_seq_length", None) |
| model_config.encoder_config = encoder_config |
|
|
| vllm_config.recalculate_max_model_len(max_model_len) |
|
|