File size: 2,952 Bytes
c6d1483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | from typing import Any, Literal
from pydantic import Field, field_serializer, field_validator
from transformers import AutoConfig, PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from speculators import SpeculatorModelConfig
__all__ = [
"Eagle3SpeculatorConfig",
]
@SpeculatorModelConfig.register("eagle3")
class Eagle3SpeculatorConfig(SpeculatorModelConfig):
"""
Configuration for EAGLE-3 speculator with vocabulary mapping.
EAGLE-3 features vocabulary mapping between draft (32K) and target (128K)
vocabularies, enabling cross-tokenizer speculation.
:param transformer_layer_config: Configuration for the transformer decoder layer
:param draft_vocab_size: Size of draft model vocabulary for speculation
:param norm_before_residual: Apply hidden_norm before storing residual
"""
speculators_model_type: Literal["eagle3"] = "eagle3"
architectures: list[str] = Field(
default_factory=lambda: ["Eagle3Speculator"],
description="Model architectures that can load these weights",
)
transformer_layer_config: PretrainedConfig = Field(
default_factory=LlamaConfig,
description="Configuration for the transformer decoder layer",
)
draft_vocab_size: int = Field(
default=32000,
description="Size of draft model vocabulary for speculation",
)
norm_before_residual: bool = Field(
default=False,
description="Apply hidden_norm before storing residual",
)
target_hidden_size: int | None = Field(
default=None,
description="Hidden size of the target model (if different from draft model)",
)
eagle_aux_hidden_state_layer_ids: list[int] | None = Field(
default=None,
description="Layer IDs of the Eagle auxiliary hidden state layers",
)
embed_requires_grad: bool = Field(
default=False,
description="Whether embedding layer weights require gradients during training",
)
@property
def target_vocab_size(self) -> int:
"""Get target vocabulary size from transformer config."""
return self.transformer_layer_config.vocab_size
@field_serializer("transformer_layer_config")
def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
"""Serialize transformer config to dict."""
return value.to_diff_dict()
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
"""Validate and convert transformer config."""
if isinstance(value, dict):
config_class: type[PretrainedConfig] = LlamaConfig
if "model_type" in value:
config_class = AutoConfig.for_model(
model_type=value["model_type"]
).__class__
return config_class(**value)
return value
|