Initial upload: policy checkpoint, config, model card, source
Browse files- README.md +98 -0
- config.json +17 -0
- policy_best.pt +3 -0
- src/__init__.py +0 -0
- src/grpo/__init__.py +0 -0
- src/grpo/action_space.py +114 -0
- src/grpo/policy.py +273 -0
README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
tags:
|
| 6 |
+
- speculative-decoding
|
| 7 |
+
- layer-skipping
|
| 8 |
+
- grpo
|
| 9 |
+
- clasp
|
| 10 |
+
- llama
|
| 11 |
+
- efficiency
|
| 12 |
+
base_model: meta-llama/Meta-Llama-3-8B
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# CASM — CLaSp Adaptive Skip Mask
|
| 16 |
+
|
| 17 |
+
CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.
|
| 18 |
+
|
| 19 |
+
## How it works
|
| 20 |
+
|
| 21 |
+
Self-speculative decoding runs the *same* frozen model in two modes:
|
| 22 |
+
|
| 23 |
+
1. **Draft** — selected decoder layers are skipped, producing `K` candidate tokens cheaply.
|
| 24 |
+
2. **Verify** — the full model validates the draft block and accepts the longest matching prefix.
|
| 25 |
+
3. **Policy update** — CASM observes the verify hidden states and chooses a new skip mask for the next cycle.
|
| 26 |
+
|
| 27 |
+
CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.
|
| 28 |
+
|
| 29 |
+
## Architecture
|
| 30 |
+
|
| 31 |
+
| Component | Description |
|
| 32 |
+
|---|---|
|
| 33 |
+
| `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` → `[L, 128]` |
|
| 34 |
+
| `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) |
|
| 35 |
+
| `PolicyEncoder` | 2-layer Transformer encoder over layer positions |
|
| 36 |
+
| `logit_head` | Per-layer skip logits → top-M selection |
|
| 37 |
+
| `AcceptanceRateHead` | Predicts E[τ/K] → optimal draft length K* |
|
| 38 |
+
|
| 39 |
+
**Parameters:** ~200 K
|
| 40 |
+
**Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096)
|
| 41 |
+
**Skip budget:** 8 layers per draft cycle
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
import torch
|
| 47 |
+
from src.grpo.policy import SkipPolicy
|
| 48 |
+
|
| 49 |
+
# Load policy
|
| 50 |
+
ckpt = torch.load("policy_best.pt", map_location="cpu")
|
| 51 |
+
policy = SkipPolicy(
|
| 52 |
+
hidden_dim=4096,
|
| 53 |
+
n_layers=32,
|
| 54 |
+
n_skip=8,
|
| 55 |
+
policy_dim=128,
|
| 56 |
+
context_tokens=1,
|
| 57 |
+
)
|
| 58 |
+
policy.load_state_dict(ckpt["policy_state_dict"])
|
| 59 |
+
policy.eval()
|
| 60 |
+
|
| 61 |
+
# During self-speculative decoding, call after each verify pass:
|
| 62 |
+
# hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
|
| 63 |
+
mask, draft_len = policy.greedy_mask(
|
| 64 |
+
hidden_states,
|
| 65 |
+
last_tau=accepted_tokens,
|
| 66 |
+
draft_len=current_draft_len,
|
| 67 |
+
position=current_position,
|
| 68 |
+
max_len=max_new_tokens,
|
| 69 |
+
)
|
| 70 |
+
# mask: list of 0/1 per layer (1 = skip this layer during draft)
|
| 71 |
+
# draft_len: recommended tokens to draft next cycle
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase.
|
| 75 |
+
|
| 76 |
+
## Training
|
| 77 |
+
|
| 78 |
+
Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.
|
| 79 |
+
|
| 80 |
+
| Metric | Value |
|
| 81 |
+
|---|---|
|
| 82 |
+
| Training steps | 10 000 |
|
| 83 |
+
| Eval reward | 99.8 |
|
| 84 |
+
| Test reward | 100.4 |
|
| 85 |
+
| GPU | NVIDIA A100 80 GB |
|
| 86 |
+
|
| 87 |
+
## Citation
|
| 88 |
+
|
| 89 |
+
If you use CASM, please cite the CLaSp paper and this repository:
|
| 90 |
+
|
| 91 |
+
```bibtex
|
| 92 |
+
@misc{casm2026,
|
| 93 |
+
author = {Dayne Guy},
|
| 94 |
+
title = {CASM: CLaSp Adaptive Skip Mask},
|
| 95 |
+
year = {2026},
|
| 96 |
+
url = {https://huggingface.co/dayngerous/CASM}
|
| 97 |
+
}
|
| 98 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "casm_skip_policy",
|
| 3 |
+
"hidden_dim": 4096,
|
| 4 |
+
"n_layers": 32,
|
| 5 |
+
"n_skip": 8,
|
| 6 |
+
"policy_dim": 128,
|
| 7 |
+
"n_heads": 4,
|
| 8 |
+
"n_encoder_layers": 2,
|
| 9 |
+
"keep_prefix": 0,
|
| 10 |
+
"keep_suffix": 0,
|
| 11 |
+
"context_tokens": 1,
|
| 12 |
+
"draft_len_choices": [4, 8, 12, 16, 24, 32, 48, 64],
|
| 13 |
+
"base_model": "meta-llama/Meta-Llama-3-8B",
|
| 14 |
+
"training_steps": 10000,
|
| 15 |
+
"eval_reward": 99.773,
|
| 16 |
+
"test_reward": 100.429
|
| 17 |
+
}
|
policy_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:762abdcb49e258e4bf2c29402429722f8e6cd0f09344e38cf137575946977783
|
| 3 |
+
size 3698949
|
src/__init__.py
ADDED
|
File without changes
|
src/grpo/__init__.py
ADDED
|
File without changes
|
src/grpo/action_space.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Action space definitions for the GRPO skip policy.
|
| 2 |
+
|
| 3 |
+
The action is a binary skip mask S ∈ {0,1}^L.
|
| 4 |
+
This module provides samplers and constraint enforcement for different
|
| 5 |
+
action space parameterizations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TopMActionSampler(nn.Module):
|
| 16 |
+
"""Fixed-budget action sampler: select exactly M layers to skip.
|
| 17 |
+
|
| 18 |
+
Policy outputs per-layer logits; the top-M scoring eligible layers are skipped.
|
| 19 |
+
During training, uses a straight-through Gumbel-top-K estimator for gradients.
|
| 20 |
+
During evaluation / rollouts, uses deterministic top-M argmax.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
n_layers: total number of transformer layers.
|
| 24 |
+
n_skip: skip budget M.
|
| 25 |
+
keep_prefix: number of layers at start that cannot be skipped.
|
| 26 |
+
keep_suffix: number of layers at end that cannot be skipped.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
n_layers: int,
|
| 32 |
+
n_skip: int,
|
| 33 |
+
keep_prefix: int = 2,
|
| 34 |
+
keep_suffix: int = 2,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.n_layers = n_layers
|
| 38 |
+
self.n_skip = n_skip
|
| 39 |
+
self.keep_prefix = keep_prefix
|
| 40 |
+
self.keep_suffix = keep_suffix
|
| 41 |
+
# Mask for eligible layers
|
| 42 |
+
eligible = torch.zeros(n_layers, dtype=torch.bool)
|
| 43 |
+
eligible[keep_prefix : n_layers - keep_suffix] = True
|
| 44 |
+
self.register_buffer("eligible", eligible)
|
| 45 |
+
|
| 46 |
+
def forward(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
|
| 47 |
+
"""Sample a soft skip mask from policy logits.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
logits: [n_layers] raw skip logits from policy network.
|
| 51 |
+
temperature: sampling temperature (1.0 = standard, lower = more peaked).
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
hard_mask: [n_layers] binary tensor (differentiable via straight-through).
|
| 55 |
+
"""
|
| 56 |
+
# Zero out ineligible logits
|
| 57 |
+
masked_logits = logits.clone()
|
| 58 |
+
masked_logits[~self.eligible] = float("-inf")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
scale = masked_logits[self.eligible].std().detach().clamp(min=1.0)
|
| 62 |
+
masked_logits = masked_logits / scale
|
| 63 |
+
|
| 64 |
+
# Gumbel-top-K for differentiable discrete selection
|
| 65 |
+
gumbel = -torch.log(-torch.log(torch.clamp(torch.rand_like(masked_logits), 1e-9, 1.0)))
|
| 66 |
+
perturbed = (masked_logits + gumbel) / temperature
|
| 67 |
+
|
| 68 |
+
# Select top n_skip eligible indices
|
| 69 |
+
topk_vals, topk_idx = torch.topk(perturbed, self.n_skip)
|
| 70 |
+
hard_mask = torch.zeros(self.n_layers, device=logits.device)
|
| 71 |
+
hard_mask.scatter_(0, topk_idx, 1.0)
|
| 72 |
+
|
| 73 |
+
# Straight-through estimator: use hard mask in forward, soft mask in backward
|
| 74 |
+
soft_mask = torch.sigmoid(masked_logits / temperature)
|
| 75 |
+
return hard_mask + (soft_mask - soft_mask.detach())
|
| 76 |
+
|
| 77 |
+
def greedy_mask(self, logits: torch.Tensor) -> List[int]:
|
| 78 |
+
"""Deterministic top-M mask for inference."""
|
| 79 |
+
masked_logits = logits.clone()
|
| 80 |
+
masked_logits[~self.eligible] = float("-inf")
|
| 81 |
+
_, topk_idx = torch.topk(masked_logits, self.n_skip)
|
| 82 |
+
mask = torch.zeros(self.n_layers, dtype=torch.long)
|
| 83 |
+
mask[topk_idx] = 1
|
| 84 |
+
return mask.tolist()
|
| 85 |
+
|
| 86 |
+
def log_prob(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
"""Log-probability of a discrete mask under the Plackett-Luce model.
|
| 88 |
+
|
| 89 |
+
Uses sequential conditioning: each selected layer is drawn from a
|
| 90 |
+
categorical over the remaining eligible layers, consistent with
|
| 91 |
+
Gumbel-top-K sampling.
|
| 92 |
+
"""
|
| 93 |
+
# Apply the same unit-std normalization used in forward() so that log_p_old
|
| 94 |
+
# (computed at rollout time) and log_p_new (computed during the PPO update)
|
| 95 |
+
# are on the same scale and the ratio exp(log_p_new - log_p_old) is correct.
|
| 96 |
+
# Clamp after normalizing for numerical safety in log_softmax.
|
| 97 |
+
eligible_logits = logits[self.eligible]
|
| 98 |
+
scale = eligible_logits.std().detach().clamp(min=1.0)
|
| 99 |
+
eligible_logits = (eligible_logits / scale).clamp(-50.0, 50.0)
|
| 100 |
+
selected_indices = mask[self.eligible].bool().nonzero(as_tuple=True)[0]
|
| 101 |
+
|
| 102 |
+
log_p = logits.new_zeros(())
|
| 103 |
+
# Use a bool exclusion mask (no grad) instead of in-place modification of
|
| 104 |
+
# eligible_logits (which has requires_grad=True during _update_policy).
|
| 105 |
+
# In-place ops on a grad tensor mid-loop corrupt autograd's version counter
|
| 106 |
+
# and can silently produce NaN gradients.
|
| 107 |
+
exclusion = torch.zeros(eligible_logits.shape[0], dtype=torch.bool,
|
| 108 |
+
device=logits.device)
|
| 109 |
+
for idx in selected_indices:
|
| 110 |
+
masked = eligible_logits.masked_fill(exclusion, float("-inf"))
|
| 111 |
+
log_p = log_p + F.log_softmax(masked, dim=0)[idx]
|
| 112 |
+
exclusion = exclusion.clone()
|
| 113 |
+
exclusion[idx] = True
|
| 114 |
+
return log_p
|
src/grpo/policy.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO skip policy network.
|
| 2 |
+
|
| 3 |
+
Architecture:
|
| 4 |
+
Input: per-layer hidden state projections z_l ∈ R^{d'} for l=0..L-1
|
| 5 |
+
+ scalar context features (last_tau, last_ms, position, age, ...)
|
| 6 |
+
Encoder: 2-layer Transformer encoder over layer index l (treating l as seq pos)
|
| 7 |
+
Output: per-layer skip logits u_l ∈ R^L
|
| 8 |
+
+ scalar p̂ ∈ (0,1) predicting E[τ/K] (acceptance rate)
|
| 9 |
+
Action: top-M selection via TopMActionSampler; K derived from p̂ analytically
|
| 10 |
+
|
| 11 |
+
The policy is lightweight by design — it should be orders of magnitude smaller
|
| 12 |
+
than the verify model to keep training cost negligible.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
from .action_space import TopMActionSampler
|
| 22 |
+
|
| 23 |
+
_DEFAULT_DRAFT_LEN_CHOICES = [4, 8, 12, 16, 24, 32, 48, 64]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def optimal_draft_len(p_hat: float, choices: List[int]) -> int:
|
| 27 |
+
"""Return the K from choices closest to the natural optimum K* = 1/(1−p̂).
|
| 28 |
+
|
| 29 |
+
Intuition: under a geometric acceptance model, K* ≈ 1/(1−p) maximises
|
| 30 |
+
expected accepted tokens per verify pass. Beyond this point, extra draft
|
| 31 |
+
tokens are increasingly likely to be rejected.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
p_hat: predicted per-token acceptance probability in (0, 1).
|
| 35 |
+
choices: discrete candidate K values (must be non-empty).
|
| 36 |
+
"""
|
| 37 |
+
p_hat = max(0.01, min(p_hat, 0.99))
|
| 38 |
+
k_natural = 1.0 / (1.0 - p_hat)
|
| 39 |
+
return min(choices, key=lambda k: abs(k - k_natural))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class HiddenStateProjector(nn.Module):
|
| 43 |
+
"""Project per-layer hidden states from d → d_policy.
|
| 44 |
+
|
| 45 |
+
Input: tuple of [1, seq_len, d] tensors (one per layer, L+1 total)
|
| 46 |
+
Output: [L, d_policy] (one projected vector per transformer layer)
|
| 47 |
+
|
| 48 |
+
When context_tokens > 1, the last K token positions are mean-pooled per
|
| 49 |
+
layer before projection, giving the policy a richer view of recent context.
|
| 50 |
+
If the sequence is shorter than K, all available tokens are used.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
hidden_dim: int,
|
| 56 |
+
policy_dim: int,
|
| 57 |
+
n_layers: int,
|
| 58 |
+
context_tokens: int = 1,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.n_layers = n_layers
|
| 62 |
+
self.context_tokens = context_tokens
|
| 63 |
+
self.proj = nn.Linear(hidden_dim, policy_dim, bias=False)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self, hidden_states: Tuple[torch.Tensor, ...]
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""Extract last K-token mean hidden state from each layer, project, return [L, d_p]."""
|
| 69 |
+
# hidden_states has L+1 entries: [embed, layer_0_out, ..., layer_{L-1}_out]
|
| 70 |
+
# We use layers 1..L (i.e., skip the embedding output at index 0)
|
| 71 |
+
layer_hs = [
|
| 72 |
+
hs[0, -self.context_tokens:, :].mean(dim=0) # [d] (mean over last K tokens)
|
| 73 |
+
for hs in hidden_states[1:self.n_layers + 1]
|
| 74 |
+
]
|
| 75 |
+
stacked = torch.stack(layer_hs, dim=0) # [L, d]
|
| 76 |
+
return self.proj(stacked.float()) # [L, d_policy]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class PolicyEncoder(nn.Module):
|
| 80 |
+
"""2-layer Transformer encoder over layer indices (treating L layers as seq).
|
| 81 |
+
|
| 82 |
+
Input: [L, d_policy] + scalar features appended to each position
|
| 83 |
+
Output: [L, d_policy]
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, d_policy: int, n_heads: int = 4, n_encoder_layers: int = 2):
|
| 87 |
+
super().__init__()
|
| 88 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 89 |
+
d_model=d_policy,
|
| 90 |
+
nhead=n_heads,
|
| 91 |
+
dim_feedforward=d_policy * 4,
|
| 92 |
+
dropout=0.0,
|
| 93 |
+
batch_first=True,
|
| 94 |
+
norm_first=True,
|
| 95 |
+
)
|
| 96 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_encoder_layers)
|
| 97 |
+
|
| 98 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 99 |
+
"""x: [L, d_policy] → [L, d_policy]"""
|
| 100 |
+
return self.encoder(x.unsqueeze(0)).squeeze(0) # add/remove batch dim
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ScalarFeatureEmbedder(nn.Module):
|
| 104 |
+
"""Embeds scalar context features and adds them to each layer position."""
|
| 105 |
+
|
| 106 |
+
FEATURE_NAMES = [
|
| 107 |
+
"last_tau_norm", # last_tau / draft_len (acceptance rate)
|
| 108 |
+
"latency_norm", # last cycle ms / expected ms (rough normalization)
|
| 109 |
+
"position_norm", # current token position / max_len
|
| 110 |
+
"age_norm", # mask age / update_interval
|
| 111 |
+
"temperature", # generation temperature (fixed per run)
|
| 112 |
+
]
|
| 113 |
+
N_FEATURES = len(FEATURE_NAMES)
|
| 114 |
+
|
| 115 |
+
def __init__(self, d_policy: int):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.embed = nn.Linear(self.N_FEATURES, d_policy, bias=False)
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self,
|
| 121 |
+
last_tau: int,
|
| 122 |
+
draft_len: int,
|
| 123 |
+
last_ms: float,
|
| 124 |
+
position: int,
|
| 125 |
+
max_len: int,
|
| 126 |
+
age: int,
|
| 127 |
+
update_interval: int,
|
| 128 |
+
temperature: float,
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
"""Return [d_policy] scalar feature embedding."""
|
| 131 |
+
feats = torch.tensor([
|
| 132 |
+
last_tau / max(draft_len, 1),
|
| 133 |
+
last_ms / 1000.0,
|
| 134 |
+
position / max(max_len, 1),
|
| 135 |
+
age / max(update_interval, 1),
|
| 136 |
+
temperature,
|
| 137 |
+
], dtype=torch.float32, device=self.embed.weight.device)
|
| 138 |
+
return self.embed(feats)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AcceptanceRateHead(nn.Module):
|
| 142 |
+
"""Predicts E[τ/K] from mean-pooled encoder output via scalar regression.
|
| 143 |
+
|
| 144 |
+
Trained with MSE against observed τ/K each rollout — no policy gradient needed.
|
| 145 |
+
The prediction p̂ is used to derive optimal draft length K* analytically via
|
| 146 |
+
optimal_draft_len().
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, d_policy: int):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.head = nn.Linear(d_policy, 1, bias=True)
|
| 152 |
+
|
| 153 |
+
def forward(self, encoded: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""encoded: [L, d_policy] → scalar p̂ ∈ (0, 1)"""
|
| 155 |
+
pooled = encoded.mean(dim=0) # [d_policy]
|
| 156 |
+
return torch.sigmoid(self.head(pooled)).squeeze(-1) # scalar
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class SkipPolicy(nn.Module):
|
| 160 |
+
"""Full skip policy: projects hidden states → encodes → outputs skip logits
|
| 161 |
+
and a predicted acceptance rate p̂.
|
| 162 |
+
|
| 163 |
+
Usage::
|
| 164 |
+
|
| 165 |
+
policy = SkipPolicy(hidden_dim=4096, n_layers=32, n_skip=16, policy_dim=128)
|
| 166 |
+
skip_logits, p_hat = policy(hidden_states, last_tau=8, ...)
|
| 167 |
+
mask, draft_len = policy.greedy_mask(hidden_states, ...)
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
hidden_dim: int,
|
| 173 |
+
n_layers: int,
|
| 174 |
+
n_skip: int,
|
| 175 |
+
policy_dim: int = 128,
|
| 176 |
+
n_heads: int = 4,
|
| 177 |
+
n_encoder_layers: int = 2,
|
| 178 |
+
keep_prefix: int = 2,
|
| 179 |
+
keep_suffix: int = 2,
|
| 180 |
+
draft_len_choices: Optional[List[int]] = None,
|
| 181 |
+
context_tokens: int = 1,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.n_layers = n_layers
|
| 185 |
+
self.n_skip = n_skip
|
| 186 |
+
self.draft_len_choices = (
|
| 187 |
+
draft_len_choices if draft_len_choices is not None
|
| 188 |
+
else _DEFAULT_DRAFT_LEN_CHOICES
|
| 189 |
+
)
|
| 190 |
+
self.projector = HiddenStateProjector(hidden_dim, policy_dim, n_layers, context_tokens)
|
| 191 |
+
self.scalar_embedder = ScalarFeatureEmbedder(policy_dim)
|
| 192 |
+
self.encoder = PolicyEncoder(policy_dim, n_heads, n_encoder_layers)
|
| 193 |
+
self.logit_head = nn.Linear(policy_dim, 1, bias=True)
|
| 194 |
+
self.sampler = TopMActionSampler(n_layers, n_skip, keep_prefix, keep_suffix)
|
| 195 |
+
self.acceptance_head = AcceptanceRateHead(policy_dim)
|
| 196 |
+
|
| 197 |
+
def forward(
|
| 198 |
+
self,
|
| 199 |
+
hidden_states: Tuple[torch.Tensor, ...],
|
| 200 |
+
last_tau: int = 0,
|
| 201 |
+
draft_len: int = 16,
|
| 202 |
+
last_ms: float = 0.0,
|
| 203 |
+
position: int = 0,
|
| 204 |
+
max_len: int = 256,
|
| 205 |
+
age: int = 0,
|
| 206 |
+
update_interval: int = 1,
|
| 207 |
+
temperature: float = 0.0,
|
| 208 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 209 |
+
"""Compute per-layer skip logits and predicted acceptance rate.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
skip_logits: [n_layers]
|
| 213 |
+
p_hat: scalar ∈ (0, 1), predicted E[τ/K]
|
| 214 |
+
"""
|
| 215 |
+
z = self.projector(hidden_states) # [L, d_policy]
|
| 216 |
+
|
| 217 |
+
scalar_emb = self.scalar_embedder(
|
| 218 |
+
last_tau, draft_len, last_ms, position, max_len,
|
| 219 |
+
age, update_interval, temperature,
|
| 220 |
+
) # [d_policy]
|
| 221 |
+
z = z + scalar_emb.unsqueeze(0) # [L, d_policy]
|
| 222 |
+
|
| 223 |
+
encoded = self.encoder(z) # [L, d_policy]
|
| 224 |
+
|
| 225 |
+
skip_logits = self.logit_head(encoded).squeeze(-1) # [L]
|
| 226 |
+
p_hat = self.acceptance_head(encoded) # scalar
|
| 227 |
+
return skip_logits, p_hat
|
| 228 |
+
|
| 229 |
+
def sample_mask(
|
| 230 |
+
self,
|
| 231 |
+
hidden_states: Tuple[torch.Tensor, ...],
|
| 232 |
+
temperature: float = 1.0,
|
| 233 |
+
**kwargs,
|
| 234 |
+
) -> Tuple[List[int], int, torch.Tensor]:
|
| 235 |
+
"""Sample a skip mask and derive draft length from predicted acceptance rate.
|
| 236 |
+
|
| 237 |
+
draft_len is selected deterministically via optimal_draft_len(p̂) — no RL.
|
| 238 |
+
log_p covers only the skip mask action.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
hard_mask: List[int] of length n_layers
|
| 242 |
+
draft_len: int, derived from p̂
|
| 243 |
+
log_p: scalar tensor, log π(mask | h)
|
| 244 |
+
"""
|
| 245 |
+
skip_logits, p_hat = self.forward(hidden_states, **kwargs)
|
| 246 |
+
|
| 247 |
+
soft_mask = self.sampler(skip_logits, temperature=temperature)
|
| 248 |
+
hard_mask = (soft_mask.detach() > 0.5).long().tolist()
|
| 249 |
+
log_p = self.sampler.log_prob(skip_logits, soft_mask.detach())
|
| 250 |
+
|
| 251 |
+
draft_len = optimal_draft_len(p_hat.detach().item(), self.draft_len_choices)
|
| 252 |
+
return hard_mask, draft_len, log_p
|
| 253 |
+
|
| 254 |
+
def greedy_mask(
|
| 255 |
+
self,
|
| 256 |
+
hidden_states: Tuple[torch.Tensor, ...],
|
| 257 |
+
**kwargs,
|
| 258 |
+
) -> Tuple[List[int], int]:
|
| 259 |
+
"""Deterministic greedy mask and draft length for evaluation."""
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
skip_logits, p_hat = self.forward(hidden_states, **kwargs)
|
| 262 |
+
mask = self.sampler.greedy_mask(skip_logits)
|
| 263 |
+
draft_len = optimal_draft_len(p_hat.item(), self.draft_len_choices)
|
| 264 |
+
return mask, draft_len
|
| 265 |
+
|
| 266 |
+
def compile_for_inference(self) -> None:
|
| 267 |
+
"""Replace forward with a torch.compile'd version for faster inference.
|
| 268 |
+
|
| 269 |
+
Call once after policy.eval() and before the generation loop.
|
| 270 |
+
Use fullgraph=False to tolerate the torch.tensor() call inside
|
| 271 |
+
ScalarFeatureEmbedder without needing to refactor it.
|
| 272 |
+
"""
|
| 273 |
+
self.forward = torch.compile(self.forward, mode="max-autotune", fullgraph=False)
|