CASM / README.md
dayngerous's picture
Initial upload: policy checkpoint, config, model card, source
81da06c verified
---
language:
- en
license: apache-2.0
tags:
- speculative-decoding
- layer-skipping
- grpo
- clasp
- llama
- efficiency
base_model: meta-llama/Meta-Llama-3-8B
---
# CASM β€” CLaSp Adaptive Skip Mask
CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.
## How it works
Self-speculative decoding runs the *same* frozen model in two modes:
1. **Draft** β€” selected decoder layers are skipped, producing `K` candidate tokens cheaply.
2. **Verify** β€” the full model validates the draft block and accepts the longest matching prefix.
3. **Policy update** β€” CASM observes the verify hidden states and chooses a new skip mask for the next cycle.
CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.
## Architecture
| Component | Description |
|---|---|
| `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` β†’ `[L, 128]` |
| `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) |
| `PolicyEncoder` | 2-layer Transformer encoder over layer positions |
| `logit_head` | Per-layer skip logits β†’ top-M selection |
| `AcceptanceRateHead` | Predicts E[Ο„/K] β†’ optimal draft length K* |
**Parameters:** ~200 K
**Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096)
**Skip budget:** 8 layers per draft cycle
## Usage
```python
import torch
from src.grpo.policy import SkipPolicy
# Load policy
ckpt = torch.load("policy_best.pt", map_location="cpu")
policy = SkipPolicy(
hidden_dim=4096,
n_layers=32,
n_skip=8,
policy_dim=128,
context_tokens=1,
)
policy.load_state_dict(ckpt["policy_state_dict"])
policy.eval()
# During self-speculative decoding, call after each verify pass:
# hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
mask, draft_len = policy.greedy_mask(
hidden_states,
last_tau=accepted_tokens,
draft_len=current_draft_len,
position=current_position,
max_len=max_new_tokens,
)
# mask: list of 0/1 per layer (1 = skip this layer during draft)
# draft_len: recommended tokens to draft next cycle
```
See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase.
## Training
Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.
| Metric | Value |
|---|---|
| Training steps | 10 000 |
| Eval reward | 99.8 |
| Test reward | 100.4 |
| GPU | NVIDIA A100 80 GB |
## Citation
If you use CASM, please cite the CLaSp paper and this repository:
```bibtex
@misc{casm2026,
author = {Dayne Guy},
title = {CASM: CLaSp Adaptive Skip Mask},
year = {2026},
url = {https://huggingface.co/dayngerous/CASM}
}
```