| from typing import Any, Literal |
|
|
| from pydantic import Field, field_serializer, field_validator |
| from transformers import AutoConfig, PretrainedConfig |
| from transformers.models.qwen3.modeling_qwen3 import ( |
| Qwen3Config,) |
| from speculators import SpeculatorModelConfig |
|
|
| __all__ = [ |
| "DFlashSpeculatorConfig", |
| ] |
|
|
|
|
| @SpeculatorModelConfig.register("dflash") |
| class DFlashSpeculatorConfig(SpeculatorModelConfig): |
| """ |
| Configuration for DFlash speculator with vocabulary mapping. |
| |
| DFlash features vocabulary mapping between draft (64K) and target (128K) |
| vocabularies, enabling cross-tokenizer speculation. |
| |
| :param transformer_layer_config: Configuration for the transformer decoder layer |
| :param draft_vocab_size: Size of draft model vocabulary for speculation |
| """ |
|
|
| speculators_model_type: Literal["dflash"] = "dflash" |
| architectures: list[str] = Field( |
| default_factory=lambda: ["DFlashSpeculator"], |
| description="Model architectures that can load these weights", |
| ) |
|
|
| transformer_layer_config: PretrainedConfig = Field( |
| default_factory=Qwen3Config, |
| description="Configuration for the transformer decoder layer", |
| ) |
|
|
| draft_vocab_size: int = Field( |
| default=32000, |
| description="Size of draft model vocabulary for speculation", |
| ) |
|
|
| num_hidden_layers: int = Field( |
| default=3, |
| description="Number of hidden layers in the DFlash model", |
| ) |
|
|
| block_size: int = Field( |
| default=8, |
| description="Default size of the draft block predicted with a forward pass of the model", |
| ) |
|
|
| target_hidden_size: int | None = Field( |
| default=None, |
| description="Hidden size of the target model (if different from draft model)", |
| ) |
|
|
| aux_hidden_state_layer_ids: list[int] | None = Field( |
| default=None, |
| description="Layer IDs of the DFlash auxiliary hidden state layers", |
| ) |
|
|
| @property |
| def target_vocab_size(self) -> int: |
| """Get target vocabulary size from transformer config.""" |
| return self.transformer_layer_config.vocab_size |
|
|
| @field_serializer("transformer_layer_config") |
| def serialize_transformer_config(self, value: PretrainedConfig) -> dict: |
| """Serialize transformer config to dict.""" |
| return value.to_diff_dict() |
|
|
| @field_validator("transformer_layer_config", mode="before") |
| @classmethod |
| def validate_transformer_config(cls, value: Any) -> PretrainedConfig: |
| """Validate and convert transformer config.""" |
| if isinstance(value, dict): |
| config_class: type[PretrainedConfig] = Qwen3Config |
| if "model_type" in value: |
| config_class = AutoConfig.for_model( |
| model_type=value["model_type"] |
| ).__class__ |
| return config_class(**value) |
| return value |
|
|