"""composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels. Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A". Verified extension point: GRPOTrainer._compute_loss(model, inputs) (DeepWiki audit of huggingface/trl, 2026-05-25). Total loss: total_loss = grpo_loss + alpha_sdpo * sdpo_kl_at_error_turns + beta_replay * trace_replay_dpo_loss Where: - grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches). - sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only. - trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from N external teacher disagreement with the student. The data collator (data_collator.py) is responsible for: - Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint. - Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere). - Loading DPO pairs from the trace-replay output (see teacher_replay.py). - Precomputing reference-policy logprobs for DPO. """ from __future__ import annotations import logging from collections.abc import Callable from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F # noqa: N812 — repo-wide torch convention if TYPE_CHECKING: # type-only — never imported at runtime (keeps the dep lazy) from composer_replication.safety import HeldOutGuard # These imports work when TRL is installed — they're not skeleton imports. # When TRL is missing we fall back to `object` so the module still imports # (e.g. for documentation generation) but raise a clear ImportError at # instantiation time rather than the cryptic `object.__init__()` error. try: from trl import GRPOTrainer # type: ignore _TRL_AVAILABLE = True except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL GRPOTrainer = object # type: ignore — fallback so module imports without TRL _TRL_AVAILABLE = False from composer_replication.opsd import generalized_jsd_loss from composer_replication.trainer.kl_in_reward import ( apply_kl_in_reward, kl_penalty_per_sequence, ) logger = logging.getLogger(__name__) class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type] """TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO. Args (in addition to GRPOTrainer's): alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled). Opt in by passing >0 once your data collator produces `sdpo_loss_mask` and `ctx_teacher_input_ids` columns. beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled). Opt in by passing >0 once your data collator produces `dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc. sdpo_jsd_beta: beta param of generalized_jsd_loss (0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per upstream OPSD convention; see composer_replication/opsd.py). sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0. sdpo_token_clip: per-token JSD clip for stability; None = no clip. replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula). kl_in_reward: when True, apply the KL-to-reference penalty in the **reward** (Composer-2 §4.1 / verl choice) instead of TRL's native **in-loss** k3 term. The penalty is folded into GRPO's advantages at scoring time (``adv -= beta·(KL - group_mean(KL))``) and TRL's in-loss KL is suppressed for that step. The F5 audit's #1 fidelity fix: the 2025/26 evidence (arXiv:2512.21852, verl, TRL #4967) shows k1-in-reward improves OOD generalization where k3-in-reward can collapse. REQUIRES ``beta>0`` (the KL coefficient — also how TRL decides to compute reference logprobs) and ``scale_rewards`` in {none,false} (the advantage-adjustment identity is exact only without std-normalization — the Dr.GRPO / Composer regime). Default False = TRL's native in-loss KL, byte-for-byte legacy behavior. kl_estimator: ``"k1"`` (default; ``logp - ref_logp``, the Composer-2 / verl choice this path exists for) or ``"k3"`` (Schulman; lets an experiment A/B k1-in-reward vs k3-in-reward). Only consulted when ``kl_in_reward=True``. heldout_guard: optional ``HeldOutGuard`` (the #2 collapse safeguard from ``composer_replication.safety``). Default None = OFF (no behavior change whatsoever). When supplied, the trainer folds one checkpoint's metrics into the guard at the ``args.logging_steps`` cadence (the same place the loss components are logged) and HALTS the run on a fired verdict — the run-level reward-hacking / collapse tripwire actually firing instead of sitting inert. heldout_eval_fn: zero-arg callable returning the held-out (real) eval score as a float, evaluated each guard cadence. Injectable so the trainer never hardcodes an eval — pass a closure over your disjoint held-out pool (the ``HeldoutSplit`` discipline). Required whenever ``heldout_guard`` is set; the guard's whole signal is in-loop reward vs. this held-out score. strict_killswitch: when True (default), a fired guard verdict raises ``CollapseStopError`` to hard-stop training (exception-based control flow, matching ``HeldOutGuard.raise_if_fired``). When False the verdict is logged and ``self.control.should_training_stop`` is set so the HF loop ends gracefully after the step (soft stop). Only consulted when ``heldout_guard`` is set. """ def __init__( self, *args: Any, alpha_sdpo: float = 0.0, beta_replay: float = 0.0, sdpo_jsd_beta: float = 0.5, sdpo_temperature: float = 1.0, sdpo_token_clip: float | None = None, replay_dpo_beta: float = 0.1, strict_sdpo_alignment: bool = True, kl_in_reward: bool = False, kl_estimator: str = "k1", heldout_guard: HeldOutGuard | None = None, heldout_eval_fn: Callable[[], float] | None = None, strict_killswitch: bool = True, **kwargs: Any, ): if not _TRL_AVAILABLE: raise ImportError( "ComposerReplicationTrainer requires TRL. Install with " "`pip install -e .[train]`." ) super().__init__(*args, **kwargs) self.alpha_sdpo = alpha_sdpo self.beta_replay = beta_replay self.sdpo_jsd_beta = sdpo_jsd_beta self.sdpo_temperature = sdpo_temperature self.sdpo_token_clip = sdpo_token_clip self.replay_dpo_beta = replay_dpo_beta # When True (default), an SDPO student/teacher shape mismatch is a hard # error — it means the data collator failed to align the post-hint # section, which silently zeroes the distillation signal (the exact # trust-gap flagged in ADR-008). Set False only for production runs # where a single malformed batch should warn-and-skip rather than abort. self.strict_sdpo_alignment = strict_sdpo_alignment # --- k1-in-reward KL (F5 #1 fidelity fix; Composer-2 §4.1 / verl) ---- # OFF by default → TRL's native in-loss k3 KL, byte-for-byte legacy. # When ON we keep self.beta as the KL coef (TRL needs beta>0 to even # create the ref model + compute ref logps), fold the k1 penalty into # advantages during scoring, and zero TRL's in-loss KL per step. self.kl_in_reward = kl_in_reward self.kl_estimator = kl_estimator if kl_in_reward: validate_kl_in_reward_config( kl_estimator=kl_estimator, beta=float(getattr(self.args, "beta", 0.0)), scale_rewards=getattr(self.args, "scale_rewards", "group"), ) # --- run-level collapse kill-switch (#2 safeguard) ------------------- # OPTIONAL + OFF BY DEFAULT: when heldout_guard is None the loss path is # byte-for-byte the legacy behavior. When set, _maybe_update_killswitch # folds metrics into the guard at the logging cadence (see _compute_loss). self.heldout_guard = heldout_guard self.heldout_eval_fn = heldout_eval_fn self.strict_killswitch = strict_killswitch if heldout_guard is not None and heldout_eval_fn is None: raise ValueError( "heldout_guard was provided without heldout_eval_fn: the guard's " "tripwire compares in-loop reward against a DISJOINT held-out " "(real) eval score, so it needs an injectable zero-arg " "heldout_eval_fn() -> float. Pass a closure over your held-out " "pool (the HeldoutSplit discipline)." ) # ---------------------------------------------------------------------- # Loss override (the integration core) # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- # k1-in-reward: fold the KL penalty into advantages at scoring time, and # suppress TRL's native in-loss k3 KL inside _compute_loss. # ---------------------------------------------------------------------- def _generate_and_score_completions( self, inputs: list[dict[str, Any]], ) -> dict[str, Any]: """Override: after TRL scores completions, fold a k1 KL penalty into the advantages (Composer-2 in-reward KL) when ``kl_in_reward`` is set. No-op (exact legacy path) when ``kl_in_reward`` is False. When set, TRL has already computed ``advantages``, ``ref_per_token_logps`` (because ``beta>0``), and the completion logprobs; we recompute the per-sequence k1 penalty and apply the exact group-mean-baseline correction. """ output = super()._generate_and_score_completions(inputs) if not getattr(self, "kl_in_reward", False): return output ref_logps = output.get("ref_per_token_logps") # The "old" (sampling-time) policy logps are TRL's in-loss π term; they # may be lazily None when generation/optimization are aligned and not # vLLM (see TRL _compute_loss: old := per_token_logps.detach()). In that # aligned case we cannot read π logps here, so we defer to _compute_loss # (which always has per_token_logps) by stashing what we need. old_logps = output.get("old_per_token_logps") completion_mask = output.get("completion_mask") if ref_logps is None or completion_mask is None: # beta>0 guarantees ref_logps; this branch only trips on a TRL # internals change — fail loud rather than silently skip the penalty. raise RuntimeError( "kl_in_reward=True but TRL did not return ref_per_token_logps / " "completion_mask from scoring (beta>0 should guarantee them). " "TRL internals may have changed; re-verify the in-reward path." ) if old_logps is not None: penalty = kl_penalty_per_sequence( policy_logps=old_logps, ref_logps=ref_logps, completion_mask=completion_mask, estimator=self.kl_estimator, ) output["advantages"] = apply_kl_in_reward( advantages=output["advantages"], kl_penalty=penalty, num_generations=self.num_generations, coef=float(self.args.beta), ) output["_kl_in_reward_applied"] = torch.tensor(True) else: # Aligned non-vLLM case: π logps materialize only in _compute_loss. # Stash ref logps + mask so _compute_loss can apply the penalty there. output["_kl_in_reward_applied"] = torch.tensor(False) return output def _compute_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """Override: total_loss = grpo + α*sdpo + β*replay. When ``kl_in_reward`` is set, TRL's native in-loss KL term (gated on ``self.beta``) is suppressed by temporarily zeroing ``self.beta`` for the duration of the parent call — the KL has already been (or is about to be) accounted for in the reward/advantage, so double-counting it in the loss would be wrong. ``self.beta`` is restored in ``finally``. """ # Channel 1: standard GRPO loss. ``getattr`` (not ``self.kl_in_reward``) # so an instance built via ``__new__`` + manual wiring (the SDPO / # kill-switch unit-test pattern that skips __init__) defaults to the # legacy path instead of raising AttributeError. if getattr(self, "kl_in_reward", False): grpo_loss = self._grpo_loss_kl_in_reward(model, inputs) else: grpo_loss = super()._compute_loss(model, inputs) # Channel 2: SDPO hint-distill at error sites sdpo_kl = self._compute_sdpo_loss(model, inputs) # Channel 3: trace-replay DPO from teacher disagreement replay_dpo = self._compute_trace_replay_loss(model, inputs) # Compose total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo # Log per-channel components (so we can ablate post-hoc) if hasattr(self, "state") and getattr(self, "args", None) is not None: log_steps = getattr(self.args, "logging_steps", 50) if self.state.global_step % log_steps == 0: self.log({ # type: ignore[attr-defined] "loss/grpo": float(grpo_loss.detach()), "loss/sdpo_kl": float(sdpo_kl.detach()), "loss/trace_replay_dpo": float(replay_dpo.detach()), "loss/total": float(total.detach()), "loss/alpha_sdpo": self.alpha_sdpo, "loss/beta_replay": self.beta_replay, }) # Fold one checkpoint into the run-level collapse kill-switch at # the SAME cadence (no-op unless a guard was configured). self._maybe_update_killswitch() return total def _grpo_loss_kl_in_reward( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """GRPO loss with the KL applied in the reward, not the loss. Two responsibilities: 1. Suppress TRL's native in-loss k3 KL term for this step by zeroing ``self.beta`` across the parent ``_compute_loss`` call (restored in ``finally``). ``self.beta`` gates the in-loss KL add (TRL ``_compute_loss``: ``if self.beta != 0.0: per_token_loss += beta*kl``). 2. Handle the deferred case: when generation/optimization are aligned and not using vLLM, the sampling-time policy logps are None at scoring time, so ``_generate_and_score_completions`` could not fold the penalty into advantages. Here ``per_token_logps`` is available, so we apply the exact same advantage correction in-place on ``inputs["advantages"]`` BEFORE the parent computes the surrogate. """ # Deferred-penalty path: advantages not yet KL-adjusted (aligned, no vLLM). applied = inputs.get("_kl_in_reward_applied") already_applied = bool(applied.item()) if applied is not None else False if not already_applied and "ref_per_token_logps" in inputs: with torch.no_grad(): prompt_ids, completion_ids = inputs["prompt_ids"], inputs["completion_ids"] completion_mask = inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([inputs["prompt_mask"], completion_mask], dim=1) logits_to_keep = completion_ids.size(1) policy_logps, _ = self._get_per_token_logps_and_entropies( model, input_ids, attention_mask, logits_to_keep ) penalty = kl_penalty_per_sequence( policy_logps=policy_logps, ref_logps=inputs["ref_per_token_logps"], completion_mask=completion_mask, estimator=self.kl_estimator, ) advantages = inputs["advantages"] # advantages may be (B,) or (B,1) — squeeze for the penalty math, # restore the original shape after. adv_flat = advantages.reshape(advantages.shape[0]) adj = apply_kl_in_reward( advantages=adv_flat, kl_penalty=penalty, num_generations=self.num_generations, coef=float(self.args.beta), ) inputs["advantages"] = adj.reshape(advantages.shape) # Suppress TRL's in-loss KL: zero beta for the parent call, restore after. saved_beta = self.beta try: self.beta = 0.0 return super()._compute_loss(model, inputs) finally: self.beta = saved_beta # ---------------------------------------------------------------------- # Run-level collapse kill-switch (#2 safeguard) — optional, OFF by default # ---------------------------------------------------------------------- def _maybe_update_killswitch(self) -> None: """Fold this checkpoint's metrics into ``heldout_guard`` and act on a fire. No-op when no guard was configured (the default) — this is the backward-compat guarantee: without ``heldout_guard`` the trainer behaves exactly as before. When a guard IS set: * ``in_loop_reward`` is the GRPO reward signal TRL already aggregates into ``self._metrics[mode]["reward"]`` each step (we read the latest; no extra forward pass). * ``heldout_score`` comes from the injected ``heldout_eval_fn()`` — the trainer never hardcodes an eval. * ``kl_to_init`` (token-mean nats/token, the ``token_mean_kl`` convention the guard expects) is read from TRL's logged ``"kl"`` metric when present, else left None (KL path stays inert). On a fired verdict the verdict is logged. If ``strict_killswitch`` (the default) the verdict is converted into a ``CollapseStopError`` via ``HeldOutGuard.raise_if_fired`` (hard stop); otherwise the HF training loop is asked to stop gracefully after this step. """ guard = self.heldout_guard if guard is None: return # OFF by default — zero behavior change round_idx = int(getattr(self.state, "global_step", 0)) in_loop_reward = self._latest_metric("reward") if in_loop_reward is None: # No reward aggregated yet (e.g. very first micro-step before TRL has # populated its metrics). Skip this cadence rather than feed a # fabricated 0.0 that would pollute the guard's baseline/EMA. logger.debug( "kill-switch: no in-loop reward metric yet at step %d; skipping.", round_idx, ) return assert self.heldout_eval_fn is not None # enforced in __init__ heldout_score = float(self.heldout_eval_fn()) kl_to_init = self._latest_metric("kl") # token-mean KL, or None status = guard.update( round_idx=round_idx, in_loop_reward=in_loop_reward, heldout_score=heldout_score, kl_to_init=kl_to_init, ) self.log({ # type: ignore[attr-defined] "killswitch/in_loop_reward": status.in_loop_ema, "killswitch/heldout_score": status.heldout_ema, "killswitch/proxy_real_gap": status.proxy_real_gap, "killswitch/fire": float(status.fire), }) if status.fire: logger.error( "HeldOutGuard FIRED at step %d — halting run. reason: %s", round_idx, status.reason, ) if self.strict_killswitch: # Typed exception — exception-based hard stop. guard.raise_if_fired(status) else: # Soft stop: let the HF loop terminate gracefully after this step. control = getattr(self, "control", None) if control is not None: control.should_training_stop = True def _latest_metric(self, name: str) -> float | None: """Most-recent value of a TRL-aggregated train metric, or None. TRL's GRPOTrainer appends per-step aggregates to ``self._metrics["train"][name]`` (e.g. ``"reward"``, ``"kl"``). We read the tail defensively so a TRL internals rename degrades to None (KL/reward path goes inert) rather than crashing training. """ metrics = getattr(self, "_metrics", None) if not isinstance(metrics, dict): return None train = metrics.get("train") if not isinstance(train, dict): return None series = train.get(name) if not series: return None try: return float(series[-1]) except (TypeError, ValueError, IndexError): return None # ---------------------------------------------------------------------- # Channel 2: SDPO hint-distill # ---------------------------------------------------------------------- def _compute_sdpo_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """Compute generalized_jsd_loss between student and hint-conditioned teacher. Both come from the SAME model — teacher just has hint inserted into context. Skipped (returns 0) if the batch has no error sites (data collator emits empty ctx_teacher_input_ids). """ if ( self.alpha_sdpo == 0.0 or "ctx_teacher_input_ids" not in inputs or inputs["ctx_teacher_input_ids"].numel() == 0 ): return torch.tensor(0.0, device=_device_of(model), requires_grad=True) # Student forward (with grad, on the original-context input) student_logits = model(input_ids=inputs["input_ids"]).logits # Teacher forward (no grad — same model, hint-conditioned context) with torch.no_grad(): teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits # ------------------------------------------------------------------ # ALIGNMENT (cross-family review 2026-05-29 — the 4/4-reviewer P0). # # The teacher context has a hint inserted at the error turn, so the # teacher's post-hint response tokens are shifted right by len(hint) # relative to the student's. A bare `student.shape == teacher.shape` # check does NOT establish token-level alignment: equal-length tensors # whose response regions are offset will be JSD'd position-by-position # against each other, distilling garbage into the policy. # # The ONLY correct alignment is an explicit map from the collator that # selects, for each response token, the matching index in each sequence. # We require it whenever SDPO is active: # - `student_response_idx` / `teacher_response_idx`: LongTensors of # equal length selecting the aligned response positions in each # sequence (the collator builds these knowing where it inserted the # hint). JSD is computed over the gathered, provably-aligned logits. # - If the collator cannot yet supply them, strict mode raises (loud # failure) rather than silently distilling misaligned tokens. s_idx = inputs.get("student_response_idx") t_idx = inputs.get("teacher_response_idx") if s_idx is None or t_idx is None: msg = ( "SDPO alignment indices missing: the collator must emit " "`student_response_idx` and `teacher_response_idx` (matching " "LongTensors selecting the aligned post-hint response tokens) so " "the JSD compares corresponding tokens. A shape-only check does " "NOT establish alignment — the hint shifts the teacher's response " "tokens right, so equal-length sequences can still be misaligned " "and silently distill garbage into the policy (ADR-008 trust-gap)." ) if self.strict_sdpo_alignment: raise ValueError( msg + " (strict_sdpo_alignment=True; pass False to fall back " "to the legacy shape-only check for resilience.)" ) logger.warning("%s Falling back to shape-only alignment check.", msg) if student_logits.shape != teacher_logits.shape: logger.warning( "SDPO shape mismatch student=%s teacher=%s; skipping.", tuple(student_logits.shape), tuple(teacher_logits.shape), ) return torch.tensor(0.0, device=_device_of(model), requires_grad=True) return generalized_jsd_loss( student_logits=student_logits, teacher_logits=teacher_logits, labels=inputs.get("sdpo_loss_mask"), beta=self.sdpo_jsd_beta, temperature=self.sdpo_temperature, token_clip=self.sdpo_token_clip, reduction="batchmean", ) # Validate the index tensors describe a consistent 1:1 alignment. if s_idx.shape != t_idx.shape: raise ValueError( f"SDPO alignment index shape mismatch: student_response_idx=" f"{tuple(s_idx.shape)} vs teacher_response_idx={tuple(t_idx.shape)}. " "They must select the same number of aligned response tokens." ) # Gather the provably-aligned response logits from each sequence, then # JSD only those positions (this is the masked error-turn distillation). # gather over the sequence dim (dim=1): expand index to the vocab dim. # # ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row # *_valid mask. Negative indices are illegal for torch.gather, so clamp # to 0 before gathering, then neutralize those positions by feeding # labels=-100 (the standard HF ignore convention that generalized_jsd_loss # already honors). This makes sentinel/padding positions contribute 0. # # Final-verify 2026-05-29: combine BOTH valid masks (not just student's) # AND the sentinel guard. If a future collator ever emits divergent # student/teacher valid tails, a teacher sentinel clamped to 0 would # otherwise be silently distilled against teacher position 0. Belt-and- # suspenders: valid iff student-valid AND teacher-valid AND both indices # non-sentinel. s_valid = inputs.get("student_response_valid") t_valid = inputs.get("teacher_response_valid") aligned_mask = (s_idx >= 0) & (t_idx >= 0) if s_valid is not None: aligned_mask = aligned_mask & s_valid.bool() if t_valid is not None: aligned_mask = aligned_mask & t_valid.bool() vocab = student_logits.size(-1) s_safe = s_idx.clamp_min(0) t_safe = t_idx.clamp_min(0) s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab) t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab) student_aligned = torch.gather(student_logits, 1, s_gather) teacher_aligned = torch.gather(teacher_logits, 1, t_gather) # Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at # sentinel/padding positions so they drop out of the JSD reduction. aligned_labels = torch.where( aligned_mask, torch.ones_like(s_idx), torch.full_like(s_idx, -100), ) return generalized_jsd_loss( student_logits=student_aligned, teacher_logits=teacher_aligned, labels=aligned_labels, # sentinel-masked aligned error-turn positions beta=self.sdpo_jsd_beta, temperature=self.sdpo_temperature, token_clip=self.sdpo_token_clip, reduction="batchmean", ) # ---------------------------------------------------------------------- # Channel 3: trace-replay DPO # ---------------------------------------------------------------------- def _compute_trace_replay_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: """Standard DPO loss using (chosen, rejected) pairs from teacher disagreement. DPO loss formula (Rafailov et al. 2023): L = -log σ(β · (logπ(chosen) - logπ_ref(chosen) - logπ(rejected) + logπ_ref(rejected))) Where logπ_ref are precomputed by the data collator using the reference (init student) policy. """ if ( self.beta_replay == 0.0 or "dpo_chosen_input_ids" not in inputs or inputs["dpo_chosen_input_ids"].numel() == 0 ): return torch.tensor(0.0, device=_device_of(model), requires_grad=True) # Forward passes for chosen and rejected, gather logprobs at response tokens chosen_logprobs = self._sequence_logprobs( model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"] ) rejected_logprobs = self._sequence_logprobs( model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"] ) ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"] ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"] logits = self.replay_dpo_beta * ( (chosen_logprobs - ref_chosen_logprobs) - (rejected_logprobs - ref_rejected_logprobs) ) return -F.logsigmoid(logits).mean() @staticmethod def _sequence_logprobs( model: torch.nn.Module, input_ids: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor: """Sum logprob of response tokens given the prompt prefix. Standard DPO accounting: we only score the response tokens (where response_mask == 1), not the prompt tokens. """ outputs = model(input_ids=input_ids) # Shift for next-token prediction: logits[t] predicts input_ids[t+1] logits = outputs.logits[:, :-1, :] targets = input_ids[:, 1:] log_probs = F.log_softmax(logits, dim=-1) token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # Mask out prompt + padding; sum response-token logprobs masked = token_logprobs * response_mask[:, 1:].float() return masked.sum(dim=-1) def _device_of(model: torch.nn.Module) -> torch.device: """Return the device of any parameter of the model — robust to FSDP/DDP wrappers.""" return next(model.parameters()).device def validate_kl_in_reward_config( *, kl_estimator: str, beta: float, scale_rewards: Any, ) -> None: """Validate the (kl_estimator, beta, scale_rewards) combo for k1-in-reward. Extracted so the preconditions are unit-testable without standing up a real GRPOTrainer (which needs a model + dataset). Raises ``ValueError`` on any invalid combination; returns None when the config is sound. Preconditions (see ``kl_in_reward.py`` for the algebra): * ``kl_estimator`` in {k1, k3}. * ``beta != 0`` — TRL only builds the reference model and computes ref logprobs when beta>0, and the in-reward penalty needs ref logps. beta doubles as the in-reward KL coefficient (the in-loss k3 term is suppressed per step). * ``scale_rewards`` in {none, false} — the advantage-adjustment identity is exact only without per-group std-normalization (the Dr.GRPO / Composer regime). """ if kl_estimator not in ("k1", "k3"): raise ValueError(f"kl_estimator must be 'k1' or 'k3', got {kl_estimator!r}.") if float(beta) == 0.0: raise ValueError( "kl_in_reward=True requires a non-zero `beta` (the KL coefficient): " "TRL only creates the reference model and computes ref logprobs when " "beta>0, and k1-in-reward needs those ref logps. Set beta to your KL " "coefficient (e.g. make_po_config('dr_grpo', beta=0.04)); the in-loss " "k3 term is suppressed automatically so beta acts purely as the " "in-reward k1 coefficient." ) if str(scale_rewards).lower() not in ("none", "false"): raise ValueError( "kl_in_reward=True requires scale_rewards in {none,false} " f"(got {scale_rewards!r}). The advantage-adjustment identity " "adv -= beta·(KL - group_mean(KL)) is EXACT only without per-group " "std-normalization (the Dr.GRPO / Composer regime). With std-norm, " "folding KL into the reward also shifts the group std, so the linear " "correction no longer matches true in-reward KL. Use " "make_po_config('dr_grpo', beta=…) (scale_rewards='none')." ) def make_dr_grpo_config(**overrides: Any): """Build a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe. Per the Composer 2 technical report (arXiv:2603.24477, research/10-composer2-techreport-mining.md) the RL base is Dr. GRPO (Liu et al., arXiv:2503.20783): - ``loss_type="dr_grpo"`` — removes GRPO's length-standardization term (which injects a length bias). TRL's own help text cites the Dr. GRPO paper for this. - ``scale_rewards="none"`` — NO std-dev advantage normalization. TRL docs: "The Dr. GRPO paper recommends not scaling rewards, as scaling by the standard deviation introduces a question-level difficulty bias." - ``num_iterations=1`` — single-epoch regime (a prompt is never trained on twice), matching the tech report. - ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012 finding #1, verified against the installed trl==1.5.0 source): ``GRPOTrainer._compute_loss`` uses the **k3** estimator ``exp(ref_logp - logp) - (ref_logp - logp) - 1`` (trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator ``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance, always-non-negative KL approximation; k1 is its unbiased but higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2)) and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL to force k1; we document the honest delta. See ``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``. Any field can be overridden via kwargs (e.g. ``learning_rate=...``, ``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless explicitly overridden, and a sanity assertion guards against silent drift. """ from trl import GRPOConfig # local import: only when actually building a config dr_grpo_defaults: dict[str, Any] = { "loss_type": "dr_grpo", "scale_rewards": "none", "num_iterations": 1, } merged = {**dr_grpo_defaults, **overrides} cfg = GRPOConfig(**merged) # Guard: fail loudly if a future TRL renames/repurposes these knobs. assert cfg.loss_type == merged["loss_type"], ( f"GRPOConfig loss_type drifted: requested {merged['loss_type']!r}, " f"got {cfg.loss_type!r} — TRL may have renamed/repurposed the knob." ) # Dr. GRPO requires NO std-dev advantage normalization. TRL accepts either # the string "none" or the bool False to disable it; normalize before # comparing so a future TRL that switches the representation still passes # (and a genuinely-wrong value like "batch"/"group"/True fails loudly). # (Cross-family review 2026-05-29: the prior literal `("none","False","False")` # had a duplicated "False" and did a brittle case-sensitive str compare.) assert str(cfg.scale_rewards).lower() in ("none", "false"), ( f"Dr. GRPO requires scale_rewards disabled (no std-norm); got " f"{cfg.scale_rewards!r}. TRL knob may have drifted — re-verify against trl version." ) assert cfg.num_iterations == merged["num_iterations"], "GRPOConfig dropped num_iterations" return cfg # --------------------------------------------------------------------------- # Policy-optimization objective MENU (ADR-014) # --------------------------------------------------------------------------- # # The base RL objective used to be hardcoded to Dr.GRPO (make_dr_grpo_config). # make_po_config gives RL a real menu: GRPO-family objectives selectable by name. # Verified against the installed trl==1.5.0 (introspected 2026-05-30): its # GRPOTrainer already implements these as `loss_type` branches + knobs, so EVERY # preset below is pure config — no custom _compute_loss override needed. # # Knob-space each preset sets (all real GRPOConfig fields in trl 1.5.0): # loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo} (gspo = grpo loss + # importance_sampling_level="sequence"; trl has no literal "gspo") # scale_rewards ∈ {"group"(std-norm), "batch", "none"(no std-norm, Dr.GRPO)} # epsilon / epsilon_high — symmetric vs decoupled "clip-higher" (DAPO) # importance_sampling_level ∈ {"token", "sequence"(GSPO)} # beta — KL-to-ref coef (0.0 = reference-free) # mask_truncated_completions — DAPO overlong masking # num_iterations — on-policy reuse (1 = strict on-policy) #: Selectable base policy-optimization objectives (named presets over trl knobs). PO_OBJECTIVES: dict[str, dict[str, Any]] = { # Vanilla GRPO (DeepSeekMath, arXiv 2402.03300): group-relative advantage # WITH std normalization + per-sequence length normalization, KL on. "grpo": { "loss_type": "grpo", "scale_rewards": "group", "importance_sampling_level": "token", "num_iterations": 1, }, # Dr.GRPO (arXiv 2503.20783): remove length-std normalization bias (no # advantage /std, length-independent aggregation). Framework's historical # default (== make_dr_grpo_config). Composer 2.5's base objective. "dr_grpo": { "loss_type": "dr_grpo", "scale_rewards": "none", "importance_sampling_level": "token", "num_iterations": 1, }, # BNPO: batch-normalized variant (trl loss_type), std over the batch. "bnpo": { "loss_type": "bnpo", "scale_rewards": "batch", "importance_sampling_level": "token", "num_iterations": 1, }, # DAPO (arXiv 2503.14476): decoupled "clip-higher" (epsilon_high > epsilon) # + token-level loss + overlong masking + KL removed. High-value, low-cost # anti-entropy-collapse objective. epsilon_high=0.28 per the paper. "dapo": { "loss_type": "dapo", "scale_rewards": "none", "epsilon": 0.2, "epsilon_high": 0.28, "mask_truncated_completions": True, "beta": 0.0, "importance_sampling_level": "token", "num_iterations": 1, }, # GSPO (Qwen, arXiv 2507.18071): SEQUENCE-level importance ratio (one length- # normalized ratio per response) — stabilizes long-CoT and especially MoE RL. # trl expresses this as the grpo loss + importance_sampling_level="sequence". "gspo": { "loss_type": "grpo", "scale_rewards": "group", "importance_sampling_level": "sequence", "num_iterations": 1, }, # CISPO (MiniMax-M1, arXiv 2506.13585): clip the IS weight and detach it as a # constant coefficient on log π — every token keeps a gradient (fixes the # "rare reasoning tokens get zeroed by the clip" pathology). eps_max≈5 (ScaleRL). "cispo": { "loss_type": "cispo", "scale_rewards": "none", "epsilon_high": 5.0, "importance_sampling_level": "token", "num_iterations": 1, }, } def make_po_config(objective: str = "dr_grpo", **overrides: Any): """Build a `trl.GRPOConfig` for a NAMED policy-optimization objective. The menu that gives RL real options beyond the single hardcoded Dr.GRPO recipe. ``objective`` selects a preset from ``PO_OBJECTIVES`` (grpo / dr_grpo / bnpo / dapo / gspo / cispo); ``**overrides`` set or override any GRPOConfig field on top (e.g. ``output_dir=...``, ``beta=...``, ``learning_rate=...``). All presets are PURE CONFIG over trl 1.5.0's GRPOTrainer (verified by introspecting the installed package 2026-05-30): the trainer already implements each ``loss_type`` branch and the ``importance_sampling_level`` / ``epsilon_high`` knobs, so no custom ``_compute_loss`` is needed. See ADR-014. Raises: ValueError: unknown objective (lists the valid menu). AssertionError: a requested knob silently failed to apply (drift guard). """ from trl import GRPOConfig # local import: only when actually building a config key = (objective or "dr_grpo").lower() if key not in PO_OBJECTIVES: raise ValueError( f"Unknown PO objective {objective!r}. Choose from: " f"{sorted(PO_OBJECTIVES)}. (Each is a named preset over trl 1.5.0's " f"GRPOConfig knobs — see PO_OBJECTIVES / ADR-014.)" ) preset = dict(PO_OBJECTIVES[key]) merged = {**preset, **overrides} cfg = GRPOConfig(**merged) # Drift guards: fail loudly if a future trl renamed/repurposed a knob we set, # so a preset can never silently degrade to a different objective. if "loss_type" in merged: assert str(cfg.loss_type) == str(merged["loss_type"]), ( f"GRPOConfig.loss_type drifted: requested {merged['loss_type']!r}, " f"got {cfg.loss_type!r} — trl may have renamed the knob." ) if "importance_sampling_level" in merged and hasattr(cfg, "importance_sampling_level"): assert str(cfg.importance_sampling_level) == str( merged["importance_sampling_level"] ), ( f"importance_sampling_level drifted for objective {key!r}: requested " f"{merged['importance_sampling_level']!r}, got {cfg.importance_sampling_level!r}." ) if key == "gspo": assert str(getattr(cfg, "importance_sampling_level", "token")) == "sequence", ( "GSPO requires importance_sampling_level='sequence'; it was overridden " "to token, which silently degrades GSPO to GRPO. Drop that override." ) if merged.get("epsilon_high") is not None: assert abs( float(getattr(cfg, "epsilon_high", merged["epsilon_high"])) - float(merged["epsilon_high"]) ) < 1e-9, f"epsilon_high (decoupled clip) drifted for {key!r}." return cfg __all__ = [ "ComposerReplicationTrainer", "make_dr_grpo_config", "make_po_config", "PO_OBJECTIVES", "validate_kl_in_reward_config", ]