Baladithya Balamurugan
Wave 20: Tier-0 fidelity fixes — k1-in-reward KL + Composer-2 behavior rewards
41289bf
Raw
History Blame Contribute Delete
44.2 kB
"""composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels.
Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A".
Verified extension point: GRPOTrainer._compute_loss(model, inputs)
(DeepWiki audit of huggingface/trl, 2026-05-25).
Total loss:
total_loss = grpo_loss
+ alpha_sdpo * sdpo_kl_at_error_turns
+ beta_replay * trace_replay_dpo_loss
Where:
- grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches).
- sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and
teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only.
- trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from
N external teacher disagreement with the student.
The data collator (data_collator.py) is responsible for:
- Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint.
- Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere).
- Loading DPO pairs from the trace-replay output (see teacher_replay.py).
- Precomputing reference-policy logprobs for DPO.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812 — repo-wide torch convention
if TYPE_CHECKING: # type-only — never imported at runtime (keeps the dep lazy)
from composer_replication.safety import HeldOutGuard
# These imports work when TRL is installed — they're not skeleton imports.
# When TRL is missing we fall back to `object` so the module still imports
# (e.g. for documentation generation) but raise a clear ImportError at
# instantiation time rather than the cryptic `object.__init__()` error.
try:
from trl import GRPOTrainer # type: ignore
_TRL_AVAILABLE = True
except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL
GRPOTrainer = object # type: ignore — fallback so module imports without TRL
_TRL_AVAILABLE = False
from composer_replication.opsd import generalized_jsd_loss
from composer_replication.trainer.kl_in_reward import (
apply_kl_in_reward,
kl_penalty_per_sequence,
)
logger = logging.getLogger(__name__)
class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
"""TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.
Args (in addition to GRPOTrainer's):
alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled).
Opt in by passing >0 once your data collator produces
`sdpo_loss_mask` and `ctx_teacher_input_ids` columns.
beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled).
Opt in by passing >0 once your data collator produces
`dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc.
sdpo_jsd_beta: beta param of generalized_jsd_loss
(0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per
upstream OPSD convention; see composer_replication/opsd.py).
sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
sdpo_token_clip: per-token JSD clip for stability; None = no clip.
replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
kl_in_reward: when True, apply the KL-to-reference penalty in the
**reward** (Composer-2 §4.1 / verl choice) instead of TRL's native
**in-loss** k3 term. The penalty is folded into GRPO's advantages at
scoring time (``adv -= beta·(KL - group_mean(KL))``) and TRL's
in-loss KL is suppressed for that step. The F5 audit's #1 fidelity
fix: the 2025/26 evidence (arXiv:2512.21852, verl, TRL #4967) shows
k1-in-reward improves OOD generalization where k3-in-reward can
collapse. REQUIRES ``beta>0`` (the KL coefficient — also how TRL
decides to compute reference logprobs) and ``scale_rewards`` in
{none,false} (the advantage-adjustment identity is exact only
without std-normalization — the Dr.GRPO / Composer regime). Default
False = TRL's native in-loss KL, byte-for-byte legacy behavior.
kl_estimator: ``"k1"`` (default; ``logp - ref_logp``, the Composer-2 /
verl choice this path exists for) or ``"k3"`` (Schulman; lets an
experiment A/B k1-in-reward vs k3-in-reward). Only consulted when
``kl_in_reward=True``.
heldout_guard: optional ``HeldOutGuard`` (the #2 collapse safeguard from
``composer_replication.safety``). Default None = OFF (no behavior
change whatsoever). When supplied, the trainer folds one checkpoint's
metrics into the guard at the ``args.logging_steps`` cadence (the same
place the loss components are logged) and HALTS the run on a fired
verdict — the run-level reward-hacking / collapse tripwire actually
firing instead of sitting inert.
heldout_eval_fn: zero-arg callable returning the held-out (real) eval
score as a float, evaluated each guard cadence. Injectable so the
trainer never hardcodes an eval — pass a closure over your disjoint
held-out pool (the ``HeldoutSplit`` discipline). Required whenever
``heldout_guard`` is set; the guard's whole signal is in-loop reward
vs. this held-out score.
strict_killswitch: when True (default), a fired guard verdict raises
``CollapseStopError`` to hard-stop training (exception-based control
flow, matching ``HeldOutGuard.raise_if_fired``). When False the
verdict is logged and ``self.control.should_training_stop`` is set so
the HF loop ends gracefully after the step (soft stop). Only consulted
when ``heldout_guard`` is set.
"""
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.0,
beta_replay: float = 0.0,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
strict_sdpo_alignment: bool = True,
kl_in_reward: bool = False,
kl_estimator: str = "k1",
heldout_guard: HeldOutGuard | None = None,
heldout_eval_fn: Callable[[], float] | None = None,
strict_killswitch: bool = True,
**kwargs: Any,
):
if not _TRL_AVAILABLE:
raise ImportError(
"ComposerReplicationTrainer requires TRL. Install with "
"`pip install -e .[train]`."
)
super().__init__(*args, **kwargs)
self.alpha_sdpo = alpha_sdpo
self.beta_replay = beta_replay
self.sdpo_jsd_beta = sdpo_jsd_beta
self.sdpo_temperature = sdpo_temperature
self.sdpo_token_clip = sdpo_token_clip
self.replay_dpo_beta = replay_dpo_beta
# When True (default), an SDPO student/teacher shape mismatch is a hard
# error — it means the data collator failed to align the post-hint
# section, which silently zeroes the distillation signal (the exact
# trust-gap flagged in ADR-008). Set False only for production runs
# where a single malformed batch should warn-and-skip rather than abort.
self.strict_sdpo_alignment = strict_sdpo_alignment
# --- k1-in-reward KL (F5 #1 fidelity fix; Composer-2 §4.1 / verl) ----
# OFF by default → TRL's native in-loss k3 KL, byte-for-byte legacy.
# When ON we keep self.beta as the KL coef (TRL needs beta>0 to even
# create the ref model + compute ref logps), fold the k1 penalty into
# advantages during scoring, and zero TRL's in-loss KL per step.
self.kl_in_reward = kl_in_reward
self.kl_estimator = kl_estimator
if kl_in_reward:
validate_kl_in_reward_config(
kl_estimator=kl_estimator,
beta=float(getattr(self.args, "beta", 0.0)),
scale_rewards=getattr(self.args, "scale_rewards", "group"),
)
# --- run-level collapse kill-switch (#2 safeguard) -------------------
# OPTIONAL + OFF BY DEFAULT: when heldout_guard is None the loss path is
# byte-for-byte the legacy behavior. When set, _maybe_update_killswitch
# folds metrics into the guard at the logging cadence (see _compute_loss).
self.heldout_guard = heldout_guard
self.heldout_eval_fn = heldout_eval_fn
self.strict_killswitch = strict_killswitch
if heldout_guard is not None and heldout_eval_fn is None:
raise ValueError(
"heldout_guard was provided without heldout_eval_fn: the guard's "
"tripwire compares in-loop reward against a DISJOINT held-out "
"(real) eval score, so it needs an injectable zero-arg "
"heldout_eval_fn() -> float. Pass a closure over your held-out "
"pool (the HeldoutSplit discipline)."
)
# ----------------------------------------------------------------------
# Loss override (the integration core)
# ----------------------------------------------------------------------
# ----------------------------------------------------------------------
# k1-in-reward: fold the KL penalty into advantages at scoring time, and
# suppress TRL's native in-loss k3 KL inside _compute_loss.
# ----------------------------------------------------------------------
def _generate_and_score_completions(
self,
inputs: list[dict[str, Any]],
) -> dict[str, Any]:
"""Override: after TRL scores completions, fold a k1 KL penalty into the
advantages (Composer-2 in-reward KL) when ``kl_in_reward`` is set.
No-op (exact legacy path) when ``kl_in_reward`` is False. When set, TRL
has already computed ``advantages``, ``ref_per_token_logps`` (because
``beta>0``), and the completion logprobs; we recompute the per-sequence
k1 penalty and apply the exact group-mean-baseline correction.
"""
output = super()._generate_and_score_completions(inputs)
if not getattr(self, "kl_in_reward", False):
return output
ref_logps = output.get("ref_per_token_logps")
# The "old" (sampling-time) policy logps are TRL's in-loss π term; they
# may be lazily None when generation/optimization are aligned and not
# vLLM (see TRL _compute_loss: old := per_token_logps.detach()). In that
# aligned case we cannot read π logps here, so we defer to _compute_loss
# (which always has per_token_logps) by stashing what we need.
old_logps = output.get("old_per_token_logps")
completion_mask = output.get("completion_mask")
if ref_logps is None or completion_mask is None:
# beta>0 guarantees ref_logps; this branch only trips on a TRL
# internals change — fail loud rather than silently skip the penalty.
raise RuntimeError(
"kl_in_reward=True but TRL did not return ref_per_token_logps / "
"completion_mask from scoring (beta>0 should guarantee them). "
"TRL internals may have changed; re-verify the in-reward path."
)
if old_logps is not None:
penalty = kl_penalty_per_sequence(
policy_logps=old_logps,
ref_logps=ref_logps,
completion_mask=completion_mask,
estimator=self.kl_estimator,
)
output["advantages"] = apply_kl_in_reward(
advantages=output["advantages"],
kl_penalty=penalty,
num_generations=self.num_generations,
coef=float(self.args.beta),
)
output["_kl_in_reward_applied"] = torch.tensor(True)
else:
# Aligned non-vLLM case: π logps materialize only in _compute_loss.
# Stash ref logps + mask so _compute_loss can apply the penalty there.
output["_kl_in_reward_applied"] = torch.tensor(False)
return output
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Override: total_loss = grpo + α*sdpo + β*replay.
When ``kl_in_reward`` is set, TRL's native in-loss KL term (gated on
``self.beta``) is suppressed by temporarily zeroing ``self.beta`` for the
duration of the parent call — the KL has already been (or is about to be)
accounted for in the reward/advantage, so double-counting it in the loss
would be wrong. ``self.beta`` is restored in ``finally``.
"""
# Channel 1: standard GRPO loss. ``getattr`` (not ``self.kl_in_reward``)
# so an instance built via ``__new__`` + manual wiring (the SDPO /
# kill-switch unit-test pattern that skips __init__) defaults to the
# legacy path instead of raising AttributeError.
if getattr(self, "kl_in_reward", False):
grpo_loss = self._grpo_loss_kl_in_reward(model, inputs)
else:
grpo_loss = super()._compute_loss(model, inputs)
# Channel 2: SDPO hint-distill at error sites
sdpo_kl = self._compute_sdpo_loss(model, inputs)
# Channel 3: trace-replay DPO from teacher disagreement
replay_dpo = self._compute_trace_replay_loss(model, inputs)
# Compose
total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo
# Log per-channel components (so we can ablate post-hoc)
if hasattr(self, "state") and getattr(self, "args", None) is not None:
log_steps = getattr(self.args, "logging_steps", 50)
if self.state.global_step % log_steps == 0:
self.log({ # type: ignore[attr-defined]
"loss/grpo": float(grpo_loss.detach()),
"loss/sdpo_kl": float(sdpo_kl.detach()),
"loss/trace_replay_dpo": float(replay_dpo.detach()),
"loss/total": float(total.detach()),
"loss/alpha_sdpo": self.alpha_sdpo,
"loss/beta_replay": self.beta_replay,
})
# Fold one checkpoint into the run-level collapse kill-switch at
# the SAME cadence (no-op unless a guard was configured).
self._maybe_update_killswitch()
return total
def _grpo_loss_kl_in_reward(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""GRPO loss with the KL applied in the reward, not the loss.
Two responsibilities:
1. Suppress TRL's native in-loss k3 KL term for this step by zeroing
``self.beta`` across the parent ``_compute_loss`` call (restored in
``finally``). ``self.beta`` gates the in-loss KL add (TRL
``_compute_loss``: ``if self.beta != 0.0: per_token_loss += beta*kl``).
2. Handle the deferred case: when generation/optimization are aligned
and not using vLLM, the sampling-time policy logps are None at
scoring time, so ``_generate_and_score_completions`` could not fold
the penalty into advantages. Here ``per_token_logps`` is available,
so we apply the exact same advantage correction in-place on
``inputs["advantages"]`` BEFORE the parent computes the surrogate.
"""
# Deferred-penalty path: advantages not yet KL-adjusted (aligned, no vLLM).
applied = inputs.get("_kl_in_reward_applied")
already_applied = bool(applied.item()) if applied is not None else False
if not already_applied and "ref_per_token_logps" in inputs:
with torch.no_grad():
prompt_ids, completion_ids = inputs["prompt_ids"], inputs["completion_ids"]
completion_mask = inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([inputs["prompt_mask"], completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
policy_logps, _ = self._get_per_token_logps_and_entropies(
model, input_ids, attention_mask, logits_to_keep
)
penalty = kl_penalty_per_sequence(
policy_logps=policy_logps,
ref_logps=inputs["ref_per_token_logps"],
completion_mask=completion_mask,
estimator=self.kl_estimator,
)
advantages = inputs["advantages"]
# advantages may be (B,) or (B,1) — squeeze for the penalty math,
# restore the original shape after.
adv_flat = advantages.reshape(advantages.shape[0])
adj = apply_kl_in_reward(
advantages=adv_flat,
kl_penalty=penalty,
num_generations=self.num_generations,
coef=float(self.args.beta),
)
inputs["advantages"] = adj.reshape(advantages.shape)
# Suppress TRL's in-loss KL: zero beta for the parent call, restore after.
saved_beta = self.beta
try:
self.beta = 0.0
return super()._compute_loss(model, inputs)
finally:
self.beta = saved_beta
# ----------------------------------------------------------------------
# Run-level collapse kill-switch (#2 safeguard) — optional, OFF by default
# ----------------------------------------------------------------------
def _maybe_update_killswitch(self) -> None:
"""Fold this checkpoint's metrics into ``heldout_guard`` and act on a fire.
No-op when no guard was configured (the default) — this is the
backward-compat guarantee: without ``heldout_guard`` the trainer behaves
exactly as before. When a guard IS set:
* ``in_loop_reward`` is the GRPO reward signal TRL already aggregates
into ``self._metrics[mode]["reward"]`` each step (we read the latest;
no extra forward pass).
* ``heldout_score`` comes from the injected ``heldout_eval_fn()`` — the
trainer never hardcodes an eval.
* ``kl_to_init`` (token-mean nats/token, the ``token_mean_kl``
convention the guard expects) is read from TRL's logged ``"kl"``
metric when present, else left None (KL path stays inert).
On a fired verdict the verdict is logged. If ``strict_killswitch`` (the
default) the verdict is converted into a ``CollapseStopError`` via
``HeldOutGuard.raise_if_fired`` (hard stop); otherwise the HF training
loop is asked to stop gracefully after this step.
"""
guard = self.heldout_guard
if guard is None:
return # OFF by default — zero behavior change
round_idx = int(getattr(self.state, "global_step", 0))
in_loop_reward = self._latest_metric("reward")
if in_loop_reward is None:
# No reward aggregated yet (e.g. very first micro-step before TRL has
# populated its metrics). Skip this cadence rather than feed a
# fabricated 0.0 that would pollute the guard's baseline/EMA.
logger.debug(
"kill-switch: no in-loop reward metric yet at step %d; skipping.",
round_idx,
)
return
assert self.heldout_eval_fn is not None # enforced in __init__
heldout_score = float(self.heldout_eval_fn())
kl_to_init = self._latest_metric("kl") # token-mean KL, or None
status = guard.update(
round_idx=round_idx,
in_loop_reward=in_loop_reward,
heldout_score=heldout_score,
kl_to_init=kl_to_init,
)
self.log({ # type: ignore[attr-defined]
"killswitch/in_loop_reward": status.in_loop_ema,
"killswitch/heldout_score": status.heldout_ema,
"killswitch/proxy_real_gap": status.proxy_real_gap,
"killswitch/fire": float(status.fire),
})
if status.fire:
logger.error(
"HeldOutGuard FIRED at step %d — halting run. reason: %s",
round_idx, status.reason,
)
if self.strict_killswitch:
# Typed exception — exception-based hard stop.
guard.raise_if_fired(status)
else:
# Soft stop: let the HF loop terminate gracefully after this step.
control = getattr(self, "control", None)
if control is not None:
control.should_training_stop = True
def _latest_metric(self, name: str) -> float | None:
"""Most-recent value of a TRL-aggregated train metric, or None.
TRL's GRPOTrainer appends per-step aggregates to
``self._metrics["train"][name]`` (e.g. ``"reward"``, ``"kl"``). We read
the tail defensively so a TRL internals rename degrades to None (KL/reward
path goes inert) rather than crashing training.
"""
metrics = getattr(self, "_metrics", None)
if not isinstance(metrics, dict):
return None
train = metrics.get("train")
if not isinstance(train, dict):
return None
series = train.get(name)
if not series:
return None
try:
return float(series[-1])
except (TypeError, ValueError, IndexError):
return None
# ----------------------------------------------------------------------
# Channel 2: SDPO hint-distill
# ----------------------------------------------------------------------
def _compute_sdpo_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Compute generalized_jsd_loss between student and hint-conditioned teacher.
Both come from the SAME model — teacher just has hint inserted into context.
Skipped (returns 0) if the batch has no error sites (data collator emits
empty ctx_teacher_input_ids).
"""
if (
self.alpha_sdpo == 0.0
or "ctx_teacher_input_ids" not in inputs
or inputs["ctx_teacher_input_ids"].numel() == 0
):
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
# Student forward (with grad, on the original-context input)
student_logits = model(input_ids=inputs["input_ids"]).logits
# Teacher forward (no grad — same model, hint-conditioned context)
with torch.no_grad():
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
# ------------------------------------------------------------------
# ALIGNMENT (cross-family review 2026-05-29 — the 4/4-reviewer P0).
#
# The teacher context has a hint inserted at the error turn, so the
# teacher's post-hint response tokens are shifted right by len(hint)
# relative to the student's. A bare `student.shape == teacher.shape`
# check does NOT establish token-level alignment: equal-length tensors
# whose response regions are offset will be JSD'd position-by-position
# against each other, distilling garbage into the policy.
#
# The ONLY correct alignment is an explicit map from the collator that
# selects, for each response token, the matching index in each sequence.
# We require it whenever SDPO is active:
# - `student_response_idx` / `teacher_response_idx`: LongTensors of
# equal length selecting the aligned response positions in each
# sequence (the collator builds these knowing where it inserted the
# hint). JSD is computed over the gathered, provably-aligned logits.
# - If the collator cannot yet supply them, strict mode raises (loud
# failure) rather than silently distilling misaligned tokens.
s_idx = inputs.get("student_response_idx")
t_idx = inputs.get("teacher_response_idx")
if s_idx is None or t_idx is None:
msg = (
"SDPO alignment indices missing: the collator must emit "
"`student_response_idx` and `teacher_response_idx` (matching "
"LongTensors selecting the aligned post-hint response tokens) so "
"the JSD compares corresponding tokens. A shape-only check does "
"NOT establish alignment — the hint shifts the teacher's response "
"tokens right, so equal-length sequences can still be misaligned "
"and silently distill garbage into the policy (ADR-008 trust-gap)."
)
if self.strict_sdpo_alignment:
raise ValueError(
msg + " (strict_sdpo_alignment=True; pass False to fall back "
"to the legacy shape-only check for resilience.)"
)
logger.warning("%s Falling back to shape-only alignment check.", msg)
if student_logits.shape != teacher_logits.shape:
logger.warning(
"SDPO shape mismatch student=%s teacher=%s; skipping.",
tuple(student_logits.shape), tuple(teacher_logits.shape),
)
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
return generalized_jsd_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
labels=inputs.get("sdpo_loss_mask"),
beta=self.sdpo_jsd_beta,
temperature=self.sdpo_temperature,
token_clip=self.sdpo_token_clip,
reduction="batchmean",
)
# Validate the index tensors describe a consistent 1:1 alignment.
if s_idx.shape != t_idx.shape:
raise ValueError(
f"SDPO alignment index shape mismatch: student_response_idx="
f"{tuple(s_idx.shape)} vs teacher_response_idx={tuple(t_idx.shape)}. "
"They must select the same number of aligned response tokens."
)
# Gather the provably-aligned response logits from each sequence, then
# JSD only those positions (this is the masked error-turn distillation).
# gather over the sequence dim (dim=1): expand index to the vocab dim.
#
# ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row
# *_valid mask. Negative indices are illegal for torch.gather, so clamp
# to 0 before gathering, then neutralize those positions by feeding
# labels=-100 (the standard HF ignore convention that generalized_jsd_loss
# already honors). This makes sentinel/padding positions contribute 0.
#
# Final-verify 2026-05-29: combine BOTH valid masks (not just student's)
# AND the sentinel guard. If a future collator ever emits divergent
# student/teacher valid tails, a teacher sentinel clamped to 0 would
# otherwise be silently distilled against teacher position 0. Belt-and-
# suspenders: valid iff student-valid AND teacher-valid AND both indices
# non-sentinel.
s_valid = inputs.get("student_response_valid")
t_valid = inputs.get("teacher_response_valid")
aligned_mask = (s_idx >= 0) & (t_idx >= 0)
if s_valid is not None:
aligned_mask = aligned_mask & s_valid.bool()
if t_valid is not None:
aligned_mask = aligned_mask & t_valid.bool()
vocab = student_logits.size(-1)
s_safe = s_idx.clamp_min(0)
t_safe = t_idx.clamp_min(0)
s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab)
t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab)
student_aligned = torch.gather(student_logits, 1, s_gather)
teacher_aligned = torch.gather(teacher_logits, 1, t_gather)
# Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at
# sentinel/padding positions so they drop out of the JSD reduction.
aligned_labels = torch.where(
aligned_mask,
torch.ones_like(s_idx),
torch.full_like(s_idx, -100),
)
return generalized_jsd_loss(
student_logits=student_aligned,
teacher_logits=teacher_aligned,
labels=aligned_labels, # sentinel-masked aligned error-turn positions
beta=self.sdpo_jsd_beta,
temperature=self.sdpo_temperature,
token_clip=self.sdpo_token_clip,
reduction="batchmean",
)
# ----------------------------------------------------------------------
# Channel 3: trace-replay DPO
# ----------------------------------------------------------------------
def _compute_trace_replay_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Standard DPO loss using (chosen, rejected) pairs from teacher disagreement.
DPO loss formula (Rafailov et al. 2023):
L = -log σ(β · (logπ(chosen) - logπ_ref(chosen)
- logπ(rejected) + logπ_ref(rejected)))
Where logπ_ref are precomputed by the data collator using the
reference (init student) policy.
"""
if (
self.beta_replay == 0.0
or "dpo_chosen_input_ids" not in inputs
or inputs["dpo_chosen_input_ids"].numel() == 0
):
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
# Forward passes for chosen and rejected, gather logprobs at response tokens
chosen_logprobs = self._sequence_logprobs(
model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"]
)
rejected_logprobs = self._sequence_logprobs(
model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"]
)
ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"]
ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"]
logits = self.replay_dpo_beta * (
(chosen_logprobs - ref_chosen_logprobs)
- (rejected_logprobs - ref_rejected_logprobs)
)
return -F.logsigmoid(logits).mean()
@staticmethod
def _sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Sum logprob of response tokens given the prompt prefix.
Standard DPO accounting: we only score the response tokens (where
response_mask == 1), not the prompt tokens.
"""
outputs = model(input_ids=input_ids)
# Shift for next-token prediction: logits[t] predicts input_ids[t+1]
logits = outputs.logits[:, :-1, :]
targets = input_ids[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
# Mask out prompt + padding; sum response-token logprobs
masked = token_logprobs * response_mask[:, 1:].float()
return masked.sum(dim=-1)
def _device_of(model: torch.nn.Module) -> torch.device:
"""Return the device of any parameter of the model — robust to FSDP/DDP wrappers."""
return next(model.parameters()).device
def validate_kl_in_reward_config(
*,
kl_estimator: str,
beta: float,
scale_rewards: Any,
) -> None:
"""Validate the (kl_estimator, beta, scale_rewards) combo for k1-in-reward.
Extracted so the preconditions are unit-testable without standing up a real
GRPOTrainer (which needs a model + dataset). Raises ``ValueError`` on any
invalid combination; returns None when the config is sound.
Preconditions (see ``kl_in_reward.py`` for the algebra):
* ``kl_estimator`` in {k1, k3}.
* ``beta != 0`` — TRL only builds the reference model and computes ref
logprobs when beta>0, and the in-reward penalty needs ref logps. beta
doubles as the in-reward KL coefficient (the in-loss k3 term is
suppressed per step).
* ``scale_rewards`` in {none, false} — the advantage-adjustment identity
is exact only without per-group std-normalization (the Dr.GRPO /
Composer regime).
"""
if kl_estimator not in ("k1", "k3"):
raise ValueError(f"kl_estimator must be 'k1' or 'k3', got {kl_estimator!r}.")
if float(beta) == 0.0:
raise ValueError(
"kl_in_reward=True requires a non-zero `beta` (the KL coefficient): "
"TRL only creates the reference model and computes ref logprobs when "
"beta>0, and k1-in-reward needs those ref logps. Set beta to your KL "
"coefficient (e.g. make_po_config('dr_grpo', beta=0.04)); the in-loss "
"k3 term is suppressed automatically so beta acts purely as the "
"in-reward k1 coefficient."
)
if str(scale_rewards).lower() not in ("none", "false"):
raise ValueError(
"kl_in_reward=True requires scale_rewards in {none,false} "
f"(got {scale_rewards!r}). The advantage-adjustment identity "
"adv -= beta·(KL - group_mean(KL)) is EXACT only without per-group "
"std-normalization (the Dr.GRPO / Composer regime). With std-norm, "
"folding KL into the reward also shifts the group std, so the linear "
"correction no longer matches true in-reward KL. Use "
"make_po_config('dr_grpo', beta=…) (scale_rewards='none')."
)
def make_dr_grpo_config(**overrides: Any):
"""Build a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe.
Per the Composer 2 technical report (arXiv:2603.24477,
research/10-composer2-techreport-mining.md) the RL base is Dr. GRPO
(Liu et al., arXiv:2503.20783):
- ``loss_type="dr_grpo"`` — removes GRPO's length-standardization term
(which injects a length bias). TRL's own help text cites the Dr. GRPO
paper for this.
- ``scale_rewards="none"`` — NO std-dev advantage normalization. TRL docs:
"The Dr. GRPO paper recommends not scaling rewards, as scaling by the
standard deviation introduces a question-level difficulty bias."
- ``num_iterations=1`` — single-epoch regime (a prompt is never
trained on twice), matching the tech report.
- ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012
finding #1, verified against the installed trl==1.5.0 source):
``GRPOTrainer._compute_loss`` uses the **k3** estimator
``exp(ref_logp - logp) - (ref_logp - logp) - 1``
(trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator
``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance,
always-non-negative KL approximation; k1 is its unbiased but
higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses
KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2))
and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL
to force k1; we document the honest delta. See
``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``.
Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
explicitly overridden, and a sanity assertion guards against silent drift.
"""
from trl import GRPOConfig # local import: only when actually building a config
dr_grpo_defaults: dict[str, Any] = {
"loss_type": "dr_grpo",
"scale_rewards": "none",
"num_iterations": 1,
}
merged = {**dr_grpo_defaults, **overrides}
cfg = GRPOConfig(**merged)
# Guard: fail loudly if a future TRL renames/repurposes these knobs.
assert cfg.loss_type == merged["loss_type"], (
f"GRPOConfig loss_type drifted: requested {merged['loss_type']!r}, "
f"got {cfg.loss_type!r} — TRL may have renamed/repurposed the knob."
)
# Dr. GRPO requires NO std-dev advantage normalization. TRL accepts either
# the string "none" or the bool False to disable it; normalize before
# comparing so a future TRL that switches the representation still passes
# (and a genuinely-wrong value like "batch"/"group"/True fails loudly).
# (Cross-family review 2026-05-29: the prior literal `("none","False","False")`
# had a duplicated "False" and did a brittle case-sensitive str compare.)
assert str(cfg.scale_rewards).lower() in ("none", "false"), (
f"Dr. GRPO requires scale_rewards disabled (no std-norm); got "
f"{cfg.scale_rewards!r}. TRL knob may have drifted — re-verify against trl version."
)
assert cfg.num_iterations == merged["num_iterations"], "GRPOConfig dropped num_iterations"
return cfg
# ---------------------------------------------------------------------------
# Policy-optimization objective MENU (ADR-014)
# ---------------------------------------------------------------------------
#
# The base RL objective used to be hardcoded to Dr.GRPO (make_dr_grpo_config).
# make_po_config gives RL a real menu: GRPO-family objectives selectable by name.
# Verified against the installed trl==1.5.0 (introspected 2026-05-30): its
# GRPOTrainer already implements these as `loss_type` branches + knobs, so EVERY
# preset below is pure config — no custom _compute_loss override needed.
#
# Knob-space each preset sets (all real GRPOConfig fields in trl 1.5.0):
# loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo} (gspo = grpo loss +
# importance_sampling_level="sequence"; trl has no literal "gspo")
# scale_rewards ∈ {"group"(std-norm), "batch", "none"(no std-norm, Dr.GRPO)}
# epsilon / epsilon_high — symmetric vs decoupled "clip-higher" (DAPO)
# importance_sampling_level ∈ {"token", "sequence"(GSPO)}
# beta — KL-to-ref coef (0.0 = reference-free)
# mask_truncated_completions — DAPO overlong masking
# num_iterations — on-policy reuse (1 = strict on-policy)
#: Selectable base policy-optimization objectives (named presets over trl knobs).
PO_OBJECTIVES: dict[str, dict[str, Any]] = {
# Vanilla GRPO (DeepSeekMath, arXiv 2402.03300): group-relative advantage
# WITH std normalization + per-sequence length normalization, KL on.
"grpo": {
"loss_type": "grpo",
"scale_rewards": "group",
"importance_sampling_level": "token",
"num_iterations": 1,
},
# Dr.GRPO (arXiv 2503.20783): remove length-std normalization bias (no
# advantage /std, length-independent aggregation). Framework's historical
# default (== make_dr_grpo_config). Composer 2.5's base objective.
"dr_grpo": {
"loss_type": "dr_grpo",
"scale_rewards": "none",
"importance_sampling_level": "token",
"num_iterations": 1,
},
# BNPO: batch-normalized variant (trl loss_type), std over the batch.
"bnpo": {
"loss_type": "bnpo",
"scale_rewards": "batch",
"importance_sampling_level": "token",
"num_iterations": 1,
},
# DAPO (arXiv 2503.14476): decoupled "clip-higher" (epsilon_high > epsilon)
# + token-level loss + overlong masking + KL removed. High-value, low-cost
# anti-entropy-collapse objective. epsilon_high=0.28 per the paper.
"dapo": {
"loss_type": "dapo",
"scale_rewards": "none",
"epsilon": 0.2,
"epsilon_high": 0.28,
"mask_truncated_completions": True,
"beta": 0.0,
"importance_sampling_level": "token",
"num_iterations": 1,
},
# GSPO (Qwen, arXiv 2507.18071): SEQUENCE-level importance ratio (one length-
# normalized ratio per response) — stabilizes long-CoT and especially MoE RL.
# trl expresses this as the grpo loss + importance_sampling_level="sequence".
"gspo": {
"loss_type": "grpo",
"scale_rewards": "group",
"importance_sampling_level": "sequence",
"num_iterations": 1,
},
# CISPO (MiniMax-M1, arXiv 2506.13585): clip the IS weight and detach it as a
# constant coefficient on log π — every token keeps a gradient (fixes the
# "rare reasoning tokens get zeroed by the clip" pathology). eps_max≈5 (ScaleRL).
"cispo": {
"loss_type": "cispo",
"scale_rewards": "none",
"epsilon_high": 5.0,
"importance_sampling_level": "token",
"num_iterations": 1,
},
}
def make_po_config(objective: str = "dr_grpo", **overrides: Any):
"""Build a `trl.GRPOConfig` for a NAMED policy-optimization objective.
The menu that gives RL real options beyond the single hardcoded Dr.GRPO
recipe. ``objective`` selects a preset from ``PO_OBJECTIVES`` (grpo /
dr_grpo / bnpo / dapo / gspo / cispo); ``**overrides`` set or override any
GRPOConfig field on top (e.g. ``output_dir=...``, ``beta=...``,
``learning_rate=...``).
All presets are PURE CONFIG over trl 1.5.0's GRPOTrainer (verified by
introspecting the installed package 2026-05-30): the trainer already
implements each ``loss_type`` branch and the ``importance_sampling_level`` /
``epsilon_high`` knobs, so no custom ``_compute_loss`` is needed. See ADR-014.
Raises:
ValueError: unknown objective (lists the valid menu).
AssertionError: a requested knob silently failed to apply (drift guard).
"""
from trl import GRPOConfig # local import: only when actually building a config
key = (objective or "dr_grpo").lower()
if key not in PO_OBJECTIVES:
raise ValueError(
f"Unknown PO objective {objective!r}. Choose from: "
f"{sorted(PO_OBJECTIVES)}. (Each is a named preset over trl 1.5.0's "
f"GRPOConfig knobs — see PO_OBJECTIVES / ADR-014.)"
)
preset = dict(PO_OBJECTIVES[key])
merged = {**preset, **overrides}
cfg = GRPOConfig(**merged)
# Drift guards: fail loudly if a future trl renamed/repurposed a knob we set,
# so a preset can never silently degrade to a different objective.
if "loss_type" in merged:
assert str(cfg.loss_type) == str(merged["loss_type"]), (
f"GRPOConfig.loss_type drifted: requested {merged['loss_type']!r}, "
f"got {cfg.loss_type!r} — trl may have renamed the knob."
)
if "importance_sampling_level" in merged and hasattr(cfg, "importance_sampling_level"):
assert str(cfg.importance_sampling_level) == str(
merged["importance_sampling_level"]
), (
f"importance_sampling_level drifted for objective {key!r}: requested "
f"{merged['importance_sampling_level']!r}, got {cfg.importance_sampling_level!r}."
)
if key == "gspo":
assert str(getattr(cfg, "importance_sampling_level", "token")) == "sequence", (
"GSPO requires importance_sampling_level='sequence'; it was overridden "
"to token, which silently degrades GSPO to GRPO. Drop that override."
)
if merged.get("epsilon_high") is not None:
assert abs(
float(getattr(cfg, "epsilon_high", merged["epsilon_high"]))
- float(merged["epsilon_high"])
) < 1e-9, f"epsilon_high (decoupled clip) drifted for {key!r}."
return cfg
__all__ = [
"ComposerReplicationTrainer",
"make_dr_grpo_config",
"make_po_config",
"PO_OBJECTIVES",
"validate_kl_in_reward_config",
]