peagle-gemma4 / config.py
MeganEFlynn's picture
Upload folder using huggingface_hub
d87b951 verified
Raw
History Blame Contribute Delete
2.58 kB
from typing import Literal
from pydantic import Field
from speculators import SpeculatorModelConfig
from speculators.models.eagle3.config import Eagle3SpeculatorConfig
__all__ = [
"PEagleSpeculatorConfig",
]
@SpeculatorModelConfig.register("peagle")
class PEagleSpeculatorConfig(Eagle3SpeculatorConfig):
"""
Configuration for P-EAGLE (Parallel EAGLE) speculator.
P-EAGLE extends EAGLE-3 with parallel multi-token prediction using
Conditional Drop Token (COD) sampling for memory-efficient training.
:param para_depths: Number of parallel prediction groups (typically 8)
:param down_sample_ratio: Geometric decay ratio for COD sampling (r in [0,1])
:param down_sample_ratio_min: Minimum retention ratio floor
:param mask_token_id: Token ID for predicted token dropout (padding unused positions)
:param max_seq_len: Maximum sequence length for attention mask construction
"""
speculators_model_type: Literal["peagle"] = "peagle" # type: ignore[assignment]
architectures: list[str] = Field(
default_factory=lambda: ["PEagleSpeculator"],
description="Model architectures that can load these weights",
)
para_depths: int = Field(
default=8,
description="Number of parallel prediction groups (depths)",
ge=1,
le=16,
)
down_sample_ratio: float = Field(
default=0.7,
description="Geometric decay ratio for COD sampling (retention rate r)",
gt=0.0,
lt=1.0,
)
down_sample_ratio_min: float = Field(
default=0.1,
description="Minimum retention ratio floor to prevent over-sampling",
gt=0.0,
le=1.0,
)
mask_token_id: int | None = Field(
default=None,
description="Token ID used for padding unused positions in parallel groups",
)
max_seq_len: int = Field(
default=2048,
description="Maximum sequence length for attention mask construction",
ge=128,
le=8192,
)
# Override Eagle3 default: P-EAGLE requires trainable embeddings
# (matches p-eagle-train)
embed_requires_grad: bool = Field(
default=True,
description=(
"Whether embedding layer weights require gradients during "
"training (True for P-EAGLE)"
),
)
prediction_loss_weight: float = Field(
default=1.0,
description="Weight for prediction loss (cross-entropy on logits). "
"P-eagle-train uses only prediction loss, no hidden state distillation.",
gt=0.0,
)