CASM β CLaSp Adaptive Skip Mask
CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with Meta-Llama-3-8B. It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.
How it works
Self-speculative decoding runs the same frozen model in two modes:
- Draft β selected decoder layers are skipped, producing
Kcandidate tokens cheaply. - Verify β the full model validates the draft block and accepts the longest matching prefix.
- Policy update β CASM observes the verify hidden states and chooses a new skip mask for the next cycle.
CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.
Architecture
| Component | Description |
|---|---|
HiddenStateProjector |
Projects per-layer hidden states [L, d_model] β [L, 128] |
ScalarFeatureEmbedder |
Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) |
PolicyEncoder |
2-layer Transformer encoder over layer positions |
logit_head |
Per-layer skip logits β top-M selection |
AcceptanceRateHead |
Predicts E[Ο/K] β optimal draft length K* |
Parameters: ~200 K
Base model: meta-llama/Meta-Llama-3-8B (32 layers, hidden_dim=4096)
Skip budget: 8 layers per draft cycle
Usage
import torch
from src.grpo.policy import SkipPolicy
# Load policy
ckpt = torch.load("policy_best.pt", map_location="cpu")
policy = SkipPolicy(
hidden_dim=4096,
n_layers=32,
n_skip=8,
policy_dim=128,
context_tokens=1,
)
policy.load_state_dict(ckpt["policy_state_dict"])
policy.eval()
# During self-speculative decoding, call after each verify pass:
# hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
mask, draft_len = policy.greedy_mask(
hidden_states,
last_tau=accepted_tokens,
draft_len=current_draft_len,
position=current_position,
max_len=max_new_tokens,
)
# mask: list of 0/1 per layer (1 = skip this layer during draft)
# draft_len: recommended tokens to draft next cycle
See grpo-clasp for the full training and evaluation codebase.
Training
Trained with GRPO on SpecBench-style prompts using meta-llama/Meta-Llama-3-8B on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.
| Metric | Value |
|---|---|
| Training steps | 10 000 |
| Eval reward | 99.8 |
| Test reward | 100.4 |
| GPU | NVIDIA A100 80 GB |
Citation
If you use CASM, please cite the CLaSp paper and this repository:
@misc{casm2026,
author = {Dayne Guy},
title = {CASM: CLaSp Adaptive Skip Mask},
year = {2026},
url = {https://huggingface.co/dayngerous/CASM}
}
- Downloads last month
- -
Model tree for dayngerous/CASM
Base model
meta-llama/Meta-Llama-3-8B