| --- |
| language: |
| - en |
| license: apache-2.0 |
| tags: |
| - speculative-decoding |
| - layer-skipping |
| - grpo |
| - clasp |
| - llama |
| - efficiency |
| base_model: meta-llama/Meta-Llama-3-8B |
| --- |
| |
| # CASM β CLaSp Adaptive Skip Mask |
|
|
| CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds. |
|
|
| ## How it works |
|
|
| Self-speculative decoding runs the *same* frozen model in two modes: |
|
|
| 1. **Draft** β selected decoder layers are skipped, producing `K` candidate tokens cheaply. |
| 2. **Verify** β the full model validates the draft block and accepts the longest matching prefix. |
| 3. **Policy update** β CASM observes the verify hidden states and chooses a new skip mask for the next cycle. |
|
|
| CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization. |
|
|
| ## Architecture |
|
|
| | Component | Description | |
| |---|---| |
| | `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` β `[L, 128]` | |
| | `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) | |
| | `PolicyEncoder` | 2-layer Transformer encoder over layer positions | |
| | `logit_head` | Per-layer skip logits β top-M selection | |
| | `AcceptanceRateHead` | Predicts E[Ο/K] β optimal draft length K* | |
|
|
| **Parameters:** ~200 K |
| **Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096) |
| **Skip budget:** 8 layers per draft cycle |
| |
| ## Usage |
| |
| ```python |
| import torch |
| from src.grpo.policy import SkipPolicy |
| |
| # Load policy |
| ckpt = torch.load("policy_best.pt", map_location="cpu") |
| policy = SkipPolicy( |
| hidden_dim=4096, |
| n_layers=32, |
| n_skip=8, |
| policy_dim=128, |
| context_tokens=1, |
| ) |
| policy.load_state_dict(ckpt["policy_state_dict"]) |
| policy.eval() |
| |
| # During self-speculative decoding, call after each verify pass: |
| # hidden_states: tuple of (L+1) tensors from model output_hidden_states=True |
| mask, draft_len = policy.greedy_mask( |
| hidden_states, |
| last_tau=accepted_tokens, |
| draft_len=current_draft_len, |
| position=current_position, |
| max_len=max_new_tokens, |
| ) |
| # mask: list of 0/1 per layer (1 = skip this layer during draft) |
| # draft_len: recommended tokens to draft next cycle |
| ``` |
| |
| See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase. |
|
|
| ## Training |
|
|
| Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps. |
|
|
| | Metric | Value | |
| |---|---| |
| | Training steps | 10 000 | |
| | Eval reward | 99.8 | |
| | Test reward | 100.4 | |
| | GPU | NVIDIA A100 80 GB | |
|
|
| ## Citation |
|
|
| If you use CASM, please cite the CLaSp paper and this repository: |
|
|
| ```bibtex |
| @misc{casm2026, |
| author = {Dayne Guy}, |
| title = {CASM: CLaSp Adaptive Skip Mask}, |
| year = {2026}, |
| url = {https://huggingface.co/dayngerous/CASM} |
| } |
| ``` |
|
|