"""k1-in-reward KL penalty — the Composer-2 / verl fidelity choice. THE FIDELITY GAP (F5 Rubric A item c2, the single highest-leverage fix). Composer-2 §4.1 explicitly chooses the **k1** KL estimator applied **in the reward** (``-log r``), citing a variance argument (Amini et al.). TRL's ``GRPOTrainer`` instead applies the **k3** estimator (``exp(Δ) - Δ - 1``, Δ = ref_logp - logp) **in the loss**, gated on ``beta != 0``. The 2025/26 literature says this is not cosmetic: * arXiv:2512.21852 ("A Comedy of Estimators") — k1-in-reward improves OOD generalization; k3-in-reward can collapse. * verl ships k1-in-reward as its default/recommended reverse-KL option (it also supports a k3-family "low_var_kl" — wording corrected per deepread finding V13). * TRL issue #4967 tracks the same divergence. OOD generalization is exactly the "take any model to the next level" axis, so this module gives the trainer an opt-in k1-in-reward path that matches Composer-2 / verl, leaving TRL's native k3-in-loss disabled (``beta = 0``). THE ALGEBRA (why this is a clean advantage adjustment, not a TRL fork). k1-in-reward means: penalize each sequence's reward by ``coef * KL_i`` before GRPO computes its group-relative advantage: reward'_i = reward_i - coef * KL_i KL_i = Σ_t mask_{i,t} · (logp_{i,t} - ref_logp_{i,t}) # k1 estimator # of KL(π‖π_ref) GRPO's advantage (with ``scale_rewards="none"``, the Dr.GRPO / Composer regime) is the group-mean baseline ``adv_i = reward_i - mean_group(reward)``. Because that baseline is LINEAR, folding-then-baselining equals adjusting the final advantage: adv'_i = reward'_i - mean_group(reward') = adv_i - coef · (KL_i - mean_group(KL)) So the trainer can let TRL compute advantages normally, then apply this exact correction — no reimplementation of TRL's reward→advantage code. THE STD-NORM CAVEAT (why we require scale_rewards="none"). The identity above is EXACT only when there is no per-group std normalization. With std-norm, folding KL into the reward also changes the group std, so the linear correction is no longer equivalent. Composer-2 and verl both train WITHOUT std scaling (Dr.GRPO's recommendation), so we make the math exact for that regime and the trainer raises if k1-in-reward is requested with std-norm on, rather than silently applying an approximation. Note: ``-log r`` (Composer-2's phrasing) with ``r = π/π_ref = exp(logp-ref_logp)`` gives ``-log r = ref_logp - logp = -(logp - ref_logp)`` *per token*. The KL PENALTY subtracted from reward is ``coef · Σ_t (logp - ref_logp)`` — i.e. the k1 estimator of the reverse KL, which is what discourages drift from π_ref. The sign convention here matches the standard RLHF KL-in-reward penalty (Stiennon et al. 2020; verl ``kl_penalty="kl"``). """ from __future__ import annotations import torch #: Supported KL estimators for the in-reward penalty. Only k1 is meaningful here #: (the whole point is to use k1 instead of TRL's native-in-loss k3); k3 is #: accepted as an explicit no-divergence opt-out for experiments. KL_ESTIMATORS = ("k1", "k3") def k1_kl_penalty_per_sequence( policy_logps: torch.Tensor, ref_logps: torch.Tensor, completion_mask: torch.Tensor, ) -> torch.Tensor: """Per-sequence k1 estimator of KL(π ‖ π_ref) over completion tokens. Args: policy_logps: ``(B, T)`` per-token logprobs under the (sampling) policy π. ref_logps: ``(B, T)`` per-token logprobs under the reference policy π_ref, on the SAME tokens/positions as ``policy_logps``. completion_mask: ``(B, T)`` 1.0 at real completion tokens, 0.0 at prompt / padding positions (the k1 sum is taken only over real tokens). Returns: ``(B,)`` per-sequence KL penalty ``Σ_t mask·(logp - ref_logp)``. The k1 estimator ``logp - ref_logp`` is the unbiased (higher-variance) single-sample estimate of the reverse KL; summed over the response it is the sequence-level KL used as the reward penalty. """ if policy_logps.shape != ref_logps.shape: raise ValueError( f"policy_logps {tuple(policy_logps.shape)} and ref_logps " f"{tuple(ref_logps.shape)} must have identical shape (same tokens)." ) if completion_mask.shape != policy_logps.shape: raise ValueError( f"completion_mask {tuple(completion_mask.shape)} must match " f"policy_logps {tuple(policy_logps.shape)}." ) per_token = (policy_logps - ref_logps) * completion_mask return per_token.sum(dim=-1) def k3_kl_penalty_per_sequence( policy_logps: torch.Tensor, ref_logps: torch.Tensor, completion_mask: torch.Tensor, ) -> torch.Tensor: """Per-sequence k3 (Schulman) estimator of KL over completion tokens. ``k3 = exp(Δ) - Δ - 1``, Δ = ref_logp - logp. Always ≥ 0, lower variance. Provided for the in-reward path so an experiment can A/B k1-in-reward against k3-in-reward (the comparison arXiv:2512.21852 makes) without touching TRL. """ if not (policy_logps.shape == ref_logps.shape == completion_mask.shape): raise ValueError("policy_logps, ref_logps, completion_mask must share shape.") delta = ref_logps - policy_logps per_token = (torch.exp(delta) - delta - 1.0) * completion_mask return per_token.sum(dim=-1) def kl_penalty_per_sequence( policy_logps: torch.Tensor, ref_logps: torch.Tensor, completion_mask: torch.Tensor, estimator: str = "k1", ) -> torch.Tensor: """Dispatch to the k1 or k3 per-sequence KL penalty.""" if estimator == "k1": return k1_kl_penalty_per_sequence(policy_logps, ref_logps, completion_mask) if estimator == "k3": return k3_kl_penalty_per_sequence(policy_logps, ref_logps, completion_mask) raise ValueError( f"Unknown KL estimator {estimator!r}; choose from {KL_ESTIMATORS}. " "k1 is the Composer-2 / verl in-reward choice this module exists for." ) def apply_kl_in_reward( advantages: torch.Tensor, kl_penalty: torch.Tensor, num_generations: int, coef: float, ) -> torch.Tensor: """Adjust GRPO advantages to fold a KL penalty into the reward. Exact (not approximate) under the group-mean baseline with NO std normalization (``scale_rewards="none"`` — the Dr.GRPO / Composer regime). See the module docstring for the linearity argument. Args: advantages: ``(B,)`` GRPO advantages as TRL computed them (= reward - group_mean(reward), no std division). kl_penalty: ``(B,)`` per-sequence KL penalty (from ``kl_penalty_per_sequence``). num_generations: G — the number of completions per prompt (group size). ``B`` must be divisible by G; groups are contiguous as TRL lays them out (``rewards.view(-1, num_generations)``). coef: the KL coefficient β. ``coef=0`` returns advantages unchanged. Returns: ``(B,)`` adjusted advantages ``adv - coef·(KL - group_mean(KL))``. """ if coef == 0.0: return advantages if advantages.shape != kl_penalty.shape: raise ValueError( f"advantages {tuple(advantages.shape)} and kl_penalty " f"{tuple(kl_penalty.shape)} must have identical shape (B,)." ) b = advantages.shape[0] if num_generations <= 0 or b % num_generations != 0: raise ValueError( f"batch size B={b} must be a positive multiple of num_generations=" f"{num_generations} (GRPO lays groups out contiguously)." ) kl_grouped = kl_penalty.view(-1, num_generations) kl_centered = (kl_grouped - kl_grouped.mean(dim=1, keepdim=True)).reshape(b) return advantages - coef * kl_centered __all__ = [ "KL_ESTIMATORS", "k1_kl_penalty_per_sequence", "k3_kl_penalty_per_sequence", "kl_penalty_per_sequence", "apply_kl_in_reward", ]