--- language: - en license: apache-2.0 tags: - speculative-decoding - layer-skipping - grpo - clasp - llama - efficiency base_model: meta-llama/Meta-Llama-3-8B --- # CASM — CLaSp Adaptive Skip Mask CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds. ## How it works Self-speculative decoding runs the *same* frozen model in two modes: 1. **Draft** — selected decoder layers are skipped, producing `K` candidate tokens cheaply. 2. **Verify** — the full model validates the draft block and accepts the longest matching prefix. 3. **Policy update** — CASM observes the verify hidden states and chooses a new skip mask for the next cycle. CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization. ## Architecture | Component | Description | |---|---| | `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` → `[L, 128]` | | `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) | | `PolicyEncoder` | 2-layer Transformer encoder over layer positions | | `logit_head` | Per-layer skip logits → top-M selection | | `AcceptanceRateHead` | Predicts E[τ/K] → optimal draft length K* | **Parameters:** ~200 K **Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096) **Skip budget:** 8 layers per draft cycle ## Usage ```python import torch from src.grpo.policy import SkipPolicy # Load policy ckpt = torch.load("policy_best.pt", map_location="cpu") policy = SkipPolicy( hidden_dim=4096, n_layers=32, n_skip=8, policy_dim=128, context_tokens=1, ) policy.load_state_dict(ckpt["policy_state_dict"]) policy.eval() # During self-speculative decoding, call after each verify pass: # hidden_states: tuple of (L+1) tensors from model output_hidden_states=True mask, draft_len = policy.greedy_mask( hidden_states, last_tau=accepted_tokens, draft_len=current_draft_len, position=current_position, max_len=max_new_tokens, ) # mask: list of 0/1 per layer (1 = skip this layer during draft) # draft_len: recommended tokens to draft next cycle ``` See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase. ## Training Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps. | Metric | Value | |---|---| | Training steps | 10 000 | | Eval reward | 99.8 | | Test reward | 100.4 | | GPU | NVIDIA A100 80 GB | ## Citation If you use CASM, please cite the CLaSp paper and this repository: ```bibtex @misc{casm2026, author = {Dayne Guy}, title = {CASM: CLaSp Adaptive Skip Mask}, year = {2026}, url = {https://huggingface.co/dayngerous/CASM} } ```