File size: 3,326 Bytes
81da06c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | ---
language:
- en
license: apache-2.0
tags:
- speculative-decoding
- layer-skipping
- grpo
- clasp
- llama
- efficiency
base_model: meta-llama/Meta-Llama-3-8B
---
# CASM — CLaSp Adaptive Skip Mask
CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.
## How it works
Self-speculative decoding runs the *same* frozen model in two modes:
1. **Draft** — selected decoder layers are skipped, producing `K` candidate tokens cheaply.
2. **Verify** — the full model validates the draft block and accepts the longest matching prefix.
3. **Policy update** — CASM observes the verify hidden states and chooses a new skip mask for the next cycle.
CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.
## Architecture
| Component | Description |
|---|---|
| `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` → `[L, 128]` |
| `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) |
| `PolicyEncoder` | 2-layer Transformer encoder over layer positions |
| `logit_head` | Per-layer skip logits → top-M selection |
| `AcceptanceRateHead` | Predicts E[τ/K] → optimal draft length K* |
**Parameters:** ~200 K
**Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096)
**Skip budget:** 8 layers per draft cycle
## Usage
```python
import torch
from src.grpo.policy import SkipPolicy
# Load policy
ckpt = torch.load("policy_best.pt", map_location="cpu")
policy = SkipPolicy(
hidden_dim=4096,
n_layers=32,
n_skip=8,
policy_dim=128,
context_tokens=1,
)
policy.load_state_dict(ckpt["policy_state_dict"])
policy.eval()
# During self-speculative decoding, call after each verify pass:
# hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
mask, draft_len = policy.greedy_mask(
hidden_states,
last_tau=accepted_tokens,
draft_len=current_draft_len,
position=current_position,
max_len=max_new_tokens,
)
# mask: list of 0/1 per layer (1 = skip this layer during draft)
# draft_len: recommended tokens to draft next cycle
```
See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase.
## Training
Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.
| Metric | Value |
|---|---|
| Training steps | 10 000 |
| Eval reward | 99.8 |
| Test reward | 100.4 |
| GPU | NVIDIA A100 80 GB |
## Citation
If you use CASM, please cite the CLaSp paper and this repository:
```bibtex
@misc{casm2026,
author = {Dayne Guy},
title = {CASM: CLaSp Adaptive Skip Mask},
year = {2026},
url = {https://huggingface.co/dayngerous/CASM}
}
```
|