dayngerous commited on
Commit
81da06c
·
verified ·
1 Parent(s): d62f8ca

Initial upload: policy checkpoint, config, model card, source

Browse files
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ tags:
6
+ - speculative-decoding
7
+ - layer-skipping
8
+ - grpo
9
+ - clasp
10
+ - llama
11
+ - efficiency
12
+ base_model: meta-llama/Meta-Llama-3-8B
13
+ ---
14
+
15
+ # CASM — CLaSp Adaptive Skip Mask
16
+
17
+ CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B). It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.
18
+
19
+ ## How it works
20
+
21
+ Self-speculative decoding runs the *same* frozen model in two modes:
22
+
23
+ 1. **Draft** — selected decoder layers are skipped, producing `K` candidate tokens cheaply.
24
+ 2. **Verify** — the full model validates the draft block and accepts the longest matching prefix.
25
+ 3. **Policy update** — CASM observes the verify hidden states and chooses a new skip mask for the next cycle.
26
+
27
+ CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.
28
+
29
+ ## Architecture
30
+
31
+ | Component | Description |
32
+ |---|---|
33
+ | `HiddenStateProjector` | Projects per-layer hidden states `[L, d_model]` → `[L, 128]` |
34
+ | `ScalarFeatureEmbedder` | Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature) |
35
+ | `PolicyEncoder` | 2-layer Transformer encoder over layer positions |
36
+ | `logit_head` | Per-layer skip logits → top-M selection |
37
+ | `AcceptanceRateHead` | Predicts E[τ/K] → optimal draft length K* |
38
+
39
+ **Parameters:** ~200 K
40
+ **Base model:** `meta-llama/Meta-Llama-3-8B` (32 layers, hidden_dim=4096)
41
+ **Skip budget:** 8 layers per draft cycle
42
+
43
+ ## Usage
44
+
45
+ ```python
46
+ import torch
47
+ from src.grpo.policy import SkipPolicy
48
+
49
+ # Load policy
50
+ ckpt = torch.load("policy_best.pt", map_location="cpu")
51
+ policy = SkipPolicy(
52
+ hidden_dim=4096,
53
+ n_layers=32,
54
+ n_skip=8,
55
+ policy_dim=128,
56
+ context_tokens=1,
57
+ )
58
+ policy.load_state_dict(ckpt["policy_state_dict"])
59
+ policy.eval()
60
+
61
+ # During self-speculative decoding, call after each verify pass:
62
+ # hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
63
+ mask, draft_len = policy.greedy_mask(
64
+ hidden_states,
65
+ last_tau=accepted_tokens,
66
+ draft_len=current_draft_len,
67
+ position=current_position,
68
+ max_len=max_new_tokens,
69
+ )
70
+ # mask: list of 0/1 per layer (1 = skip this layer during draft)
71
+ # draft_len: recommended tokens to draft next cycle
72
+ ```
73
+
74
+ See [grpo-clasp](https://github.com/dayne-2stacks/grpo-clasp) for the full training and evaluation codebase.
75
+
76
+ ## Training
77
+
78
+ Trained with GRPO on SpecBench-style prompts using `meta-llama/Meta-Llama-3-8B` on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.
79
+
80
+ | Metric | Value |
81
+ |---|---|
82
+ | Training steps | 10 000 |
83
+ | Eval reward | 99.8 |
84
+ | Test reward | 100.4 |
85
+ | GPU | NVIDIA A100 80 GB |
86
+
87
+ ## Citation
88
+
89
+ If you use CASM, please cite the CLaSp paper and this repository:
90
+
91
+ ```bibtex
92
+ @misc{casm2026,
93
+ author = {Dayne Guy},
94
+ title = {CASM: CLaSp Adaptive Skip Mask},
95
+ year = {2026},
96
+ url = {https://huggingface.co/dayngerous/CASM}
97
+ }
98
+ ```
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "casm_skip_policy",
3
+ "hidden_dim": 4096,
4
+ "n_layers": 32,
5
+ "n_skip": 8,
6
+ "policy_dim": 128,
7
+ "n_heads": 4,
8
+ "n_encoder_layers": 2,
9
+ "keep_prefix": 0,
10
+ "keep_suffix": 0,
11
+ "context_tokens": 1,
12
+ "draft_len_choices": [4, 8, 12, 16, 24, 32, 48, 64],
13
+ "base_model": "meta-llama/Meta-Llama-3-8B",
14
+ "training_steps": 10000,
15
+ "eval_reward": 99.773,
16
+ "test_reward": 100.429
17
+ }
policy_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:762abdcb49e258e4bf2c29402429722f8e6cd0f09344e38cf137575946977783
3
+ size 3698949
src/__init__.py ADDED
File without changes
src/grpo/__init__.py ADDED
File without changes
src/grpo/action_space.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Action space definitions for the GRPO skip policy.
2
+
3
+ The action is a binary skip mask S ∈ {0,1}^L.
4
+ This module provides samplers and constraint enforcement for different
5
+ action space parameterizations.
6
+ """
7
+
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class TopMActionSampler(nn.Module):
16
+ """Fixed-budget action sampler: select exactly M layers to skip.
17
+
18
+ Policy outputs per-layer logits; the top-M scoring eligible layers are skipped.
19
+ During training, uses a straight-through Gumbel-top-K estimator for gradients.
20
+ During evaluation / rollouts, uses deterministic top-M argmax.
21
+
22
+ Args:
23
+ n_layers: total number of transformer layers.
24
+ n_skip: skip budget M.
25
+ keep_prefix: number of layers at start that cannot be skipped.
26
+ keep_suffix: number of layers at end that cannot be skipped.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ n_layers: int,
32
+ n_skip: int,
33
+ keep_prefix: int = 2,
34
+ keep_suffix: int = 2,
35
+ ):
36
+ super().__init__()
37
+ self.n_layers = n_layers
38
+ self.n_skip = n_skip
39
+ self.keep_prefix = keep_prefix
40
+ self.keep_suffix = keep_suffix
41
+ # Mask for eligible layers
42
+ eligible = torch.zeros(n_layers, dtype=torch.bool)
43
+ eligible[keep_prefix : n_layers - keep_suffix] = True
44
+ self.register_buffer("eligible", eligible)
45
+
46
+ def forward(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
47
+ """Sample a soft skip mask from policy logits.
48
+
49
+ Args:
50
+ logits: [n_layers] raw skip logits from policy network.
51
+ temperature: sampling temperature (1.0 = standard, lower = more peaked).
52
+
53
+ Returns:
54
+ hard_mask: [n_layers] binary tensor (differentiable via straight-through).
55
+ """
56
+ # Zero out ineligible logits
57
+ masked_logits = logits.clone()
58
+ masked_logits[~self.eligible] = float("-inf")
59
+
60
+
61
+ scale = masked_logits[self.eligible].std().detach().clamp(min=1.0)
62
+ masked_logits = masked_logits / scale
63
+
64
+ # Gumbel-top-K for differentiable discrete selection
65
+ gumbel = -torch.log(-torch.log(torch.clamp(torch.rand_like(masked_logits), 1e-9, 1.0)))
66
+ perturbed = (masked_logits + gumbel) / temperature
67
+
68
+ # Select top n_skip eligible indices
69
+ topk_vals, topk_idx = torch.topk(perturbed, self.n_skip)
70
+ hard_mask = torch.zeros(self.n_layers, device=logits.device)
71
+ hard_mask.scatter_(0, topk_idx, 1.0)
72
+
73
+ # Straight-through estimator: use hard mask in forward, soft mask in backward
74
+ soft_mask = torch.sigmoid(masked_logits / temperature)
75
+ return hard_mask + (soft_mask - soft_mask.detach())
76
+
77
+ def greedy_mask(self, logits: torch.Tensor) -> List[int]:
78
+ """Deterministic top-M mask for inference."""
79
+ masked_logits = logits.clone()
80
+ masked_logits[~self.eligible] = float("-inf")
81
+ _, topk_idx = torch.topk(masked_logits, self.n_skip)
82
+ mask = torch.zeros(self.n_layers, dtype=torch.long)
83
+ mask[topk_idx] = 1
84
+ return mask.tolist()
85
+
86
+ def log_prob(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
87
+ """Log-probability of a discrete mask under the Plackett-Luce model.
88
+
89
+ Uses sequential conditioning: each selected layer is drawn from a
90
+ categorical over the remaining eligible layers, consistent with
91
+ Gumbel-top-K sampling.
92
+ """
93
+ # Apply the same unit-std normalization used in forward() so that log_p_old
94
+ # (computed at rollout time) and log_p_new (computed during the PPO update)
95
+ # are on the same scale and the ratio exp(log_p_new - log_p_old) is correct.
96
+ # Clamp after normalizing for numerical safety in log_softmax.
97
+ eligible_logits = logits[self.eligible]
98
+ scale = eligible_logits.std().detach().clamp(min=1.0)
99
+ eligible_logits = (eligible_logits / scale).clamp(-50.0, 50.0)
100
+ selected_indices = mask[self.eligible].bool().nonzero(as_tuple=True)[0]
101
+
102
+ log_p = logits.new_zeros(())
103
+ # Use a bool exclusion mask (no grad) instead of in-place modification of
104
+ # eligible_logits (which has requires_grad=True during _update_policy).
105
+ # In-place ops on a grad tensor mid-loop corrupt autograd's version counter
106
+ # and can silently produce NaN gradients.
107
+ exclusion = torch.zeros(eligible_logits.shape[0], dtype=torch.bool,
108
+ device=logits.device)
109
+ for idx in selected_indices:
110
+ masked = eligible_logits.masked_fill(exclusion, float("-inf"))
111
+ log_p = log_p + F.log_softmax(masked, dim=0)[idx]
112
+ exclusion = exclusion.clone()
113
+ exclusion[idx] = True
114
+ return log_p
src/grpo/policy.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO skip policy network.
2
+
3
+ Architecture:
4
+ Input: per-layer hidden state projections z_l ∈ R^{d'} for l=0..L-1
5
+ + scalar context features (last_tau, last_ms, position, age, ...)
6
+ Encoder: 2-layer Transformer encoder over layer index l (treating l as seq pos)
7
+ Output: per-layer skip logits u_l ∈ R^L
8
+ + scalar p̂ ∈ (0,1) predicting E[τ/K] (acceptance rate)
9
+ Action: top-M selection via TopMActionSampler; K derived from p̂ analytically
10
+
11
+ The policy is lightweight by design — it should be orders of magnitude smaller
12
+ than the verify model to keep training cost negligible.
13
+ """
14
+
15
+ from typing import List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from .action_space import TopMActionSampler
22
+
23
+ _DEFAULT_DRAFT_LEN_CHOICES = [4, 8, 12, 16, 24, 32, 48, 64]
24
+
25
+
26
+ def optimal_draft_len(p_hat: float, choices: List[int]) -> int:
27
+ """Return the K from choices closest to the natural optimum K* = 1/(1−p̂).
28
+
29
+ Intuition: under a geometric acceptance model, K* ≈ 1/(1−p) maximises
30
+ expected accepted tokens per verify pass. Beyond this point, extra draft
31
+ tokens are increasingly likely to be rejected.
32
+
33
+ Args:
34
+ p_hat: predicted per-token acceptance probability in (0, 1).
35
+ choices: discrete candidate K values (must be non-empty).
36
+ """
37
+ p_hat = max(0.01, min(p_hat, 0.99))
38
+ k_natural = 1.0 / (1.0 - p_hat)
39
+ return min(choices, key=lambda k: abs(k - k_natural))
40
+
41
+
42
+ class HiddenStateProjector(nn.Module):
43
+ """Project per-layer hidden states from d → d_policy.
44
+
45
+ Input: tuple of [1, seq_len, d] tensors (one per layer, L+1 total)
46
+ Output: [L, d_policy] (one projected vector per transformer layer)
47
+
48
+ When context_tokens > 1, the last K token positions are mean-pooled per
49
+ layer before projection, giving the policy a richer view of recent context.
50
+ If the sequence is shorter than K, all available tokens are used.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_dim: int,
56
+ policy_dim: int,
57
+ n_layers: int,
58
+ context_tokens: int = 1,
59
+ ):
60
+ super().__init__()
61
+ self.n_layers = n_layers
62
+ self.context_tokens = context_tokens
63
+ self.proj = nn.Linear(hidden_dim, policy_dim, bias=False)
64
+
65
+ def forward(
66
+ self, hidden_states: Tuple[torch.Tensor, ...]
67
+ ) -> torch.Tensor:
68
+ """Extract last K-token mean hidden state from each layer, project, return [L, d_p]."""
69
+ # hidden_states has L+1 entries: [embed, layer_0_out, ..., layer_{L-1}_out]
70
+ # We use layers 1..L (i.e., skip the embedding output at index 0)
71
+ layer_hs = [
72
+ hs[0, -self.context_tokens:, :].mean(dim=0) # [d] (mean over last K tokens)
73
+ for hs in hidden_states[1:self.n_layers + 1]
74
+ ]
75
+ stacked = torch.stack(layer_hs, dim=0) # [L, d]
76
+ return self.proj(stacked.float()) # [L, d_policy]
77
+
78
+
79
+ class PolicyEncoder(nn.Module):
80
+ """2-layer Transformer encoder over layer indices (treating L layers as seq).
81
+
82
+ Input: [L, d_policy] + scalar features appended to each position
83
+ Output: [L, d_policy]
84
+ """
85
+
86
+ def __init__(self, d_policy: int, n_heads: int = 4, n_encoder_layers: int = 2):
87
+ super().__init__()
88
+ encoder_layer = nn.TransformerEncoderLayer(
89
+ d_model=d_policy,
90
+ nhead=n_heads,
91
+ dim_feedforward=d_policy * 4,
92
+ dropout=0.0,
93
+ batch_first=True,
94
+ norm_first=True,
95
+ )
96
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_encoder_layers)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ """x: [L, d_policy] → [L, d_policy]"""
100
+ return self.encoder(x.unsqueeze(0)).squeeze(0) # add/remove batch dim
101
+
102
+
103
+ class ScalarFeatureEmbedder(nn.Module):
104
+ """Embeds scalar context features and adds them to each layer position."""
105
+
106
+ FEATURE_NAMES = [
107
+ "last_tau_norm", # last_tau / draft_len (acceptance rate)
108
+ "latency_norm", # last cycle ms / expected ms (rough normalization)
109
+ "position_norm", # current token position / max_len
110
+ "age_norm", # mask age / update_interval
111
+ "temperature", # generation temperature (fixed per run)
112
+ ]
113
+ N_FEATURES = len(FEATURE_NAMES)
114
+
115
+ def __init__(self, d_policy: int):
116
+ super().__init__()
117
+ self.embed = nn.Linear(self.N_FEATURES, d_policy, bias=False)
118
+
119
+ def forward(
120
+ self,
121
+ last_tau: int,
122
+ draft_len: int,
123
+ last_ms: float,
124
+ position: int,
125
+ max_len: int,
126
+ age: int,
127
+ update_interval: int,
128
+ temperature: float,
129
+ ) -> torch.Tensor:
130
+ """Return [d_policy] scalar feature embedding."""
131
+ feats = torch.tensor([
132
+ last_tau / max(draft_len, 1),
133
+ last_ms / 1000.0,
134
+ position / max(max_len, 1),
135
+ age / max(update_interval, 1),
136
+ temperature,
137
+ ], dtype=torch.float32, device=self.embed.weight.device)
138
+ return self.embed(feats)
139
+
140
+
141
+ class AcceptanceRateHead(nn.Module):
142
+ """Predicts E[τ/K] from mean-pooled encoder output via scalar regression.
143
+
144
+ Trained with MSE against observed τ/K each rollout — no policy gradient needed.
145
+ The prediction p̂ is used to derive optimal draft length K* analytically via
146
+ optimal_draft_len().
147
+ """
148
+
149
+ def __init__(self, d_policy: int):
150
+ super().__init__()
151
+ self.head = nn.Linear(d_policy, 1, bias=True)
152
+
153
+ def forward(self, encoded: torch.Tensor) -> torch.Tensor:
154
+ """encoded: [L, d_policy] → scalar p̂ ∈ (0, 1)"""
155
+ pooled = encoded.mean(dim=0) # [d_policy]
156
+ return torch.sigmoid(self.head(pooled)).squeeze(-1) # scalar
157
+
158
+
159
+ class SkipPolicy(nn.Module):
160
+ """Full skip policy: projects hidden states → encodes → outputs skip logits
161
+ and a predicted acceptance rate p̂.
162
+
163
+ Usage::
164
+
165
+ policy = SkipPolicy(hidden_dim=4096, n_layers=32, n_skip=16, policy_dim=128)
166
+ skip_logits, p_hat = policy(hidden_states, last_tau=8, ...)
167
+ mask, draft_len = policy.greedy_mask(hidden_states, ...)
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ hidden_dim: int,
173
+ n_layers: int,
174
+ n_skip: int,
175
+ policy_dim: int = 128,
176
+ n_heads: int = 4,
177
+ n_encoder_layers: int = 2,
178
+ keep_prefix: int = 2,
179
+ keep_suffix: int = 2,
180
+ draft_len_choices: Optional[List[int]] = None,
181
+ context_tokens: int = 1,
182
+ ):
183
+ super().__init__()
184
+ self.n_layers = n_layers
185
+ self.n_skip = n_skip
186
+ self.draft_len_choices = (
187
+ draft_len_choices if draft_len_choices is not None
188
+ else _DEFAULT_DRAFT_LEN_CHOICES
189
+ )
190
+ self.projector = HiddenStateProjector(hidden_dim, policy_dim, n_layers, context_tokens)
191
+ self.scalar_embedder = ScalarFeatureEmbedder(policy_dim)
192
+ self.encoder = PolicyEncoder(policy_dim, n_heads, n_encoder_layers)
193
+ self.logit_head = nn.Linear(policy_dim, 1, bias=True)
194
+ self.sampler = TopMActionSampler(n_layers, n_skip, keep_prefix, keep_suffix)
195
+ self.acceptance_head = AcceptanceRateHead(policy_dim)
196
+
197
+ def forward(
198
+ self,
199
+ hidden_states: Tuple[torch.Tensor, ...],
200
+ last_tau: int = 0,
201
+ draft_len: int = 16,
202
+ last_ms: float = 0.0,
203
+ position: int = 0,
204
+ max_len: int = 256,
205
+ age: int = 0,
206
+ update_interval: int = 1,
207
+ temperature: float = 0.0,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ """Compute per-layer skip logits and predicted acceptance rate.
210
+
211
+ Returns:
212
+ skip_logits: [n_layers]
213
+ p_hat: scalar ∈ (0, 1), predicted E[τ/K]
214
+ """
215
+ z = self.projector(hidden_states) # [L, d_policy]
216
+
217
+ scalar_emb = self.scalar_embedder(
218
+ last_tau, draft_len, last_ms, position, max_len,
219
+ age, update_interval, temperature,
220
+ ) # [d_policy]
221
+ z = z + scalar_emb.unsqueeze(0) # [L, d_policy]
222
+
223
+ encoded = self.encoder(z) # [L, d_policy]
224
+
225
+ skip_logits = self.logit_head(encoded).squeeze(-1) # [L]
226
+ p_hat = self.acceptance_head(encoded) # scalar
227
+ return skip_logits, p_hat
228
+
229
+ def sample_mask(
230
+ self,
231
+ hidden_states: Tuple[torch.Tensor, ...],
232
+ temperature: float = 1.0,
233
+ **kwargs,
234
+ ) -> Tuple[List[int], int, torch.Tensor]:
235
+ """Sample a skip mask and derive draft length from predicted acceptance rate.
236
+
237
+ draft_len is selected deterministically via optimal_draft_len(p̂) — no RL.
238
+ log_p covers only the skip mask action.
239
+
240
+ Returns:
241
+ hard_mask: List[int] of length n_layers
242
+ draft_len: int, derived from p̂
243
+ log_p: scalar tensor, log π(mask | h)
244
+ """
245
+ skip_logits, p_hat = self.forward(hidden_states, **kwargs)
246
+
247
+ soft_mask = self.sampler(skip_logits, temperature=temperature)
248
+ hard_mask = (soft_mask.detach() > 0.5).long().tolist()
249
+ log_p = self.sampler.log_prob(skip_logits, soft_mask.detach())
250
+
251
+ draft_len = optimal_draft_len(p_hat.detach().item(), self.draft_len_choices)
252
+ return hard_mask, draft_len, log_p
253
+
254
+ def greedy_mask(
255
+ self,
256
+ hidden_states: Tuple[torch.Tensor, ...],
257
+ **kwargs,
258
+ ) -> Tuple[List[int], int]:
259
+ """Deterministic greedy mask and draft length for evaluation."""
260
+ with torch.no_grad():
261
+ skip_logits, p_hat = self.forward(hidden_states, **kwargs)
262
+ mask = self.sampler.greedy_mask(skip_logits)
263
+ draft_len = optimal_draft_len(p_hat.item(), self.draft_len_choices)
264
+ return mask, draft_len
265
+
266
+ def compile_for_inference(self) -> None:
267
+ """Replace forward with a torch.compile'd version for faster inference.
268
+
269
+ Call once after policy.eval() and before the generation loop.
270
+ Use fullgraph=False to tolerate the torch.tensor() call inside
271
+ ScalarFeatureEmbedder without needing to refactor it.
272
+ """
273
+ self.forward = torch.compile(self.forward, mode="max-autotune", fullgraph=False)