dflash-test / config.py
shanjiaz's picture
Upload folder using huggingface_hub
331eb39 verified
from typing import Any, Literal
from pydantic import Field, field_serializer, field_validator
from transformers import AutoConfig, PretrainedConfig
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3Config,)
from speculators import SpeculatorModelConfig
__all__ = [
"DFlashSpeculatorConfig",
]
@SpeculatorModelConfig.register("dflash")
class DFlashSpeculatorConfig(SpeculatorModelConfig):
"""
Configuration for DFlash speculator with vocabulary mapping.
DFlash features vocabulary mapping between draft (64K) and target (128K)
vocabularies, enabling cross-tokenizer speculation.
:param transformer_layer_config: Configuration for the transformer decoder layer
:param draft_vocab_size: Size of draft model vocabulary for speculation
"""
speculators_model_type: Literal["dflash"] = "dflash"
architectures: list[str] = Field(
default_factory=lambda: ["DFlashSpeculator"],
description="Model architectures that can load these weights",
)
transformer_layer_config: PretrainedConfig = Field(
default_factory=Qwen3Config,
description="Configuration for the transformer decoder layer",
)
draft_vocab_size: int = Field(
default=32000,
description="Size of draft model vocabulary for speculation",
)
num_hidden_layers: int = Field(
default=3,
description="Number of hidden layers in the DFlash model",
)
block_size: int = Field(
default=8,
description="Default size of the draft block predicted with a forward pass of the model",
)
target_hidden_size: int | None = Field(
default=None,
description="Hidden size of the target model (if different from draft model)",
)
aux_hidden_state_layer_ids: list[int] | None = Field(
default=None,
description="Layer IDs of the DFlash auxiliary hidden state layers",
)
@property
def target_vocab_size(self) -> int:
"""Get target vocabulary size from transformer config."""
return self.transformer_layer_config.vocab_size
@field_serializer("transformer_layer_config")
def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
"""Serialize transformer config to dict."""
return value.to_diff_dict()
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
"""Validate and convert transformer config."""
if isinstance(value, dict):
config_class: type[PretrainedConfig] = Qwen3Config
if "model_type" in value:
config_class = AutoConfig.for_model(
model_type=value["model_type"]
).__class__
return config_class(**value)
return value