File size: 3,326 Bytes
81da06c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
---
language:
- en
license: apache-2.0
tags:
- speculative-decoding
- layer-skipping
- grpo
- clasp
- llama
- efficiency
base_model: meta-llama/Meta-Llama-3-8B
---

# CASM — CLaSp Adaptive Skip Mask

CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.

## How it works

Self-speculative decoding runs the *same* frozen model in two modes:

1. **Draft** — selected decoder layers are skipped, producing `K` candidate tokens cheaply.
2. **Verify** — the full model validates the draft block and accepts the longest matching prefix.
3. **Policy update** — CASM observes the verify hidden states and chooses a new skip mask for the next cycle.

CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.

## Architecture

| Component | Description |
|---|---|
| `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` → `[L, 128]` |
| `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) |
| `PolicyEncoder` | 2-layer Transformer encoder over layer positions |
| `logit_head` | Per-layer skip logits → top-M selection |
| `AcceptanceRateHead` | Predicts E[τ/K] → optimal draft length K* |

**Parameters:** ~200 K  
**Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096)  
**Skip budget:** 8 layers per draft cycle  

## Usage

```python
import torch
from src.grpo.policy import SkipPolicy

# Load policy
ckpt = torch.load("policy_best.pt", map_location="cpu")
policy = SkipPolicy(
    hidden_dim=4096,
    n_layers=32,
    n_skip=8,
    policy_dim=128,
    context_tokens=1,
)
policy.load_state_dict(ckpt["policy_state_dict"])
policy.eval()

# During self-speculative decoding, call after each verify pass:
# hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
mask, draft_len = policy.greedy_mask(
    hidden_states,
    last_tau=accepted_tokens,
    draft_len=current_draft_len,
    position=current_position,
    max_len=max_new_tokens,
)
# mask: list of 0/1 per layer (1 = skip this layer during draft)
# draft_len: recommended tokens to draft next cycle
```

See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase.

## Training

Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.

| Metric | Value |
|---|---|
| Training steps | 10 000 |
| Eval reward | 99.8 |
| Test reward | 100.4 |
| GPU | NVIDIA A100 80 GB |

## Citation

If you use CASM, please cite the CLaSp paper and this repository:

```bibtex
@misc{casm2026,
  author = {Dayne Guy},
  title  = {CASM: CLaSp Adaptive Skip Mask},
  year   = {2026},
  url    = {https://huggingface.co/dayngerous/CASM}
}
```