CASM / src /grpo /action_space.py
dayngerous's picture
Initial upload: policy checkpoint, config, model card, source
81da06c verified
"""Action space definitions for the GRPO skip policy.
The action is a binary skip mask S ∈ {0,1}^L.
This module provides samplers and constraint enforcement for different
action space parameterizations.
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopMActionSampler(nn.Module):
"""Fixed-budget action sampler: select exactly M layers to skip.
Policy outputs per-layer logits; the top-M scoring eligible layers are skipped.
During training, uses a straight-through Gumbel-top-K estimator for gradients.
During evaluation / rollouts, uses deterministic top-M argmax.
Args:
n_layers: total number of transformer layers.
n_skip: skip budget M.
keep_prefix: number of layers at start that cannot be skipped.
keep_suffix: number of layers at end that cannot be skipped.
"""
def __init__(
self,
n_layers: int,
n_skip: int,
keep_prefix: int = 2,
keep_suffix: int = 2,
):
super().__init__()
self.n_layers = n_layers
self.n_skip = n_skip
self.keep_prefix = keep_prefix
self.keep_suffix = keep_suffix
# Mask for eligible layers
eligible = torch.zeros(n_layers, dtype=torch.bool)
eligible[keep_prefix : n_layers - keep_suffix] = True
self.register_buffer("eligible", eligible)
def forward(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
"""Sample a soft skip mask from policy logits.
Args:
logits: [n_layers] raw skip logits from policy network.
temperature: sampling temperature (1.0 = standard, lower = more peaked).
Returns:
hard_mask: [n_layers] binary tensor (differentiable via straight-through).
"""
# Zero out ineligible logits
masked_logits = logits.clone()
masked_logits[~self.eligible] = float("-inf")
scale = masked_logits[self.eligible].std().detach().clamp(min=1.0)
masked_logits = masked_logits / scale
# Gumbel-top-K for differentiable discrete selection
gumbel = -torch.log(-torch.log(torch.clamp(torch.rand_like(masked_logits), 1e-9, 1.0)))
perturbed = (masked_logits + gumbel) / temperature
# Select top n_skip eligible indices
topk_vals, topk_idx = torch.topk(perturbed, self.n_skip)
hard_mask = torch.zeros(self.n_layers, device=logits.device)
hard_mask.scatter_(0, topk_idx, 1.0)
# Straight-through estimator: use hard mask in forward, soft mask in backward
soft_mask = torch.sigmoid(masked_logits / temperature)
return hard_mask + (soft_mask - soft_mask.detach())
def greedy_mask(self, logits: torch.Tensor) -> List[int]:
"""Deterministic top-M mask for inference."""
masked_logits = logits.clone()
masked_logits[~self.eligible] = float("-inf")
_, topk_idx = torch.topk(masked_logits, self.n_skip)
mask = torch.zeros(self.n_layers, dtype=torch.long)
mask[topk_idx] = 1
return mask.tolist()
def log_prob(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Log-probability of a discrete mask under the Plackett-Luce model.
Uses sequential conditioning: each selected layer is drawn from a
categorical over the remaining eligible layers, consistent with
Gumbel-top-K sampling.
"""
# Apply the same unit-std normalization used in forward() so that log_p_old
# (computed at rollout time) and log_p_new (computed during the PPO update)
# are on the same scale and the ratio exp(log_p_new - log_p_old) is correct.
# Clamp after normalizing for numerical safety in log_softmax.
eligible_logits = logits[self.eligible]
scale = eligible_logits.std().detach().clamp(min=1.0)
eligible_logits = (eligible_logits / scale).clamp(-50.0, 50.0)
selected_indices = mask[self.eligible].bool().nonzero(as_tuple=True)[0]
log_p = logits.new_zeros(())
# Use a bool exclusion mask (no grad) instead of in-place modification of
# eligible_logits (which has requires_grad=True during _update_policy).
# In-place ops on a grad tensor mid-loop corrupt autograd's version counter
# and can silently produce NaN gradients.
exclusion = torch.zeros(eligible_logits.shape[0], dtype=torch.bool,
device=logits.device)
for idx in selected_indices:
masked = eligible_logits.masked_fill(exclusion, float("-inf"))
log_p = log_p + F.log_softmax(masked, dim=0)[idx]
exclusion = exclusion.clone()
exclusion[idx] = True
return log_p