| """Action space definitions for the GRPO skip policy. |
| |
| The action is a binary skip mask S ∈ {0,1}^L. |
| This module provides samplers and constraint enforcement for different |
| action space parameterizations. |
| """ |
|
|
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class TopMActionSampler(nn.Module): |
| """Fixed-budget action sampler: select exactly M layers to skip. |
| |
| Policy outputs per-layer logits; the top-M scoring eligible layers are skipped. |
| During training, uses a straight-through Gumbel-top-K estimator for gradients. |
| During evaluation / rollouts, uses deterministic top-M argmax. |
| |
| Args: |
| n_layers: total number of transformer layers. |
| n_skip: skip budget M. |
| keep_prefix: number of layers at start that cannot be skipped. |
| keep_suffix: number of layers at end that cannot be skipped. |
| """ |
|
|
| def __init__( |
| self, |
| n_layers: int, |
| n_skip: int, |
| keep_prefix: int = 2, |
| keep_suffix: int = 2, |
| ): |
| super().__init__() |
| self.n_layers = n_layers |
| self.n_skip = n_skip |
| self.keep_prefix = keep_prefix |
| self.keep_suffix = keep_suffix |
| |
| eligible = torch.zeros(n_layers, dtype=torch.bool) |
| eligible[keep_prefix : n_layers - keep_suffix] = True |
| self.register_buffer("eligible", eligible) |
|
|
| def forward(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: |
| """Sample a soft skip mask from policy logits. |
| |
| Args: |
| logits: [n_layers] raw skip logits from policy network. |
| temperature: sampling temperature (1.0 = standard, lower = more peaked). |
| |
| Returns: |
| hard_mask: [n_layers] binary tensor (differentiable via straight-through). |
| """ |
| |
| masked_logits = logits.clone() |
| masked_logits[~self.eligible] = float("-inf") |
|
|
|
|
| scale = masked_logits[self.eligible].std().detach().clamp(min=1.0) |
| masked_logits = masked_logits / scale |
|
|
| |
| gumbel = -torch.log(-torch.log(torch.clamp(torch.rand_like(masked_logits), 1e-9, 1.0))) |
| perturbed = (masked_logits + gumbel) / temperature |
|
|
| |
| topk_vals, topk_idx = torch.topk(perturbed, self.n_skip) |
| hard_mask = torch.zeros(self.n_layers, device=logits.device) |
| hard_mask.scatter_(0, topk_idx, 1.0) |
|
|
| |
| soft_mask = torch.sigmoid(masked_logits / temperature) |
| return hard_mask + (soft_mask - soft_mask.detach()) |
|
|
| def greedy_mask(self, logits: torch.Tensor) -> List[int]: |
| """Deterministic top-M mask for inference.""" |
| masked_logits = logits.clone() |
| masked_logits[~self.eligible] = float("-inf") |
| _, topk_idx = torch.topk(masked_logits, self.n_skip) |
| mask = torch.zeros(self.n_layers, dtype=torch.long) |
| mask[topk_idx] = 1 |
| return mask.tolist() |
|
|
| def log_prob(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Log-probability of a discrete mask under the Plackett-Luce model. |
| |
| Uses sequential conditioning: each selected layer is drawn from a |
| categorical over the remaining eligible layers, consistent with |
| Gumbel-top-K sampling. |
| """ |
| |
| |
| |
| |
| eligible_logits = logits[self.eligible] |
| scale = eligible_logits.std().detach().clamp(min=1.0) |
| eligible_logits = (eligible_logits / scale).clamp(-50.0, 50.0) |
| selected_indices = mask[self.eligible].bool().nonzero(as_tuple=True)[0] |
|
|
| log_p = logits.new_zeros(()) |
| |
| |
| |
| |
| exclusion = torch.zeros(eligible_logits.shape[0], dtype=torch.bool, |
| device=logits.device) |
| for idx in selected_indices: |
| masked = eligible_logits.masked_fill(exclusion, float("-inf")) |
| log_p = log_p + F.log_softmax(masked, dim=0)[idx] |
| exclusion = exclusion.clone() |
| exclusion[idx] = True |
| return log_p |
|
|