File size: 2,577 Bytes
d87b951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Literal

from pydantic import Field

from speculators import SpeculatorModelConfig
from speculators.models.eagle3.config import Eagle3SpeculatorConfig

__all__ = [
    "PEagleSpeculatorConfig",
]


@SpeculatorModelConfig.register("peagle")
class PEagleSpeculatorConfig(Eagle3SpeculatorConfig):
    """
    Configuration for P-EAGLE (Parallel EAGLE) speculator.

    P-EAGLE extends EAGLE-3 with parallel multi-token prediction using
    Conditional Drop Token (COD) sampling for memory-efficient training.

    :param para_depths: Number of parallel prediction groups (typically 8)
    :param down_sample_ratio: Geometric decay ratio for COD sampling (r in [0,1])
    :param down_sample_ratio_min: Minimum retention ratio floor
    :param mask_token_id: Token ID for predicted token dropout (padding unused positions)
    :param max_seq_len: Maximum sequence length for attention mask construction
    """

    speculators_model_type: Literal["peagle"] = "peagle"  # type: ignore[assignment]
    architectures: list[str] = Field(
        default_factory=lambda: ["PEagleSpeculator"],
        description="Model architectures that can load these weights",
    )

    para_depths: int = Field(
        default=8,
        description="Number of parallel prediction groups (depths)",
        ge=1,
        le=16,
    )

    down_sample_ratio: float = Field(
        default=0.7,
        description="Geometric decay ratio for COD sampling (retention rate r)",
        gt=0.0,
        lt=1.0,
    )

    down_sample_ratio_min: float = Field(
        default=0.1,
        description="Minimum retention ratio floor to prevent over-sampling",
        gt=0.0,
        le=1.0,
    )

    mask_token_id: int | None = Field(
        default=None,
        description="Token ID used for padding unused positions in parallel groups",
    )

    max_seq_len: int = Field(
        default=2048,
        description="Maximum sequence length for attention mask construction",
        ge=128,
        le=8192,
    )

    # Override Eagle3 default: P-EAGLE requires trainable embeddings
    # (matches p-eagle-train)
    embed_requires_grad: bool = Field(
        default=True,
        description=(
            "Whether embedding layer weights require gradients during "
            "training (True for P-EAGLE)"
        ),
    )

    prediction_loss_weight: float = Field(
        default=1.0,
        description="Weight for prediction loss (cross-entropy on logits). "
        "P-eagle-train uses only prediction loss, no hidden state distillation.",
        gt=0.0,
    )