| """GRPO skip policy network. |
| |
| Architecture: |
| Input: per-layer hidden state projections z_l ∈ R^{d'} for l=0..L-1 |
| + scalar context features (last_tau, last_ms, position, age, ...) |
| Encoder: 2-layer Transformer encoder over layer index l (treating l as seq pos) |
| Output: per-layer skip logits u_l ∈ R^L |
| + scalar p̂ ∈ (0,1) predicting E[τ/K] (acceptance rate) |
| Action: top-M selection via TopMActionSampler; K derived from p̂ analytically |
| |
| The policy is lightweight by design — it should be orders of magnitude smaller |
| than the verify model to keep training cost negligible. |
| """ |
|
|
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .action_space import TopMActionSampler |
|
|
| _DEFAULT_DRAFT_LEN_CHOICES = [4, 8, 12, 16, 24, 32, 48, 64] |
|
|
|
|
| def optimal_draft_len(p_hat: float, choices: List[int]) -> int: |
| """Return the K from choices closest to the natural optimum K* = 1/(1−p̂). |
| |
| Intuition: under a geometric acceptance model, K* ≈ 1/(1−p) maximises |
| expected accepted tokens per verify pass. Beyond this point, extra draft |
| tokens are increasingly likely to be rejected. |
| |
| Args: |
| p_hat: predicted per-token acceptance probability in (0, 1). |
| choices: discrete candidate K values (must be non-empty). |
| """ |
| p_hat = max(0.01, min(p_hat, 0.99)) |
| k_natural = 1.0 / (1.0 - p_hat) |
| return min(choices, key=lambda k: abs(k - k_natural)) |
|
|
|
|
| class HiddenStateProjector(nn.Module): |
| """Project per-layer hidden states from d → d_policy. |
| |
| Input: tuple of [1, seq_len, d] tensors (one per layer, L+1 total) |
| Output: [L, d_policy] (one projected vector per transformer layer) |
| |
| When context_tokens > 1, the last K token positions are mean-pooled per |
| layer before projection, giving the policy a richer view of recent context. |
| If the sequence is shorter than K, all available tokens are used. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| policy_dim: int, |
| n_layers: int, |
| context_tokens: int = 1, |
| ): |
| super().__init__() |
| self.n_layers = n_layers |
| self.context_tokens = context_tokens |
| self.proj = nn.Linear(hidden_dim, policy_dim, bias=False) |
|
|
| def forward( |
| self, hidden_states: Tuple[torch.Tensor, ...] |
| ) -> torch.Tensor: |
| """Extract last K-token mean hidden state from each layer, project, return [L, d_p].""" |
| |
| |
| layer_hs = [ |
| hs[0, -self.context_tokens:, :].mean(dim=0) |
| for hs in hidden_states[1:self.n_layers + 1] |
| ] |
| stacked = torch.stack(layer_hs, dim=0) |
| return self.proj(stacked.float()) |
|
|
|
|
| class PolicyEncoder(nn.Module): |
| """2-layer Transformer encoder over layer indices (treating L layers as seq). |
| |
| Input: [L, d_policy] + scalar features appended to each position |
| Output: [L, d_policy] |
| """ |
|
|
| def __init__(self, d_policy: int, n_heads: int = 4, n_encoder_layers: int = 2): |
| super().__init__() |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_policy, |
| nhead=n_heads, |
| dim_feedforward=d_policy * 4, |
| dropout=0.0, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_encoder_layers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: [L, d_policy] → [L, d_policy]""" |
| return self.encoder(x.unsqueeze(0)).squeeze(0) |
|
|
|
|
| class ScalarFeatureEmbedder(nn.Module): |
| """Embeds scalar context features and adds them to each layer position.""" |
|
|
| FEATURE_NAMES = [ |
| "last_tau_norm", |
| "latency_norm", |
| "position_norm", |
| "age_norm", |
| "temperature", |
| ] |
| N_FEATURES = len(FEATURE_NAMES) |
|
|
| def __init__(self, d_policy: int): |
| super().__init__() |
| self.embed = nn.Linear(self.N_FEATURES, d_policy, bias=False) |
|
|
| def forward( |
| self, |
| last_tau: int, |
| draft_len: int, |
| last_ms: float, |
| position: int, |
| max_len: int, |
| age: int, |
| update_interval: int, |
| temperature: float, |
| ) -> torch.Tensor: |
| """Return [d_policy] scalar feature embedding.""" |
| feats = torch.tensor([ |
| last_tau / max(draft_len, 1), |
| last_ms / 1000.0, |
| position / max(max_len, 1), |
| age / max(update_interval, 1), |
| temperature, |
| ], dtype=torch.float32, device=self.embed.weight.device) |
| return self.embed(feats) |
|
|
|
|
| class AcceptanceRateHead(nn.Module): |
| """Predicts E[τ/K] from mean-pooled encoder output via scalar regression. |
| |
| Trained with MSE against observed τ/K each rollout — no policy gradient needed. |
| The prediction p̂ is used to derive optimal draft length K* analytically via |
| optimal_draft_len(). |
| """ |
|
|
| def __init__(self, d_policy: int): |
| super().__init__() |
| self.head = nn.Linear(d_policy, 1, bias=True) |
|
|
| def forward(self, encoded: torch.Tensor) -> torch.Tensor: |
| """encoded: [L, d_policy] → scalar p̂ ∈ (0, 1)""" |
| pooled = encoded.mean(dim=0) |
| return torch.sigmoid(self.head(pooled)).squeeze(-1) |
|
|
|
|
| class SkipPolicy(nn.Module): |
| """Full skip policy: projects hidden states → encodes → outputs skip logits |
| and a predicted acceptance rate p̂. |
| |
| Usage:: |
| |
| policy = SkipPolicy(hidden_dim=4096, n_layers=32, n_skip=16, policy_dim=128) |
| skip_logits, p_hat = policy(hidden_states, last_tau=8, ...) |
| mask, draft_len = policy.greedy_mask(hidden_states, ...) |
| """ |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| n_layers: int, |
| n_skip: int, |
| policy_dim: int = 128, |
| n_heads: int = 4, |
| n_encoder_layers: int = 2, |
| keep_prefix: int = 2, |
| keep_suffix: int = 2, |
| draft_len_choices: Optional[List[int]] = None, |
| context_tokens: int = 1, |
| ): |
| super().__init__() |
| self.n_layers = n_layers |
| self.n_skip = n_skip |
| self.draft_len_choices = ( |
| draft_len_choices if draft_len_choices is not None |
| else _DEFAULT_DRAFT_LEN_CHOICES |
| ) |
| self.projector = HiddenStateProjector(hidden_dim, policy_dim, n_layers, context_tokens) |
| self.scalar_embedder = ScalarFeatureEmbedder(policy_dim) |
| self.encoder = PolicyEncoder(policy_dim, n_heads, n_encoder_layers) |
| self.logit_head = nn.Linear(policy_dim, 1, bias=True) |
| self.sampler = TopMActionSampler(n_layers, n_skip, keep_prefix, keep_suffix) |
| self.acceptance_head = AcceptanceRateHead(policy_dim) |
|
|
| def forward( |
| self, |
| hidden_states: Tuple[torch.Tensor, ...], |
| last_tau: int = 0, |
| draft_len: int = 16, |
| last_ms: float = 0.0, |
| position: int = 0, |
| max_len: int = 256, |
| age: int = 0, |
| update_interval: int = 1, |
| temperature: float = 0.0, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute per-layer skip logits and predicted acceptance rate. |
| |
| Returns: |
| skip_logits: [n_layers] |
| p_hat: scalar ∈ (0, 1), predicted E[τ/K] |
| """ |
| z = self.projector(hidden_states) |
|
|
| scalar_emb = self.scalar_embedder( |
| last_tau, draft_len, last_ms, position, max_len, |
| age, update_interval, temperature, |
| ) |
| z = z + scalar_emb.unsqueeze(0) |
|
|
| encoded = self.encoder(z) |
|
|
| skip_logits = self.logit_head(encoded).squeeze(-1) |
| p_hat = self.acceptance_head(encoded) |
| return skip_logits, p_hat |
|
|
| def sample_mask( |
| self, |
| hidden_states: Tuple[torch.Tensor, ...], |
| temperature: float = 1.0, |
| **kwargs, |
| ) -> Tuple[List[int], int, torch.Tensor]: |
| """Sample a skip mask and derive draft length from predicted acceptance rate. |
| |
| draft_len is selected deterministically via optimal_draft_len(p̂) — no RL. |
| log_p covers only the skip mask action. |
| |
| Returns: |
| hard_mask: List[int] of length n_layers |
| draft_len: int, derived from p̂ |
| log_p: scalar tensor, log π(mask | h) |
| """ |
| skip_logits, p_hat = self.forward(hidden_states, **kwargs) |
|
|
| soft_mask = self.sampler(skip_logits, temperature=temperature) |
| hard_mask = (soft_mask.detach() > 0.5).long().tolist() |
| log_p = self.sampler.log_prob(skip_logits, soft_mask.detach()) |
|
|
| draft_len = optimal_draft_len(p_hat.detach().item(), self.draft_len_choices) |
| return hard_mask, draft_len, log_p |
|
|
| def greedy_mask( |
| self, |
| hidden_states: Tuple[torch.Tensor, ...], |
| **kwargs, |
| ) -> Tuple[List[int], int]: |
| """Deterministic greedy mask and draft length for evaluation.""" |
| with torch.no_grad(): |
| skip_logits, p_hat = self.forward(hidden_states, **kwargs) |
| mask = self.sampler.greedy_mask(skip_logits) |
| draft_len = optimal_draft_len(p_hat.item(), self.draft_len_choices) |
| return mask, draft_len |
|
|
| def compile_for_inference(self) -> None: |
| """Replace forward with a torch.compile'd version for faster inference. |
| |
| Call once after policy.eval() and before the generation loop. |
| Use fullgraph=False to tolerate the torch.tensor() call inside |
| ScalarFeatureEmbedder without needing to refactor it. |
| """ |
| self.forward = torch.compile(self.forward, mode="max-autotune", fullgraph=False) |
|
|