"""Action space definitions for the GRPO skip policy. The action is a binary skip mask S ∈ {0,1}^L. This module provides samplers and constraint enforcement for different action space parameterizations. """ from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class TopMActionSampler(nn.Module): """Fixed-budget action sampler: select exactly M layers to skip. Policy outputs per-layer logits; the top-M scoring eligible layers are skipped. During training, uses a straight-through Gumbel-top-K estimator for gradients. During evaluation / rollouts, uses deterministic top-M argmax. Args: n_layers: total number of transformer layers. n_skip: skip budget M. keep_prefix: number of layers at start that cannot be skipped. keep_suffix: number of layers at end that cannot be skipped. """ def __init__( self, n_layers: int, n_skip: int, keep_prefix: int = 2, keep_suffix: int = 2, ): super().__init__() self.n_layers = n_layers self.n_skip = n_skip self.keep_prefix = keep_prefix self.keep_suffix = keep_suffix # Mask for eligible layers eligible = torch.zeros(n_layers, dtype=torch.bool) eligible[keep_prefix : n_layers - keep_suffix] = True self.register_buffer("eligible", eligible) def forward(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: """Sample a soft skip mask from policy logits. Args: logits: [n_layers] raw skip logits from policy network. temperature: sampling temperature (1.0 = standard, lower = more peaked). Returns: hard_mask: [n_layers] binary tensor (differentiable via straight-through). """ # Zero out ineligible logits masked_logits = logits.clone() masked_logits[~self.eligible] = float("-inf") scale = masked_logits[self.eligible].std().detach().clamp(min=1.0) masked_logits = masked_logits / scale # Gumbel-top-K for differentiable discrete selection gumbel = -torch.log(-torch.log(torch.clamp(torch.rand_like(masked_logits), 1e-9, 1.0))) perturbed = (masked_logits + gumbel) / temperature # Select top n_skip eligible indices topk_vals, topk_idx = torch.topk(perturbed, self.n_skip) hard_mask = torch.zeros(self.n_layers, device=logits.device) hard_mask.scatter_(0, topk_idx, 1.0) # Straight-through estimator: use hard mask in forward, soft mask in backward soft_mask = torch.sigmoid(masked_logits / temperature) return hard_mask + (soft_mask - soft_mask.detach()) def greedy_mask(self, logits: torch.Tensor) -> List[int]: """Deterministic top-M mask for inference.""" masked_logits = logits.clone() masked_logits[~self.eligible] = float("-inf") _, topk_idx = torch.topk(masked_logits, self.n_skip) mask = torch.zeros(self.n_layers, dtype=torch.long) mask[topk_idx] = 1 return mask.tolist() def log_prob(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Log-probability of a discrete mask under the Plackett-Luce model. Uses sequential conditioning: each selected layer is drawn from a categorical over the remaining eligible layers, consistent with Gumbel-top-K sampling. """ # Apply the same unit-std normalization used in forward() so that log_p_old # (computed at rollout time) and log_p_new (computed during the PPO update) # are on the same scale and the ratio exp(log_p_new - log_p_old) is correct. # Clamp after normalizing for numerical safety in log_softmax. eligible_logits = logits[self.eligible] scale = eligible_logits.std().detach().clamp(min=1.0) eligible_logits = (eligible_logits / scale).clamp(-50.0, 50.0) selected_indices = mask[self.eligible].bool().nonzero(as_tuple=True)[0] log_p = logits.new_zeros(()) # Use a bool exclusion mask (no grad) instead of in-place modification of # eligible_logits (which has requires_grad=True during _update_policy). # In-place ops on a grad tensor mid-loop corrupt autograd's version counter # and can silently produce NaN gradients. exclusion = torch.zeros(eligible_logits.shape[0], dtype=torch.bool, device=logits.device) for idx in selected_indices: masked = eligible_logits.masked_fill(exclusion, float("-inf")) log_p = log_p + F.log_softmax(masked, dim=0)[idx] exclusion = exclusion.clone() exclusion[idx] = True return log_p