Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Baladithya Balamurugan
Wave 20: Tier-0 fidelity fixes — k1-in-reward KL + Composer-2 behavior rewards
41289bf | """composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels. | |
| Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A". | |
| Verified extension point: GRPOTrainer._compute_loss(model, inputs) | |
| (DeepWiki audit of huggingface/trl, 2026-05-25). | |
| Total loss: | |
| total_loss = grpo_loss | |
| + alpha_sdpo * sdpo_kl_at_error_turns | |
| + beta_replay * trace_replay_dpo_loss | |
| Where: | |
| - grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches). | |
| - sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and | |
| teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only. | |
| - trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from | |
| N external teacher disagreement with the student. | |
| The data collator (data_collator.py) is responsible for: | |
| - Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint. | |
| - Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere). | |
| - Loading DPO pairs from the trace-replay output (see teacher_replay.py). | |
| - Precomputing reference-policy logprobs for DPO. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from collections.abc import Callable | |
| from typing import TYPE_CHECKING, Any | |
| import torch | |
| import torch.nn.functional as F # noqa: N812 — repo-wide torch convention | |
| if TYPE_CHECKING: # type-only — never imported at runtime (keeps the dep lazy) | |
| from composer_replication.safety import HeldOutGuard | |
| # These imports work when TRL is installed — they're not skeleton imports. | |
| # When TRL is missing we fall back to `object` so the module still imports | |
| # (e.g. for documentation generation) but raise a clear ImportError at | |
| # instantiation time rather than the cryptic `object.__init__()` error. | |
| try: | |
| from trl import GRPOTrainer # type: ignore | |
| _TRL_AVAILABLE = True | |
| except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL | |
| GRPOTrainer = object # type: ignore — fallback so module imports without TRL | |
| _TRL_AVAILABLE = False | |
| from composer_replication.opsd import generalized_jsd_loss | |
| from composer_replication.trainer.kl_in_reward import ( | |
| apply_kl_in_reward, | |
| kl_penalty_per_sequence, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type] | |
| """TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO. | |
| Args (in addition to GRPOTrainer's): | |
| alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled). | |
| Opt in by passing >0 once your data collator produces | |
| `sdpo_loss_mask` and `ctx_teacher_input_ids` columns. | |
| beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled). | |
| Opt in by passing >0 once your data collator produces | |
| `dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc. | |
| sdpo_jsd_beta: beta param of generalized_jsd_loss | |
| (0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per | |
| upstream OPSD convention; see composer_replication/opsd.py). | |
| sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0. | |
| sdpo_token_clip: per-token JSD clip for stability; None = no clip. | |
| replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula). | |
| kl_in_reward: when True, apply the KL-to-reference penalty in the | |
| **reward** (Composer-2 §4.1 / verl choice) instead of TRL's native | |
| **in-loss** k3 term. The penalty is folded into GRPO's advantages at | |
| scoring time (``adv -= beta·(KL - group_mean(KL))``) and TRL's | |
| in-loss KL is suppressed for that step. The F5 audit's #1 fidelity | |
| fix: the 2025/26 evidence (arXiv:2512.21852, verl, TRL #4967) shows | |
| k1-in-reward improves OOD generalization where k3-in-reward can | |
| collapse. REQUIRES ``beta>0`` (the KL coefficient — also how TRL | |
| decides to compute reference logprobs) and ``scale_rewards`` in | |
| {none,false} (the advantage-adjustment identity is exact only | |
| without std-normalization — the Dr.GRPO / Composer regime). Default | |
| False = TRL's native in-loss KL, byte-for-byte legacy behavior. | |
| kl_estimator: ``"k1"`` (default; ``logp - ref_logp``, the Composer-2 / | |
| verl choice this path exists for) or ``"k3"`` (Schulman; lets an | |
| experiment A/B k1-in-reward vs k3-in-reward). Only consulted when | |
| ``kl_in_reward=True``. | |
| heldout_guard: optional ``HeldOutGuard`` (the #2 collapse safeguard from | |
| ``composer_replication.safety``). Default None = OFF (no behavior | |
| change whatsoever). When supplied, the trainer folds one checkpoint's | |
| metrics into the guard at the ``args.logging_steps`` cadence (the same | |
| place the loss components are logged) and HALTS the run on a fired | |
| verdict — the run-level reward-hacking / collapse tripwire actually | |
| firing instead of sitting inert. | |
| heldout_eval_fn: zero-arg callable returning the held-out (real) eval | |
| score as a float, evaluated each guard cadence. Injectable so the | |
| trainer never hardcodes an eval — pass a closure over your disjoint | |
| held-out pool (the ``HeldoutSplit`` discipline). Required whenever | |
| ``heldout_guard`` is set; the guard's whole signal is in-loop reward | |
| vs. this held-out score. | |
| strict_killswitch: when True (default), a fired guard verdict raises | |
| ``CollapseStopError`` to hard-stop training (exception-based control | |
| flow, matching ``HeldOutGuard.raise_if_fired``). When False the | |
| verdict is logged and ``self.control.should_training_stop`` is set so | |
| the HF loop ends gracefully after the step (soft stop). Only consulted | |
| when ``heldout_guard`` is set. | |
| """ | |
| def __init__( | |
| self, | |
| *args: Any, | |
| alpha_sdpo: float = 0.0, | |
| beta_replay: float = 0.0, | |
| sdpo_jsd_beta: float = 0.5, | |
| sdpo_temperature: float = 1.0, | |
| sdpo_token_clip: float | None = None, | |
| replay_dpo_beta: float = 0.1, | |
| strict_sdpo_alignment: bool = True, | |
| kl_in_reward: bool = False, | |
| kl_estimator: str = "k1", | |
| heldout_guard: HeldOutGuard | None = None, | |
| heldout_eval_fn: Callable[[], float] | None = None, | |
| strict_killswitch: bool = True, | |
| **kwargs: Any, | |
| ): | |
| if not _TRL_AVAILABLE: | |
| raise ImportError( | |
| "ComposerReplicationTrainer requires TRL. Install with " | |
| "`pip install -e .[train]`." | |
| ) | |
| super().__init__(*args, **kwargs) | |
| self.alpha_sdpo = alpha_sdpo | |
| self.beta_replay = beta_replay | |
| self.sdpo_jsd_beta = sdpo_jsd_beta | |
| self.sdpo_temperature = sdpo_temperature | |
| self.sdpo_token_clip = sdpo_token_clip | |
| self.replay_dpo_beta = replay_dpo_beta | |
| # When True (default), an SDPO student/teacher shape mismatch is a hard | |
| # error — it means the data collator failed to align the post-hint | |
| # section, which silently zeroes the distillation signal (the exact | |
| # trust-gap flagged in ADR-008). Set False only for production runs | |
| # where a single malformed batch should warn-and-skip rather than abort. | |
| self.strict_sdpo_alignment = strict_sdpo_alignment | |
| # --- k1-in-reward KL (F5 #1 fidelity fix; Composer-2 §4.1 / verl) ---- | |
| # OFF by default → TRL's native in-loss k3 KL, byte-for-byte legacy. | |
| # When ON we keep self.beta as the KL coef (TRL needs beta>0 to even | |
| # create the ref model + compute ref logps), fold the k1 penalty into | |
| # advantages during scoring, and zero TRL's in-loss KL per step. | |
| self.kl_in_reward = kl_in_reward | |
| self.kl_estimator = kl_estimator | |
| if kl_in_reward: | |
| validate_kl_in_reward_config( | |
| kl_estimator=kl_estimator, | |
| beta=float(getattr(self.args, "beta", 0.0)), | |
| scale_rewards=getattr(self.args, "scale_rewards", "group"), | |
| ) | |
| # --- run-level collapse kill-switch (#2 safeguard) ------------------- | |
| # OPTIONAL + OFF BY DEFAULT: when heldout_guard is None the loss path is | |
| # byte-for-byte the legacy behavior. When set, _maybe_update_killswitch | |
| # folds metrics into the guard at the logging cadence (see _compute_loss). | |
| self.heldout_guard = heldout_guard | |
| self.heldout_eval_fn = heldout_eval_fn | |
| self.strict_killswitch = strict_killswitch | |
| if heldout_guard is not None and heldout_eval_fn is None: | |
| raise ValueError( | |
| "heldout_guard was provided without heldout_eval_fn: the guard's " | |
| "tripwire compares in-loop reward against a DISJOINT held-out " | |
| "(real) eval score, so it needs an injectable zero-arg " | |
| "heldout_eval_fn() -> float. Pass a closure over your held-out " | |
| "pool (the HeldoutSplit discipline)." | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Loss override (the integration core) | |
| # ---------------------------------------------------------------------- | |
| # ---------------------------------------------------------------------- | |
| # k1-in-reward: fold the KL penalty into advantages at scoring time, and | |
| # suppress TRL's native in-loss k3 KL inside _compute_loss. | |
| # ---------------------------------------------------------------------- | |
| def _generate_and_score_completions( | |
| self, | |
| inputs: list[dict[str, Any]], | |
| ) -> dict[str, Any]: | |
| """Override: after TRL scores completions, fold a k1 KL penalty into the | |
| advantages (Composer-2 in-reward KL) when ``kl_in_reward`` is set. | |
| No-op (exact legacy path) when ``kl_in_reward`` is False. When set, TRL | |
| has already computed ``advantages``, ``ref_per_token_logps`` (because | |
| ``beta>0``), and the completion logprobs; we recompute the per-sequence | |
| k1 penalty and apply the exact group-mean-baseline correction. | |
| """ | |
| output = super()._generate_and_score_completions(inputs) | |
| if not getattr(self, "kl_in_reward", False): | |
| return output | |
| ref_logps = output.get("ref_per_token_logps") | |
| # The "old" (sampling-time) policy logps are TRL's in-loss π term; they | |
| # may be lazily None when generation/optimization are aligned and not | |
| # vLLM (see TRL _compute_loss: old := per_token_logps.detach()). In that | |
| # aligned case we cannot read π logps here, so we defer to _compute_loss | |
| # (which always has per_token_logps) by stashing what we need. | |
| old_logps = output.get("old_per_token_logps") | |
| completion_mask = output.get("completion_mask") | |
| if ref_logps is None or completion_mask is None: | |
| # beta>0 guarantees ref_logps; this branch only trips on a TRL | |
| # internals change — fail loud rather than silently skip the penalty. | |
| raise RuntimeError( | |
| "kl_in_reward=True but TRL did not return ref_per_token_logps / " | |
| "completion_mask from scoring (beta>0 should guarantee them). " | |
| "TRL internals may have changed; re-verify the in-reward path." | |
| ) | |
| if old_logps is not None: | |
| penalty = kl_penalty_per_sequence( | |
| policy_logps=old_logps, | |
| ref_logps=ref_logps, | |
| completion_mask=completion_mask, | |
| estimator=self.kl_estimator, | |
| ) | |
| output["advantages"] = apply_kl_in_reward( | |
| advantages=output["advantages"], | |
| kl_penalty=penalty, | |
| num_generations=self.num_generations, | |
| coef=float(self.args.beta), | |
| ) | |
| output["_kl_in_reward_applied"] = torch.tensor(True) | |
| else: | |
| # Aligned non-vLLM case: π logps materialize only in _compute_loss. | |
| # Stash ref logps + mask so _compute_loss can apply the penalty there. | |
| output["_kl_in_reward_applied"] = torch.tensor(False) | |
| return output | |
| def _compute_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Override: total_loss = grpo + α*sdpo + β*replay. | |
| When ``kl_in_reward`` is set, TRL's native in-loss KL term (gated on | |
| ``self.beta``) is suppressed by temporarily zeroing ``self.beta`` for the | |
| duration of the parent call — the KL has already been (or is about to be) | |
| accounted for in the reward/advantage, so double-counting it in the loss | |
| would be wrong. ``self.beta`` is restored in ``finally``. | |
| """ | |
| # Channel 1: standard GRPO loss. ``getattr`` (not ``self.kl_in_reward``) | |
| # so an instance built via ``__new__`` + manual wiring (the SDPO / | |
| # kill-switch unit-test pattern that skips __init__) defaults to the | |
| # legacy path instead of raising AttributeError. | |
| if getattr(self, "kl_in_reward", False): | |
| grpo_loss = self._grpo_loss_kl_in_reward(model, inputs) | |
| else: | |
| grpo_loss = super()._compute_loss(model, inputs) | |
| # Channel 2: SDPO hint-distill at error sites | |
| sdpo_kl = self._compute_sdpo_loss(model, inputs) | |
| # Channel 3: trace-replay DPO from teacher disagreement | |
| replay_dpo = self._compute_trace_replay_loss(model, inputs) | |
| # Compose | |
| total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo | |
| # Log per-channel components (so we can ablate post-hoc) | |
| if hasattr(self, "state") and getattr(self, "args", None) is not None: | |
| log_steps = getattr(self.args, "logging_steps", 50) | |
| if self.state.global_step % log_steps == 0: | |
| self.log({ # type: ignore[attr-defined] | |
| "loss/grpo": float(grpo_loss.detach()), | |
| "loss/sdpo_kl": float(sdpo_kl.detach()), | |
| "loss/trace_replay_dpo": float(replay_dpo.detach()), | |
| "loss/total": float(total.detach()), | |
| "loss/alpha_sdpo": self.alpha_sdpo, | |
| "loss/beta_replay": self.beta_replay, | |
| }) | |
| # Fold one checkpoint into the run-level collapse kill-switch at | |
| # the SAME cadence (no-op unless a guard was configured). | |
| self._maybe_update_killswitch() | |
| return total | |
| def _grpo_loss_kl_in_reward( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """GRPO loss with the KL applied in the reward, not the loss. | |
| Two responsibilities: | |
| 1. Suppress TRL's native in-loss k3 KL term for this step by zeroing | |
| ``self.beta`` across the parent ``_compute_loss`` call (restored in | |
| ``finally``). ``self.beta`` gates the in-loss KL add (TRL | |
| ``_compute_loss``: ``if self.beta != 0.0: per_token_loss += beta*kl``). | |
| 2. Handle the deferred case: when generation/optimization are aligned | |
| and not using vLLM, the sampling-time policy logps are None at | |
| scoring time, so ``_generate_and_score_completions`` could not fold | |
| the penalty into advantages. Here ``per_token_logps`` is available, | |
| so we apply the exact same advantage correction in-place on | |
| ``inputs["advantages"]`` BEFORE the parent computes the surrogate. | |
| """ | |
| # Deferred-penalty path: advantages not yet KL-adjusted (aligned, no vLLM). | |
| applied = inputs.get("_kl_in_reward_applied") | |
| already_applied = bool(applied.item()) if applied is not None else False | |
| if not already_applied and "ref_per_token_logps" in inputs: | |
| with torch.no_grad(): | |
| prompt_ids, completion_ids = inputs["prompt_ids"], inputs["completion_ids"] | |
| completion_mask = inputs["completion_mask"] | |
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) | |
| attention_mask = torch.cat([inputs["prompt_mask"], completion_mask], dim=1) | |
| logits_to_keep = completion_ids.size(1) | |
| policy_logps, _ = self._get_per_token_logps_and_entropies( | |
| model, input_ids, attention_mask, logits_to_keep | |
| ) | |
| penalty = kl_penalty_per_sequence( | |
| policy_logps=policy_logps, | |
| ref_logps=inputs["ref_per_token_logps"], | |
| completion_mask=completion_mask, | |
| estimator=self.kl_estimator, | |
| ) | |
| advantages = inputs["advantages"] | |
| # advantages may be (B,) or (B,1) — squeeze for the penalty math, | |
| # restore the original shape after. | |
| adv_flat = advantages.reshape(advantages.shape[0]) | |
| adj = apply_kl_in_reward( | |
| advantages=adv_flat, | |
| kl_penalty=penalty, | |
| num_generations=self.num_generations, | |
| coef=float(self.args.beta), | |
| ) | |
| inputs["advantages"] = adj.reshape(advantages.shape) | |
| # Suppress TRL's in-loss KL: zero beta for the parent call, restore after. | |
| saved_beta = self.beta | |
| try: | |
| self.beta = 0.0 | |
| return super()._compute_loss(model, inputs) | |
| finally: | |
| self.beta = saved_beta | |
| # ---------------------------------------------------------------------- | |
| # Run-level collapse kill-switch (#2 safeguard) — optional, OFF by default | |
| # ---------------------------------------------------------------------- | |
| def _maybe_update_killswitch(self) -> None: | |
| """Fold this checkpoint's metrics into ``heldout_guard`` and act on a fire. | |
| No-op when no guard was configured (the default) — this is the | |
| backward-compat guarantee: without ``heldout_guard`` the trainer behaves | |
| exactly as before. When a guard IS set: | |
| * ``in_loop_reward`` is the GRPO reward signal TRL already aggregates | |
| into ``self._metrics[mode]["reward"]`` each step (we read the latest; | |
| no extra forward pass). | |
| * ``heldout_score`` comes from the injected ``heldout_eval_fn()`` — the | |
| trainer never hardcodes an eval. | |
| * ``kl_to_init`` (token-mean nats/token, the ``token_mean_kl`` | |
| convention the guard expects) is read from TRL's logged ``"kl"`` | |
| metric when present, else left None (KL path stays inert). | |
| On a fired verdict the verdict is logged. If ``strict_killswitch`` (the | |
| default) the verdict is converted into a ``CollapseStopError`` via | |
| ``HeldOutGuard.raise_if_fired`` (hard stop); otherwise the HF training | |
| loop is asked to stop gracefully after this step. | |
| """ | |
| guard = self.heldout_guard | |
| if guard is None: | |
| return # OFF by default — zero behavior change | |
| round_idx = int(getattr(self.state, "global_step", 0)) | |
| in_loop_reward = self._latest_metric("reward") | |
| if in_loop_reward is None: | |
| # No reward aggregated yet (e.g. very first micro-step before TRL has | |
| # populated its metrics). Skip this cadence rather than feed a | |
| # fabricated 0.0 that would pollute the guard's baseline/EMA. | |
| logger.debug( | |
| "kill-switch: no in-loop reward metric yet at step %d; skipping.", | |
| round_idx, | |
| ) | |
| return | |
| assert self.heldout_eval_fn is not None # enforced in __init__ | |
| heldout_score = float(self.heldout_eval_fn()) | |
| kl_to_init = self._latest_metric("kl") # token-mean KL, or None | |
| status = guard.update( | |
| round_idx=round_idx, | |
| in_loop_reward=in_loop_reward, | |
| heldout_score=heldout_score, | |
| kl_to_init=kl_to_init, | |
| ) | |
| self.log({ # type: ignore[attr-defined] | |
| "killswitch/in_loop_reward": status.in_loop_ema, | |
| "killswitch/heldout_score": status.heldout_ema, | |
| "killswitch/proxy_real_gap": status.proxy_real_gap, | |
| "killswitch/fire": float(status.fire), | |
| }) | |
| if status.fire: | |
| logger.error( | |
| "HeldOutGuard FIRED at step %d — halting run. reason: %s", | |
| round_idx, status.reason, | |
| ) | |
| if self.strict_killswitch: | |
| # Typed exception — exception-based hard stop. | |
| guard.raise_if_fired(status) | |
| else: | |
| # Soft stop: let the HF loop terminate gracefully after this step. | |
| control = getattr(self, "control", None) | |
| if control is not None: | |
| control.should_training_stop = True | |
| def _latest_metric(self, name: str) -> float | None: | |
| """Most-recent value of a TRL-aggregated train metric, or None. | |
| TRL's GRPOTrainer appends per-step aggregates to | |
| ``self._metrics["train"][name]`` (e.g. ``"reward"``, ``"kl"``). We read | |
| the tail defensively so a TRL internals rename degrades to None (KL/reward | |
| path goes inert) rather than crashing training. | |
| """ | |
| metrics = getattr(self, "_metrics", None) | |
| if not isinstance(metrics, dict): | |
| return None | |
| train = metrics.get("train") | |
| if not isinstance(train, dict): | |
| return None | |
| series = train.get(name) | |
| if not series: | |
| return None | |
| try: | |
| return float(series[-1]) | |
| except (TypeError, ValueError, IndexError): | |
| return None | |
| # ---------------------------------------------------------------------- | |
| # Channel 2: SDPO hint-distill | |
| # ---------------------------------------------------------------------- | |
| def _compute_sdpo_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Compute generalized_jsd_loss between student and hint-conditioned teacher. | |
| Both come from the SAME model — teacher just has hint inserted into context. | |
| Skipped (returns 0) if the batch has no error sites (data collator emits | |
| empty ctx_teacher_input_ids). | |
| """ | |
| if ( | |
| self.alpha_sdpo == 0.0 | |
| or "ctx_teacher_input_ids" not in inputs | |
| or inputs["ctx_teacher_input_ids"].numel() == 0 | |
| ): | |
| return torch.tensor(0.0, device=_device_of(model), requires_grad=True) | |
| # Student forward (with grad, on the original-context input) | |
| student_logits = model(input_ids=inputs["input_ids"]).logits | |
| # Teacher forward (no grad — same model, hint-conditioned context) | |
| with torch.no_grad(): | |
| teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits | |
| # ------------------------------------------------------------------ | |
| # ALIGNMENT (cross-family review 2026-05-29 — the 4/4-reviewer P0). | |
| # | |
| # The teacher context has a hint inserted at the error turn, so the | |
| # teacher's post-hint response tokens are shifted right by len(hint) | |
| # relative to the student's. A bare `student.shape == teacher.shape` | |
| # check does NOT establish token-level alignment: equal-length tensors | |
| # whose response regions are offset will be JSD'd position-by-position | |
| # against each other, distilling garbage into the policy. | |
| # | |
| # The ONLY correct alignment is an explicit map from the collator that | |
| # selects, for each response token, the matching index in each sequence. | |
| # We require it whenever SDPO is active: | |
| # - `student_response_idx` / `teacher_response_idx`: LongTensors of | |
| # equal length selecting the aligned response positions in each | |
| # sequence (the collator builds these knowing where it inserted the | |
| # hint). JSD is computed over the gathered, provably-aligned logits. | |
| # - If the collator cannot yet supply them, strict mode raises (loud | |
| # failure) rather than silently distilling misaligned tokens. | |
| s_idx = inputs.get("student_response_idx") | |
| t_idx = inputs.get("teacher_response_idx") | |
| if s_idx is None or t_idx is None: | |
| msg = ( | |
| "SDPO alignment indices missing: the collator must emit " | |
| "`student_response_idx` and `teacher_response_idx` (matching " | |
| "LongTensors selecting the aligned post-hint response tokens) so " | |
| "the JSD compares corresponding tokens. A shape-only check does " | |
| "NOT establish alignment — the hint shifts the teacher's response " | |
| "tokens right, so equal-length sequences can still be misaligned " | |
| "and silently distill garbage into the policy (ADR-008 trust-gap)." | |
| ) | |
| if self.strict_sdpo_alignment: | |
| raise ValueError( | |
| msg + " (strict_sdpo_alignment=True; pass False to fall back " | |
| "to the legacy shape-only check for resilience.)" | |
| ) | |
| logger.warning("%s Falling back to shape-only alignment check.", msg) | |
| if student_logits.shape != teacher_logits.shape: | |
| logger.warning( | |
| "SDPO shape mismatch student=%s teacher=%s; skipping.", | |
| tuple(student_logits.shape), tuple(teacher_logits.shape), | |
| ) | |
| return torch.tensor(0.0, device=_device_of(model), requires_grad=True) | |
| return generalized_jsd_loss( | |
| student_logits=student_logits, | |
| teacher_logits=teacher_logits, | |
| labels=inputs.get("sdpo_loss_mask"), | |
| beta=self.sdpo_jsd_beta, | |
| temperature=self.sdpo_temperature, | |
| token_clip=self.sdpo_token_clip, | |
| reduction="batchmean", | |
| ) | |
| # Validate the index tensors describe a consistent 1:1 alignment. | |
| if s_idx.shape != t_idx.shape: | |
| raise ValueError( | |
| f"SDPO alignment index shape mismatch: student_response_idx=" | |
| f"{tuple(s_idx.shape)} vs teacher_response_idx={tuple(t_idx.shape)}. " | |
| "They must select the same number of aligned response tokens." | |
| ) | |
| # Gather the provably-aligned response logits from each sequence, then | |
| # JSD only those positions (this is the masked error-turn distillation). | |
| # gather over the sequence dim (dim=1): expand index to the vocab dim. | |
| # | |
| # ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row | |
| # *_valid mask. Negative indices are illegal for torch.gather, so clamp | |
| # to 0 before gathering, then neutralize those positions by feeding | |
| # labels=-100 (the standard HF ignore convention that generalized_jsd_loss | |
| # already honors). This makes sentinel/padding positions contribute 0. | |
| # | |
| # Final-verify 2026-05-29: combine BOTH valid masks (not just student's) | |
| # AND the sentinel guard. If a future collator ever emits divergent | |
| # student/teacher valid tails, a teacher sentinel clamped to 0 would | |
| # otherwise be silently distilled against teacher position 0. Belt-and- | |
| # suspenders: valid iff student-valid AND teacher-valid AND both indices | |
| # non-sentinel. | |
| s_valid = inputs.get("student_response_valid") | |
| t_valid = inputs.get("teacher_response_valid") | |
| aligned_mask = (s_idx >= 0) & (t_idx >= 0) | |
| if s_valid is not None: | |
| aligned_mask = aligned_mask & s_valid.bool() | |
| if t_valid is not None: | |
| aligned_mask = aligned_mask & t_valid.bool() | |
| vocab = student_logits.size(-1) | |
| s_safe = s_idx.clamp_min(0) | |
| t_safe = t_idx.clamp_min(0) | |
| s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab) | |
| t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab) | |
| student_aligned = torch.gather(student_logits, 1, s_gather) | |
| teacher_aligned = torch.gather(teacher_logits, 1, t_gather) | |
| # Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at | |
| # sentinel/padding positions so they drop out of the JSD reduction. | |
| aligned_labels = torch.where( | |
| aligned_mask, | |
| torch.ones_like(s_idx), | |
| torch.full_like(s_idx, -100), | |
| ) | |
| return generalized_jsd_loss( | |
| student_logits=student_aligned, | |
| teacher_logits=teacher_aligned, | |
| labels=aligned_labels, # sentinel-masked aligned error-turn positions | |
| beta=self.sdpo_jsd_beta, | |
| temperature=self.sdpo_temperature, | |
| token_clip=self.sdpo_token_clip, | |
| reduction="batchmean", | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Channel 3: trace-replay DPO | |
| # ---------------------------------------------------------------------- | |
| def _compute_trace_replay_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Standard DPO loss using (chosen, rejected) pairs from teacher disagreement. | |
| DPO loss formula (Rafailov et al. 2023): | |
| L = -log σ(β · (logπ(chosen) - logπ_ref(chosen) | |
| - logπ(rejected) + logπ_ref(rejected))) | |
| Where logπ_ref are precomputed by the data collator using the | |
| reference (init student) policy. | |
| """ | |
| if ( | |
| self.beta_replay == 0.0 | |
| or "dpo_chosen_input_ids" not in inputs | |
| or inputs["dpo_chosen_input_ids"].numel() == 0 | |
| ): | |
| return torch.tensor(0.0, device=_device_of(model), requires_grad=True) | |
| # Forward passes for chosen and rejected, gather logprobs at response tokens | |
| chosen_logprobs = self._sequence_logprobs( | |
| model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"] | |
| ) | |
| rejected_logprobs = self._sequence_logprobs( | |
| model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"] | |
| ) | |
| ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"] | |
| ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"] | |
| logits = self.replay_dpo_beta * ( | |
| (chosen_logprobs - ref_chosen_logprobs) | |
| - (rejected_logprobs - ref_rejected_logprobs) | |
| ) | |
| return -F.logsigmoid(logits).mean() | |
| def _sequence_logprobs( | |
| model: torch.nn.Module, | |
| input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Sum logprob of response tokens given the prompt prefix. | |
| Standard DPO accounting: we only score the response tokens (where | |
| response_mask == 1), not the prompt tokens. | |
| """ | |
| outputs = model(input_ids=input_ids) | |
| # Shift for next-token prediction: logits[t] predicts input_ids[t+1] | |
| logits = outputs.logits[:, :-1, :] | |
| targets = input_ids[:, 1:] | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) | |
| # Mask out prompt + padding; sum response-token logprobs | |
| masked = token_logprobs * response_mask[:, 1:].float() | |
| return masked.sum(dim=-1) | |
| def _device_of(model: torch.nn.Module) -> torch.device: | |
| """Return the device of any parameter of the model — robust to FSDP/DDP wrappers.""" | |
| return next(model.parameters()).device | |
| def validate_kl_in_reward_config( | |
| *, | |
| kl_estimator: str, | |
| beta: float, | |
| scale_rewards: Any, | |
| ) -> None: | |
| """Validate the (kl_estimator, beta, scale_rewards) combo for k1-in-reward. | |
| Extracted so the preconditions are unit-testable without standing up a real | |
| GRPOTrainer (which needs a model + dataset). Raises ``ValueError`` on any | |
| invalid combination; returns None when the config is sound. | |
| Preconditions (see ``kl_in_reward.py`` for the algebra): | |
| * ``kl_estimator`` in {k1, k3}. | |
| * ``beta != 0`` — TRL only builds the reference model and computes ref | |
| logprobs when beta>0, and the in-reward penalty needs ref logps. beta | |
| doubles as the in-reward KL coefficient (the in-loss k3 term is | |
| suppressed per step). | |
| * ``scale_rewards`` in {none, false} — the advantage-adjustment identity | |
| is exact only without per-group std-normalization (the Dr.GRPO / | |
| Composer regime). | |
| """ | |
| if kl_estimator not in ("k1", "k3"): | |
| raise ValueError(f"kl_estimator must be 'k1' or 'k3', got {kl_estimator!r}.") | |
| if float(beta) == 0.0: | |
| raise ValueError( | |
| "kl_in_reward=True requires a non-zero `beta` (the KL coefficient): " | |
| "TRL only creates the reference model and computes ref logprobs when " | |
| "beta>0, and k1-in-reward needs those ref logps. Set beta to your KL " | |
| "coefficient (e.g. make_po_config('dr_grpo', beta=0.04)); the in-loss " | |
| "k3 term is suppressed automatically so beta acts purely as the " | |
| "in-reward k1 coefficient." | |
| ) | |
| if str(scale_rewards).lower() not in ("none", "false"): | |
| raise ValueError( | |
| "kl_in_reward=True requires scale_rewards in {none,false} " | |
| f"(got {scale_rewards!r}). The advantage-adjustment identity " | |
| "adv -= beta·(KL - group_mean(KL)) is EXACT only without per-group " | |
| "std-normalization (the Dr.GRPO / Composer regime). With std-norm, " | |
| "folding KL into the reward also shifts the group std, so the linear " | |
| "correction no longer matches true in-reward KL. Use " | |
| "make_po_config('dr_grpo', beta=…) (scale_rewards='none')." | |
| ) | |
| def make_dr_grpo_config(**overrides: Any): | |
| """Build a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe. | |
| Per the Composer 2 technical report (arXiv:2603.24477, | |
| research/10-composer2-techreport-mining.md) the RL base is Dr. GRPO | |
| (Liu et al., arXiv:2503.20783): | |
| - ``loss_type="dr_grpo"`` — removes GRPO's length-standardization term | |
| (which injects a length bias). TRL's own help text cites the Dr. GRPO | |
| paper for this. | |
| - ``scale_rewards="none"`` — NO std-dev advantage normalization. TRL docs: | |
| "The Dr. GRPO paper recommends not scaling rewards, as scaling by the | |
| standard deviation introduces a question-level difficulty bias." | |
| - ``num_iterations=1`` — single-epoch regime (a prompt is never | |
| trained on twice), matching the tech report. | |
| - ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012 | |
| finding #1, verified against the installed trl==1.5.0 source): | |
| ``GRPOTrainer._compute_loss`` uses the **k3** estimator | |
| ``exp(ref_logp - logp) - (ref_logp - logp) - 1`` | |
| (trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator | |
| ``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance, | |
| always-non-negative KL approximation; k1 is its unbiased but | |
| higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses | |
| KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2)) | |
| and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL | |
| to force k1; we document the honest delta. See | |
| ``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``. | |
| Any field can be overridden via kwargs (e.g. ``learning_rate=...``, | |
| ``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless | |
| explicitly overridden, and a sanity assertion guards against silent drift. | |
| """ | |
| from trl import GRPOConfig # local import: only when actually building a config | |
| dr_grpo_defaults: dict[str, Any] = { | |
| "loss_type": "dr_grpo", | |
| "scale_rewards": "none", | |
| "num_iterations": 1, | |
| } | |
| merged = {**dr_grpo_defaults, **overrides} | |
| cfg = GRPOConfig(**merged) | |
| # Guard: fail loudly if a future TRL renames/repurposes these knobs. | |
| assert cfg.loss_type == merged["loss_type"], ( | |
| f"GRPOConfig loss_type drifted: requested {merged['loss_type']!r}, " | |
| f"got {cfg.loss_type!r} — TRL may have renamed/repurposed the knob." | |
| ) | |
| # Dr. GRPO requires NO std-dev advantage normalization. TRL accepts either | |
| # the string "none" or the bool False to disable it; normalize before | |
| # comparing so a future TRL that switches the representation still passes | |
| # (and a genuinely-wrong value like "batch"/"group"/True fails loudly). | |
| # (Cross-family review 2026-05-29: the prior literal `("none","False","False")` | |
| # had a duplicated "False" and did a brittle case-sensitive str compare.) | |
| assert str(cfg.scale_rewards).lower() in ("none", "false"), ( | |
| f"Dr. GRPO requires scale_rewards disabled (no std-norm); got " | |
| f"{cfg.scale_rewards!r}. TRL knob may have drifted — re-verify against trl version." | |
| ) | |
| assert cfg.num_iterations == merged["num_iterations"], "GRPOConfig dropped num_iterations" | |
| return cfg | |
| # --------------------------------------------------------------------------- | |
| # Policy-optimization objective MENU (ADR-014) | |
| # --------------------------------------------------------------------------- | |
| # | |
| # The base RL objective used to be hardcoded to Dr.GRPO (make_dr_grpo_config). | |
| # make_po_config gives RL a real menu: GRPO-family objectives selectable by name. | |
| # Verified against the installed trl==1.5.0 (introspected 2026-05-30): its | |
| # GRPOTrainer already implements these as `loss_type` branches + knobs, so EVERY | |
| # preset below is pure config — no custom _compute_loss override needed. | |
| # | |
| # Knob-space each preset sets (all real GRPOConfig fields in trl 1.5.0): | |
| # loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo} (gspo = grpo loss + | |
| # importance_sampling_level="sequence"; trl has no literal "gspo") | |
| # scale_rewards ∈ {"group"(std-norm), "batch", "none"(no std-norm, Dr.GRPO)} | |
| # epsilon / epsilon_high — symmetric vs decoupled "clip-higher" (DAPO) | |
| # importance_sampling_level ∈ {"token", "sequence"(GSPO)} | |
| # beta — KL-to-ref coef (0.0 = reference-free) | |
| # mask_truncated_completions — DAPO overlong masking | |
| # num_iterations — on-policy reuse (1 = strict on-policy) | |
| #: Selectable base policy-optimization objectives (named presets over trl knobs). | |
| PO_OBJECTIVES: dict[str, dict[str, Any]] = { | |
| # Vanilla GRPO (DeepSeekMath, arXiv 2402.03300): group-relative advantage | |
| # WITH std normalization + per-sequence length normalization, KL on. | |
| "grpo": { | |
| "loss_type": "grpo", | |
| "scale_rewards": "group", | |
| "importance_sampling_level": "token", | |
| "num_iterations": 1, | |
| }, | |
| # Dr.GRPO (arXiv 2503.20783): remove length-std normalization bias (no | |
| # advantage /std, length-independent aggregation). Framework's historical | |
| # default (== make_dr_grpo_config). Composer 2.5's base objective. | |
| "dr_grpo": { | |
| "loss_type": "dr_grpo", | |
| "scale_rewards": "none", | |
| "importance_sampling_level": "token", | |
| "num_iterations": 1, | |
| }, | |
| # BNPO: batch-normalized variant (trl loss_type), std over the batch. | |
| "bnpo": { | |
| "loss_type": "bnpo", | |
| "scale_rewards": "batch", | |
| "importance_sampling_level": "token", | |
| "num_iterations": 1, | |
| }, | |
| # DAPO (arXiv 2503.14476): decoupled "clip-higher" (epsilon_high > epsilon) | |
| # + token-level loss + overlong masking + KL removed. High-value, low-cost | |
| # anti-entropy-collapse objective. epsilon_high=0.28 per the paper. | |
| "dapo": { | |
| "loss_type": "dapo", | |
| "scale_rewards": "none", | |
| "epsilon": 0.2, | |
| "epsilon_high": 0.28, | |
| "mask_truncated_completions": True, | |
| "beta": 0.0, | |
| "importance_sampling_level": "token", | |
| "num_iterations": 1, | |
| }, | |
| # GSPO (Qwen, arXiv 2507.18071): SEQUENCE-level importance ratio (one length- | |
| # normalized ratio per response) — stabilizes long-CoT and especially MoE RL. | |
| # trl expresses this as the grpo loss + importance_sampling_level="sequence". | |
| "gspo": { | |
| "loss_type": "grpo", | |
| "scale_rewards": "group", | |
| "importance_sampling_level": "sequence", | |
| "num_iterations": 1, | |
| }, | |
| # CISPO (MiniMax-M1, arXiv 2506.13585): clip the IS weight and detach it as a | |
| # constant coefficient on log π — every token keeps a gradient (fixes the | |
| # "rare reasoning tokens get zeroed by the clip" pathology). eps_max≈5 (ScaleRL). | |
| "cispo": { | |
| "loss_type": "cispo", | |
| "scale_rewards": "none", | |
| "epsilon_high": 5.0, | |
| "importance_sampling_level": "token", | |
| "num_iterations": 1, | |
| }, | |
| } | |
| def make_po_config(objective: str = "dr_grpo", **overrides: Any): | |
| """Build a `trl.GRPOConfig` for a NAMED policy-optimization objective. | |
| The menu that gives RL real options beyond the single hardcoded Dr.GRPO | |
| recipe. ``objective`` selects a preset from ``PO_OBJECTIVES`` (grpo / | |
| dr_grpo / bnpo / dapo / gspo / cispo); ``**overrides`` set or override any | |
| GRPOConfig field on top (e.g. ``output_dir=...``, ``beta=...``, | |
| ``learning_rate=...``). | |
| All presets are PURE CONFIG over trl 1.5.0's GRPOTrainer (verified by | |
| introspecting the installed package 2026-05-30): the trainer already | |
| implements each ``loss_type`` branch and the ``importance_sampling_level`` / | |
| ``epsilon_high`` knobs, so no custom ``_compute_loss`` is needed. See ADR-014. | |
| Raises: | |
| ValueError: unknown objective (lists the valid menu). | |
| AssertionError: a requested knob silently failed to apply (drift guard). | |
| """ | |
| from trl import GRPOConfig # local import: only when actually building a config | |
| key = (objective or "dr_grpo").lower() | |
| if key not in PO_OBJECTIVES: | |
| raise ValueError( | |
| f"Unknown PO objective {objective!r}. Choose from: " | |
| f"{sorted(PO_OBJECTIVES)}. (Each is a named preset over trl 1.5.0's " | |
| f"GRPOConfig knobs — see PO_OBJECTIVES / ADR-014.)" | |
| ) | |
| preset = dict(PO_OBJECTIVES[key]) | |
| merged = {**preset, **overrides} | |
| cfg = GRPOConfig(**merged) | |
| # Drift guards: fail loudly if a future trl renamed/repurposed a knob we set, | |
| # so a preset can never silently degrade to a different objective. | |
| if "loss_type" in merged: | |
| assert str(cfg.loss_type) == str(merged["loss_type"]), ( | |
| f"GRPOConfig.loss_type drifted: requested {merged['loss_type']!r}, " | |
| f"got {cfg.loss_type!r} — trl may have renamed the knob." | |
| ) | |
| if "importance_sampling_level" in merged and hasattr(cfg, "importance_sampling_level"): | |
| assert str(cfg.importance_sampling_level) == str( | |
| merged["importance_sampling_level"] | |
| ), ( | |
| f"importance_sampling_level drifted for objective {key!r}: requested " | |
| f"{merged['importance_sampling_level']!r}, got {cfg.importance_sampling_level!r}." | |
| ) | |
| if key == "gspo": | |
| assert str(getattr(cfg, "importance_sampling_level", "token")) == "sequence", ( | |
| "GSPO requires importance_sampling_level='sequence'; it was overridden " | |
| "to token, which silently degrades GSPO to GRPO. Drop that override." | |
| ) | |
| if merged.get("epsilon_high") is not None: | |
| assert abs( | |
| float(getattr(cfg, "epsilon_high", merged["epsilon_high"])) | |
| - float(merged["epsilon_high"]) | |
| ) < 1e-9, f"epsilon_high (decoupled clip) drifted for {key!r}." | |
| return cfg | |
| __all__ = [ | |
| "ComposerReplicationTrainer", | |
| "make_dr_grpo_config", | |
| "make_po_config", | |
| "PO_OBJECTIVES", | |
| "validate_kl_in_reward_config", | |
| ] | |