| from typing import Any, Literal |
|
|
| from pydantic import Field, field_serializer, field_validator |
| from transformers import AutoConfig, PretrainedConfig |
| from transformers.models.llama.configuration_llama import LlamaConfig |
|
|
| from speculators import SpeculatorModelConfig |
|
|
| __all__ = [ |
| "Eagle3SpeculatorConfig", |
| ] |
|
|
|
|
| @SpeculatorModelConfig.register("eagle3") |
| class Eagle3SpeculatorConfig(SpeculatorModelConfig): |
| """ |
| Configuration for EAGLE-3 speculator with vocabulary mapping. |
| |
| EAGLE-3 features vocabulary mapping between draft (32K) and target (128K) |
| vocabularies, enabling cross-tokenizer speculation. |
| |
| :param transformer_layer_config: Configuration for the transformer decoder layer |
| :param draft_vocab_size: Size of draft model vocabulary for speculation |
| :param norm_before_residual: Apply hidden_norm before storing residual |
| """ |
|
|
| speculators_model_type: Literal["eagle3"] = "eagle3" |
| architectures: list[str] = Field( |
| default_factory=lambda: ["Eagle3Speculator"], |
| description="Model architectures that can load these weights", |
| ) |
|
|
| transformer_layer_config: PretrainedConfig = Field( |
| default_factory=LlamaConfig, |
| description="Configuration for the transformer decoder layer", |
| ) |
|
|
| draft_vocab_size: int = Field( |
| default=32000, |
| description="Size of draft model vocabulary for speculation", |
| ) |
|
|
| norm_before_residual: bool = Field( |
| default=False, |
| description="Apply hidden_norm before storing residual", |
| ) |
|
|
| target_hidden_size: int | None = Field( |
| default=None, |
| description="Hidden size of the target model (if different from draft model)", |
| ) |
|
|
| eagle_aux_hidden_state_layer_ids: list[int] | None = Field( |
| default=None, |
| description="Layer IDs of the Eagle auxiliary hidden state layers", |
| ) |
|
|
| embed_requires_grad: bool = Field( |
| default=False, |
| description="Whether embedding layer weights require gradients during training", |
| ) |
|
|
| @property |
| def target_vocab_size(self) -> int: |
| """Get target vocabulary size from transformer config.""" |
| return self.transformer_layer_config.vocab_size |
|
|
| @field_serializer("transformer_layer_config") |
| def serialize_transformer_config(self, value: PretrainedConfig) -> dict: |
| """Serialize transformer config to dict.""" |
| return value.to_diff_dict() |
|
|
| @field_validator("transformer_layer_config", mode="before") |
| @classmethod |
| def validate_transformer_config(cls, value: Any) -> PretrainedConfig: |
| """Validate and convert transformer config.""" |
| if isinstance(value, dict): |
| config_class: type[PretrainedConfig] = LlamaConfig |
| if "model_type" in value: |
| config_class = AutoConfig.for_model( |
| model_type=value["model_type"] |
| ).__class__ |
| return config_class(**value) |
| return value |
|
|