Baladithya Balamurugan
Wave 21: Stage-0 dataset pipeline — swesmith engine, rollout harness, gates, contract
9a2ce20
Raw
History Blame Contribute Delete
8.03 kB
"""k1-in-reward KL penalty — the Composer-2 / verl fidelity choice.
THE FIDELITY GAP (F5 Rubric A item c2, the single highest-leverage fix).
Composer-2 §4.1 explicitly chooses the **k1** KL estimator applied **in the
reward** (``-log r``), citing a variance argument (Amini et al.). TRL's
``GRPOTrainer`` instead applies the **k3** estimator (``exp(Δ) - Δ - 1``,
Δ = ref_logp - logp) **in the loss**, gated on ``beta != 0``. The 2025/26
literature says this is not cosmetic:
* arXiv:2512.21852 ("A Comedy of Estimators") — k1-in-reward improves OOD
generalization; k3-in-reward can collapse.
* verl ships k1-in-reward as its default/recommended reverse-KL option
(it also supports a k3-family "low_var_kl" — wording corrected per
deepread finding V13).
* TRL issue #4967 tracks the same divergence.
OOD generalization is exactly the "take any model to the next level" axis, so
this module gives the trainer an opt-in k1-in-reward path that matches
Composer-2 / verl, leaving TRL's native k3-in-loss disabled (``beta = 0``).
THE ALGEBRA (why this is a clean advantage adjustment, not a TRL fork).
k1-in-reward means: penalize each sequence's reward by ``coef * KL_i`` before
GRPO computes its group-relative advantage:
reward'_i = reward_i - coef * KL_i
KL_i = Σ_t mask_{i,t} · (logp_{i,t} - ref_logp_{i,t}) # k1 estimator
# of KL(π‖π_ref)
GRPO's advantage (with ``scale_rewards="none"``, the Dr.GRPO / Composer regime)
is the group-mean baseline ``adv_i = reward_i - mean_group(reward)``. Because
that baseline is LINEAR, folding-then-baselining equals adjusting the final
advantage:
adv'_i = reward'_i - mean_group(reward')
= adv_i - coef · (KL_i - mean_group(KL))
So the trainer can let TRL compute advantages normally, then apply this exact
correction — no reimplementation of TRL's reward→advantage code.
THE STD-NORM CAVEAT (why we require scale_rewards="none"). The identity above
is EXACT only when there is no per-group std normalization. With std-norm,
folding KL into the reward also changes the group std, so the linear correction
is no longer equivalent. Composer-2 and verl both train WITHOUT std scaling
(Dr.GRPO's recommendation), so we make the math exact for that regime and the
trainer raises if k1-in-reward is requested with std-norm on, rather than
silently applying an approximation.
Note: ``-log r`` (Composer-2's phrasing) with ``r = π/π_ref = exp(logp-ref_logp)``
gives ``-log r = ref_logp - logp = -(logp - ref_logp)`` *per token*. The KL
PENALTY subtracted from reward is ``coef · Σ_t (logp - ref_logp)`` — i.e. the
k1 estimator of the reverse KL, which is what discourages drift from π_ref. The
sign convention here matches the standard RLHF KL-in-reward penalty
(Stiennon et al. 2020; verl ``kl_penalty="kl"``).
"""
from __future__ import annotations
import torch
#: Supported KL estimators for the in-reward penalty. Only k1 is meaningful here
#: (the whole point is to use k1 instead of TRL's native-in-loss k3); k3 is
#: accepted as an explicit no-divergence opt-out for experiments.
KL_ESTIMATORS = ("k1", "k3")
def k1_kl_penalty_per_sequence(
policy_logps: torch.Tensor,
ref_logps: torch.Tensor,
completion_mask: torch.Tensor,
) -> torch.Tensor:
"""Per-sequence k1 estimator of KL(π ‖ π_ref) over completion tokens.
Args:
policy_logps: ``(B, T)`` per-token logprobs under the (sampling) policy π.
ref_logps: ``(B, T)`` per-token logprobs under the reference policy π_ref,
on the SAME tokens/positions as ``policy_logps``.
completion_mask: ``(B, T)`` 1.0 at real completion tokens, 0.0 at prompt /
padding positions (the k1 sum is taken only over real tokens).
Returns:
``(B,)`` per-sequence KL penalty ``Σ_t mask·(logp - ref_logp)``.
The k1 estimator ``logp - ref_logp`` is the unbiased (higher-variance)
single-sample estimate of the reverse KL; summed over the response it is the
sequence-level KL used as the reward penalty.
"""
if policy_logps.shape != ref_logps.shape:
raise ValueError(
f"policy_logps {tuple(policy_logps.shape)} and ref_logps "
f"{tuple(ref_logps.shape)} must have identical shape (same tokens)."
)
if completion_mask.shape != policy_logps.shape:
raise ValueError(
f"completion_mask {tuple(completion_mask.shape)} must match "
f"policy_logps {tuple(policy_logps.shape)}."
)
per_token = (policy_logps - ref_logps) * completion_mask
return per_token.sum(dim=-1)
def k3_kl_penalty_per_sequence(
policy_logps: torch.Tensor,
ref_logps: torch.Tensor,
completion_mask: torch.Tensor,
) -> torch.Tensor:
"""Per-sequence k3 (Schulman) estimator of KL over completion tokens.
``k3 = exp(Δ) - Δ - 1``, Δ = ref_logp - logp. Always ≥ 0, lower variance.
Provided for the in-reward path so an experiment can A/B k1-in-reward against
k3-in-reward (the comparison arXiv:2512.21852 makes) without touching TRL.
"""
if not (policy_logps.shape == ref_logps.shape == completion_mask.shape):
raise ValueError("policy_logps, ref_logps, completion_mask must share shape.")
delta = ref_logps - policy_logps
per_token = (torch.exp(delta) - delta - 1.0) * completion_mask
return per_token.sum(dim=-1)
def kl_penalty_per_sequence(
policy_logps: torch.Tensor,
ref_logps: torch.Tensor,
completion_mask: torch.Tensor,
estimator: str = "k1",
) -> torch.Tensor:
"""Dispatch to the k1 or k3 per-sequence KL penalty."""
if estimator == "k1":
return k1_kl_penalty_per_sequence(policy_logps, ref_logps, completion_mask)
if estimator == "k3":
return k3_kl_penalty_per_sequence(policy_logps, ref_logps, completion_mask)
raise ValueError(
f"Unknown KL estimator {estimator!r}; choose from {KL_ESTIMATORS}. "
"k1 is the Composer-2 / verl in-reward choice this module exists for."
)
def apply_kl_in_reward(
advantages: torch.Tensor,
kl_penalty: torch.Tensor,
num_generations: int,
coef: float,
) -> torch.Tensor:
"""Adjust GRPO advantages to fold a KL penalty into the reward.
Exact (not approximate) under the group-mean baseline with NO std
normalization (``scale_rewards="none"`` — the Dr.GRPO / Composer regime).
See the module docstring for the linearity argument.
Args:
advantages: ``(B,)`` GRPO advantages as TRL computed them
(= reward - group_mean(reward), no std division).
kl_penalty: ``(B,)`` per-sequence KL penalty (from
``kl_penalty_per_sequence``).
num_generations: G — the number of completions per prompt (group size).
``B`` must be divisible by G; groups are contiguous as TRL lays them
out (``rewards.view(-1, num_generations)``).
coef: the KL coefficient β. ``coef=0`` returns advantages unchanged.
Returns:
``(B,)`` adjusted advantages ``adv - coef·(KL - group_mean(KL))``.
"""
if coef == 0.0:
return advantages
if advantages.shape != kl_penalty.shape:
raise ValueError(
f"advantages {tuple(advantages.shape)} and kl_penalty "
f"{tuple(kl_penalty.shape)} must have identical shape (B,)."
)
b = advantages.shape[0]
if num_generations <= 0 or b % num_generations != 0:
raise ValueError(
f"batch size B={b} must be a positive multiple of num_generations="
f"{num_generations} (GRPO lays groups out contiguously)."
)
kl_grouped = kl_penalty.view(-1, num_generations)
kl_centered = (kl_grouped - kl_grouped.mean(dim=1, keepdim=True)).reshape(b)
return advantages - coef * kl_centered
__all__ = [
"KL_ESTIMATORS",
"k1_kl_penalty_per_sequence",
"k3_kl_penalty_per_sequence",
"kl_penalty_per_sequence",
"apply_kl_in_reward",
]