Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 44,217 Bytes
ac05fbf bd0c358 ac05fbf bd0c358 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf 41289bf ac05fbf e5add15 ac05fbf 41289bf bd0c358 ac05fbf e5add15 ac05fbf bde5c5e 41289bf bd0c358 ac05fbf e5add15 ac05fbf bde5c5e 41289bf bd0c358 ac05fbf 41289bf ac05fbf 41289bf ac05fbf bd0c358 ac05fbf 41289bf bd0c358 ac05fbf 185cce2 bde5c5e 185cce2 ac05fbf bde5c5e 185cce2 bde5c5e 185cce2 d02d724 678d10b d02d724 185cce2 d02d724 185cce2 ac05fbf d02d724 ac05fbf 185cce2 d02d724 ac05fbf 41289bf bde5c5e d02d724 bde5c5e 185cce2 bde5c5e aae66fa 41289bf aae66fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 | """composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels.
Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A".
Verified extension point: GRPOTrainer._compute_loss(model, inputs)
(DeepWiki audit of huggingface/trl, 2026-05-25).
Total loss:
total_loss = grpo_loss
+ alpha_sdpo * sdpo_kl_at_error_turns
+ beta_replay * trace_replay_dpo_loss
Where:
- grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches).
- sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and
teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only.
- trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from
N external teacher disagreement with the student.
The data collator (data_collator.py) is responsible for:
- Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint.
- Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere).
- Loading DPO pairs from the trace-replay output (see teacher_replay.py).
- Precomputing reference-policy logprobs for DPO.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812 — repo-wide torch convention
if TYPE_CHECKING: # type-only — never imported at runtime (keeps the dep lazy)
from composer_replication.safety import HeldOutGuard
# These imports work when TRL is installed — they're not skeleton imports.
# When TRL is missing we fall back to `object` so the module still imports
# (e.g. for documentation generation) but raise a clear ImportError at
# instantiation time rather than the cryptic `object.__init__()` error.
try:
from trl import GRPOTrainer # type: ignore
_TRL_AVAILABLE = True
except ImportError: # pragma: no cover — only hit in unit-test stubs without TRL
GRPOTrainer = object # type: ignore — fallback so module imports without TRL
_TRL_AVAILABLE = False
from composer_replication.opsd import generalized_jsd_loss
from composer_replication.trainer.kl_in_reward import (
apply_kl_in_reward,
kl_penalty_per_sequence,
)
logger = logging.getLogger(__name__)
class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
"""TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.
Args (in addition to GRPOTrainer's):
alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled).
Opt in by passing >0 once your data collator produces
`sdpo_loss_mask` and `ctx_teacher_input_ids` columns.
beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled).
Opt in by passing >0 once your data collator produces
`dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc.
sdpo_jsd_beta: beta param of generalized_jsd_loss
(0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per
upstream OPSD convention; see composer_replication/opsd.py).
sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
sdpo_token_clip: per-token JSD clip for stability; None = no clip.
replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
kl_in_reward: when True, apply the KL-to-reference penalty in the
**reward** (Composer-2 §4.1 / verl choice) instead of TRL's native
**in-loss** k3 term. The penalty is folded into GRPO's advantages at
scoring time (``adv -= beta·(KL - group_mean(KL))``) and TRL's
in-loss KL is suppressed for that step. The F5 audit's #1 fidelity
fix: the 2025/26 evidence (arXiv:2512.21852, verl, TRL #4967) shows
k1-in-reward improves OOD generalization where k3-in-reward can
collapse. REQUIRES ``beta>0`` (the KL coefficient — also how TRL
decides to compute reference logprobs) and ``scale_rewards`` in
{none,false} (the advantage-adjustment identity is exact only
without std-normalization — the Dr.GRPO / Composer regime). Default
False = TRL's native in-loss KL, byte-for-byte legacy behavior.
kl_estimator: ``"k1"`` (default; ``logp - ref_logp``, the Composer-2 /
verl choice this path exists for) or ``"k3"`` (Schulman; lets an
experiment A/B k1-in-reward vs k3-in-reward). Only consulted when
``kl_in_reward=True``.
heldout_guard: optional ``HeldOutGuard`` (the #2 collapse safeguard from
``composer_replication.safety``). Default None = OFF (no behavior
change whatsoever). When supplied, the trainer folds one checkpoint's
metrics into the guard at the ``args.logging_steps`` cadence (the same
place the loss components are logged) and HALTS the run on a fired
verdict — the run-level reward-hacking / collapse tripwire actually
firing instead of sitting inert.
heldout_eval_fn: zero-arg callable returning the held-out (real) eval
score as a float, evaluated each guard cadence. Injectable so the
trainer never hardcodes an eval — pass a closure over your disjoint
held-out pool (the ``HeldoutSplit`` discipline). Required whenever
``heldout_guard`` is set; the guard's whole signal is in-loop reward
vs. this held-out score.
strict_killswitch: when True (default), a fired guard verdict raises
``CollapseStopError`` to hard-stop training (exception-based control
flow, matching ``HeldOutGuard.raise_if_fired``). When False the
verdict is logged and ``self.control.should_training_stop`` is set so
the HF loop ends gracefully after the step (soft stop). Only consulted
when ``heldout_guard`` is set.
"""
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.0,
beta_replay: float = 0.0,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
strict_sdpo_alignment: bool = True,
kl_in_reward: bool = False,
kl_estimator: str = "k1",
heldout_guard: HeldOutGuard | None = None,
heldout_eval_fn: Callable[[], float] | None = None,
strict_killswitch: bool = True,
**kwargs: Any,
):
if not _TRL_AVAILABLE:
raise ImportError(
"ComposerReplicationTrainer requires TRL. Install with "
"`pip install -e .[train]`."
)
super().__init__(*args, **kwargs)
self.alpha_sdpo = alpha_sdpo
self.beta_replay = beta_replay
self.sdpo_jsd_beta = sdpo_jsd_beta
self.sdpo_temperature = sdpo_temperature
self.sdpo_token_clip = sdpo_token_clip
self.replay_dpo_beta = replay_dpo_beta
# When True (default), an SDPO student/teacher shape mismatch is a hard
# error — it means the data collator failed to align the post-hint
# section, which silently zeroes the distillation signal (the exact
# trust-gap flagged in ADR-008). Set False only for production runs
# where a single malformed batch should warn-and-skip rather than abort.
self.strict_sdpo_alignment = strict_sdpo_alignment
# --- k1-in-reward KL (F5 #1 fidelity fix; Composer-2 §4.1 / verl) ----
# OFF by default → TRL's native in-loss k3 KL, byte-for-byte legacy.
# When ON we keep self.beta as the KL coef (TRL needs beta>0 to even
# create the ref model + compute ref logps), fold the k1 penalty into
# advantages during scoring, and zero TRL's in-loss KL per step.
self.kl_in_reward = kl_in_reward
self.kl_estimator = kl_estimator
if kl_in_reward:
validate_kl_in_reward_config(
kl_estimator=kl_estimator,
beta=float(getattr(self.args, "beta", 0.0)),
scale_rewards=getattr(self.args, "scale_rewards", "group"),
)
# --- run-level collapse kill-switch (#2 safeguard) -------------------
# OPTIONAL + OFF BY DEFAULT: when heldout_guard is None the loss path is
# byte-for-byte the legacy behavior. When set, _maybe_update_killswitch
# folds metrics into the guard at the logging cadence (see _compute_loss).
self.heldout_guard = heldout_guard
self.heldout_eval_fn = heldout_eval_fn
self.strict_killswitch = strict_killswitch
if heldout_guard is not None and heldout_eval_fn is None:
raise ValueError(
"heldout_guard was provided without heldout_eval_fn: the guard's "
"tripwire compares in-loop reward against a DISJOINT held-out "
"(real) eval score, so it needs an injectable zero-arg "
"heldout_eval_fn() -> float. Pass a closure over your held-out "
"pool (the HeldoutSplit discipline)."
)
# ----------------------------------------------------------------------
# Loss override (the integration core)
# ----------------------------------------------------------------------
# ----------------------------------------------------------------------
# k1-in-reward: fold the KL penalty into advantages at scoring time, and
# suppress TRL's native in-loss k3 KL inside _compute_loss.
# ----------------------------------------------------------------------
def _generate_and_score_completions(
self,
inputs: list[dict[str, Any]],
) -> dict[str, Any]:
"""Override: after TRL scores completions, fold a k1 KL penalty into the
advantages (Composer-2 in-reward KL) when ``kl_in_reward`` is set.
No-op (exact legacy path) when ``kl_in_reward`` is False. When set, TRL
has already computed ``advantages``, ``ref_per_token_logps`` (because
``beta>0``), and the completion logprobs; we recompute the per-sequence
k1 penalty and apply the exact group-mean-baseline correction.
"""
output = super()._generate_and_score_completions(inputs)
if not getattr(self, "kl_in_reward", False):
return output
ref_logps = output.get("ref_per_token_logps")
# The "old" (sampling-time) policy logps are TRL's in-loss π term; they
# may be lazily None when generation/optimization are aligned and not
# vLLM (see TRL _compute_loss: old := per_token_logps.detach()). In that
# aligned case we cannot read π logps here, so we defer to _compute_loss
# (which always has per_token_logps) by stashing what we need.
old_logps = output.get("old_per_token_logps")
completion_mask = output.get("completion_mask")
if ref_logps is None or completion_mask is None:
# beta>0 guarantees ref_logps; this branch only trips on a TRL
# internals change — fail loud rather than silently skip the penalty.
raise RuntimeError(
"kl_in_reward=True but TRL did not return ref_per_token_logps / "
"completion_mask from scoring (beta>0 should guarantee them). "
"TRL internals may have changed; re-verify the in-reward path."
)
if old_logps is not None:
penalty = kl_penalty_per_sequence(
policy_logps=old_logps,
ref_logps=ref_logps,
completion_mask=completion_mask,
estimator=self.kl_estimator,
)
output["advantages"] = apply_kl_in_reward(
advantages=output["advantages"],
kl_penalty=penalty,
num_generations=self.num_generations,
coef=float(self.args.beta),
)
output["_kl_in_reward_applied"] = torch.tensor(True)
else:
# Aligned non-vLLM case: π logps materialize only in _compute_loss.
# Stash ref logps + mask so _compute_loss can apply the penalty there.
output["_kl_in_reward_applied"] = torch.tensor(False)
return output
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Override: total_loss = grpo + α*sdpo + β*replay.
When ``kl_in_reward`` is set, TRL's native in-loss KL term (gated on
``self.beta``) is suppressed by temporarily zeroing ``self.beta`` for the
duration of the parent call — the KL has already been (or is about to be)
accounted for in the reward/advantage, so double-counting it in the loss
would be wrong. ``self.beta`` is restored in ``finally``.
"""
# Channel 1: standard GRPO loss. ``getattr`` (not ``self.kl_in_reward``)
# so an instance built via ``__new__`` + manual wiring (the SDPO /
# kill-switch unit-test pattern that skips __init__) defaults to the
# legacy path instead of raising AttributeError.
if getattr(self, "kl_in_reward", False):
grpo_loss = self._grpo_loss_kl_in_reward(model, inputs)
else:
grpo_loss = super()._compute_loss(model, inputs)
# Channel 2: SDPO hint-distill at error sites
sdpo_kl = self._compute_sdpo_loss(model, inputs)
# Channel 3: trace-replay DPO from teacher disagreement
replay_dpo = self._compute_trace_replay_loss(model, inputs)
# Compose
total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo
# Log per-channel components (so we can ablate post-hoc)
if hasattr(self, "state") and getattr(self, "args", None) is not None:
log_steps = getattr(self.args, "logging_steps", 50)
if self.state.global_step % log_steps == 0:
self.log({ # type: ignore[attr-defined]
"loss/grpo": float(grpo_loss.detach()),
"loss/sdpo_kl": float(sdpo_kl.detach()),
"loss/trace_replay_dpo": float(replay_dpo.detach()),
"loss/total": float(total.detach()),
"loss/alpha_sdpo": self.alpha_sdpo,
"loss/beta_replay": self.beta_replay,
})
# Fold one checkpoint into the run-level collapse kill-switch at
# the SAME cadence (no-op unless a guard was configured).
self._maybe_update_killswitch()
return total
def _grpo_loss_kl_in_reward(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""GRPO loss with the KL applied in the reward, not the loss.
Two responsibilities:
1. Suppress TRL's native in-loss k3 KL term for this step by zeroing
``self.beta`` across the parent ``_compute_loss`` call (restored in
``finally``). ``self.beta`` gates the in-loss KL add (TRL
``_compute_loss``: ``if self.beta != 0.0: per_token_loss += beta*kl``).
2. Handle the deferred case: when generation/optimization are aligned
and not using vLLM, the sampling-time policy logps are None at
scoring time, so ``_generate_and_score_completions`` could not fold
the penalty into advantages. Here ``per_token_logps`` is available,
so we apply the exact same advantage correction in-place on
``inputs["advantages"]`` BEFORE the parent computes the surrogate.
"""
# Deferred-penalty path: advantages not yet KL-adjusted (aligned, no vLLM).
applied = inputs.get("_kl_in_reward_applied")
already_applied = bool(applied.item()) if applied is not None else False
if not already_applied and "ref_per_token_logps" in inputs:
with torch.no_grad():
prompt_ids, completion_ids = inputs["prompt_ids"], inputs["completion_ids"]
completion_mask = inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([inputs["prompt_mask"], completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
policy_logps, _ = self._get_per_token_logps_and_entropies(
model, input_ids, attention_mask, logits_to_keep
)
penalty = kl_penalty_per_sequence(
policy_logps=policy_logps,
ref_logps=inputs["ref_per_token_logps"],
completion_mask=completion_mask,
estimator=self.kl_estimator,
)
advantages = inputs["advantages"]
# advantages may be (B,) or (B,1) — squeeze for the penalty math,
# restore the original shape after.
adv_flat = advantages.reshape(advantages.shape[0])
adj = apply_kl_in_reward(
advantages=adv_flat,
kl_penalty=penalty,
num_generations=self.num_generations,
coef=float(self.args.beta),
)
inputs["advantages"] = adj.reshape(advantages.shape)
# Suppress TRL's in-loss KL: zero beta for the parent call, restore after.
saved_beta = self.beta
try:
self.beta = 0.0
return super()._compute_loss(model, inputs)
finally:
self.beta = saved_beta
# ----------------------------------------------------------------------
# Run-level collapse kill-switch (#2 safeguard) — optional, OFF by default
# ----------------------------------------------------------------------
def _maybe_update_killswitch(self) -> None:
"""Fold this checkpoint's metrics into ``heldout_guard`` and act on a fire.
No-op when no guard was configured (the default) — this is the
backward-compat guarantee: without ``heldout_guard`` the trainer behaves
exactly as before. When a guard IS set:
* ``in_loop_reward`` is the GRPO reward signal TRL already aggregates
into ``self._metrics[mode]["reward"]`` each step (we read the latest;
no extra forward pass).
* ``heldout_score`` comes from the injected ``heldout_eval_fn()`` — the
trainer never hardcodes an eval.
* ``kl_to_init`` (token-mean nats/token, the ``token_mean_kl``
convention the guard expects) is read from TRL's logged ``"kl"``
metric when present, else left None (KL path stays inert).
On a fired verdict the verdict is logged. If ``strict_killswitch`` (the
default) the verdict is converted into a ``CollapseStopError`` via
``HeldOutGuard.raise_if_fired`` (hard stop); otherwise the HF training
loop is asked to stop gracefully after this step.
"""
guard = self.heldout_guard
if guard is None:
return # OFF by default — zero behavior change
round_idx = int(getattr(self.state, "global_step", 0))
in_loop_reward = self._latest_metric("reward")
if in_loop_reward is None:
# No reward aggregated yet (e.g. very first micro-step before TRL has
# populated its metrics). Skip this cadence rather than feed a
# fabricated 0.0 that would pollute the guard's baseline/EMA.
logger.debug(
"kill-switch: no in-loop reward metric yet at step %d; skipping.",
round_idx,
)
return
assert self.heldout_eval_fn is not None # enforced in __init__
heldout_score = float(self.heldout_eval_fn())
kl_to_init = self._latest_metric("kl") # token-mean KL, or None
status = guard.update(
round_idx=round_idx,
in_loop_reward=in_loop_reward,
heldout_score=heldout_score,
kl_to_init=kl_to_init,
)
self.log({ # type: ignore[attr-defined]
"killswitch/in_loop_reward": status.in_loop_ema,
"killswitch/heldout_score": status.heldout_ema,
"killswitch/proxy_real_gap": status.proxy_real_gap,
"killswitch/fire": float(status.fire),
})
if status.fire:
logger.error(
"HeldOutGuard FIRED at step %d — halting run. reason: %s",
round_idx, status.reason,
)
if self.strict_killswitch:
# Typed exception — exception-based hard stop.
guard.raise_if_fired(status)
else:
# Soft stop: let the HF loop terminate gracefully after this step.
control = getattr(self, "control", None)
if control is not None:
control.should_training_stop = True
def _latest_metric(self, name: str) -> float | None:
"""Most-recent value of a TRL-aggregated train metric, or None.
TRL's GRPOTrainer appends per-step aggregates to
``self._metrics["train"][name]`` (e.g. ``"reward"``, ``"kl"``). We read
the tail defensively so a TRL internals rename degrades to None (KL/reward
path goes inert) rather than crashing training.
"""
metrics = getattr(self, "_metrics", None)
if not isinstance(metrics, dict):
return None
train = metrics.get("train")
if not isinstance(train, dict):
return None
series = train.get(name)
if not series:
return None
try:
return float(series[-1])
except (TypeError, ValueError, IndexError):
return None
# ----------------------------------------------------------------------
# Channel 2: SDPO hint-distill
# ----------------------------------------------------------------------
def _compute_sdpo_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Compute generalized_jsd_loss between student and hint-conditioned teacher.
Both come from the SAME model — teacher just has hint inserted into context.
Skipped (returns 0) if the batch has no error sites (data collator emits
empty ctx_teacher_input_ids).
"""
if (
self.alpha_sdpo == 0.0
or "ctx_teacher_input_ids" not in inputs
or inputs["ctx_teacher_input_ids"].numel() == 0
):
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
# Student forward (with grad, on the original-context input)
student_logits = model(input_ids=inputs["input_ids"]).logits
# Teacher forward (no grad — same model, hint-conditioned context)
with torch.no_grad():
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
# ------------------------------------------------------------------
# ALIGNMENT (cross-family review 2026-05-29 — the 4/4-reviewer P0).
#
# The teacher context has a hint inserted at the error turn, so the
# teacher's post-hint response tokens are shifted right by len(hint)
# relative to the student's. A bare `student.shape == teacher.shape`
# check does NOT establish token-level alignment: equal-length tensors
# whose response regions are offset will be JSD'd position-by-position
# against each other, distilling garbage into the policy.
#
# The ONLY correct alignment is an explicit map from the collator that
# selects, for each response token, the matching index in each sequence.
# We require it whenever SDPO is active:
# - `student_response_idx` / `teacher_response_idx`: LongTensors of
# equal length selecting the aligned response positions in each
# sequence (the collator builds these knowing where it inserted the
# hint). JSD is computed over the gathered, provably-aligned logits.
# - If the collator cannot yet supply them, strict mode raises (loud
# failure) rather than silently distilling misaligned tokens.
s_idx = inputs.get("student_response_idx")
t_idx = inputs.get("teacher_response_idx")
if s_idx is None or t_idx is None:
msg = (
"SDPO alignment indices missing: the collator must emit "
"`student_response_idx` and `teacher_response_idx` (matching "
"LongTensors selecting the aligned post-hint response tokens) so "
"the JSD compares corresponding tokens. A shape-only check does "
"NOT establish alignment — the hint shifts the teacher's response "
"tokens right, so equal-length sequences can still be misaligned "
"and silently distill garbage into the policy (ADR-008 trust-gap)."
)
if self.strict_sdpo_alignment:
raise ValueError(
msg + " (strict_sdpo_alignment=True; pass False to fall back "
"to the legacy shape-only check for resilience.)"
)
logger.warning("%s Falling back to shape-only alignment check.", msg)
if student_logits.shape != teacher_logits.shape:
logger.warning(
"SDPO shape mismatch student=%s teacher=%s; skipping.",
tuple(student_logits.shape), tuple(teacher_logits.shape),
)
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
return generalized_jsd_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
labels=inputs.get("sdpo_loss_mask"),
beta=self.sdpo_jsd_beta,
temperature=self.sdpo_temperature,
token_clip=self.sdpo_token_clip,
reduction="batchmean",
)
# Validate the index tensors describe a consistent 1:1 alignment.
if s_idx.shape != t_idx.shape:
raise ValueError(
f"SDPO alignment index shape mismatch: student_response_idx="
f"{tuple(s_idx.shape)} vs teacher_response_idx={tuple(t_idx.shape)}. "
"They must select the same number of aligned response tokens."
)
# Gather the provably-aligned response logits from each sequence, then
# JSD only those positions (this is the masked error-turn distillation).
# gather over the sequence dim (dim=1): expand index to the vocab dim.
#
# ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row
# *_valid mask. Negative indices are illegal for torch.gather, so clamp
# to 0 before gathering, then neutralize those positions by feeding
# labels=-100 (the standard HF ignore convention that generalized_jsd_loss
# already honors). This makes sentinel/padding positions contribute 0.
#
# Final-verify 2026-05-29: combine BOTH valid masks (not just student's)
# AND the sentinel guard. If a future collator ever emits divergent
# student/teacher valid tails, a teacher sentinel clamped to 0 would
# otherwise be silently distilled against teacher position 0. Belt-and-
# suspenders: valid iff student-valid AND teacher-valid AND both indices
# non-sentinel.
s_valid = inputs.get("student_response_valid")
t_valid = inputs.get("teacher_response_valid")
aligned_mask = (s_idx >= 0) & (t_idx >= 0)
if s_valid is not None:
aligned_mask = aligned_mask & s_valid.bool()
if t_valid is not None:
aligned_mask = aligned_mask & t_valid.bool()
vocab = student_logits.size(-1)
s_safe = s_idx.clamp_min(0)
t_safe = t_idx.clamp_min(0)
s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab)
t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab)
student_aligned = torch.gather(student_logits, 1, s_gather)
teacher_aligned = torch.gather(teacher_logits, 1, t_gather)
# Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at
# sentinel/padding positions so they drop out of the JSD reduction.
aligned_labels = torch.where(
aligned_mask,
torch.ones_like(s_idx),
torch.full_like(s_idx, -100),
)
return generalized_jsd_loss(
student_logits=student_aligned,
teacher_logits=teacher_aligned,
labels=aligned_labels, # sentinel-masked aligned error-turn positions
beta=self.sdpo_jsd_beta,
temperature=self.sdpo_temperature,
token_clip=self.sdpo_token_clip,
reduction="batchmean",
)
# ----------------------------------------------------------------------
# Channel 3: trace-replay DPO
# ----------------------------------------------------------------------
def _compute_trace_replay_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Standard DPO loss using (chosen, rejected) pairs from teacher disagreement.
DPO loss formula (Rafailov et al. 2023):
L = -log σ(β · (logπ(chosen) - logπ_ref(chosen)
- logπ(rejected) + logπ_ref(rejected)))
Where logπ_ref are precomputed by the data collator using the
reference (init student) policy.
"""
if (
self.beta_replay == 0.0
or "dpo_chosen_input_ids" not in inputs
or inputs["dpo_chosen_input_ids"].numel() == 0
):
return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
# Forward passes for chosen and rejected, gather logprobs at response tokens
chosen_logprobs = self._sequence_logprobs(
model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"]
)
rejected_logprobs = self._sequence_logprobs(
model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"]
)
ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"]
ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"]
logits = self.replay_dpo_beta * (
(chosen_logprobs - ref_chosen_logprobs)
- (rejected_logprobs - ref_rejected_logprobs)
)
return -F.logsigmoid(logits).mean()
@staticmethod
def _sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Sum logprob of response tokens given the prompt prefix.
Standard DPO accounting: we only score the response tokens (where
response_mask == 1), not the prompt tokens.
"""
outputs = model(input_ids=input_ids)
# Shift for next-token prediction: logits[t] predicts input_ids[t+1]
logits = outputs.logits[:, :-1, :]
targets = input_ids[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
# Mask out prompt + padding; sum response-token logprobs
masked = token_logprobs * response_mask[:, 1:].float()
return masked.sum(dim=-1)
def _device_of(model: torch.nn.Module) -> torch.device:
"""Return the device of any parameter of the model — robust to FSDP/DDP wrappers."""
return next(model.parameters()).device
def validate_kl_in_reward_config(
*,
kl_estimator: str,
beta: float,
scale_rewards: Any,
) -> None:
"""Validate the (kl_estimator, beta, scale_rewards) combo for k1-in-reward.
Extracted so the preconditions are unit-testable without standing up a real
GRPOTrainer (which needs a model + dataset). Raises ``ValueError`` on any
invalid combination; returns None when the config is sound.
Preconditions (see ``kl_in_reward.py`` for the algebra):
* ``kl_estimator`` in {k1, k3}.
* ``beta != 0`` — TRL only builds the reference model and computes ref
logprobs when beta>0, and the in-reward penalty needs ref logps. beta
doubles as the in-reward KL coefficient (the in-loss k3 term is
suppressed per step).
* ``scale_rewards`` in {none, false} — the advantage-adjustment identity
is exact only without per-group std-normalization (the Dr.GRPO /
Composer regime).
"""
if kl_estimator not in ("k1", "k3"):
raise ValueError(f"kl_estimator must be 'k1' or 'k3', got {kl_estimator!r}.")
if float(beta) == 0.0:
raise ValueError(
"kl_in_reward=True requires a non-zero `beta` (the KL coefficient): "
"TRL only creates the reference model and computes ref logprobs when "
"beta>0, and k1-in-reward needs those ref logps. Set beta to your KL "
"coefficient (e.g. make_po_config('dr_grpo', beta=0.04)); the in-loss "
"k3 term is suppressed automatically so beta acts purely as the "
"in-reward k1 coefficient."
)
if str(scale_rewards).lower() not in ("none", "false"):
raise ValueError(
"kl_in_reward=True requires scale_rewards in {none,false} "
f"(got {scale_rewards!r}). The advantage-adjustment identity "
"adv -= beta·(KL - group_mean(KL)) is EXACT only without per-group "
"std-normalization (the Dr.GRPO / Composer regime). With std-norm, "
"folding KL into the reward also shifts the group std, so the linear "
"correction no longer matches true in-reward KL. Use "
"make_po_config('dr_grpo', beta=…) (scale_rewards='none')."
)
def make_dr_grpo_config(**overrides: Any):
"""Build a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe.
Per the Composer 2 technical report (arXiv:2603.24477,
research/10-composer2-techreport-mining.md) the RL base is Dr. GRPO
(Liu et al., arXiv:2503.20783):
- ``loss_type="dr_grpo"`` — removes GRPO's length-standardization term
(which injects a length bias). TRL's own help text cites the Dr. GRPO
paper for this.
- ``scale_rewards="none"`` — NO std-dev advantage normalization. TRL docs:
"The Dr. GRPO paper recommends not scaling rewards, as scaling by the
standard deviation introduces a question-level difficulty bias."
- ``num_iterations=1`` — single-epoch regime (a prompt is never
trained on twice), matching the tech report.
- ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012
finding #1, verified against the installed trl==1.5.0 source):
``GRPOTrainer._compute_loss`` uses the **k3** estimator
``exp(ref_logp - logp) - (ref_logp - logp) - 1``
(trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator
``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance,
always-non-negative KL approximation; k1 is its unbiased but
higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses
KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2))
and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL
to force k1; we document the honest delta. See
``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``.
Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
explicitly overridden, and a sanity assertion guards against silent drift.
"""
from trl import GRPOConfig # local import: only when actually building a config
dr_grpo_defaults: dict[str, Any] = {
"loss_type": "dr_grpo",
"scale_rewards": "none",
"num_iterations": 1,
}
merged = {**dr_grpo_defaults, **overrides}
cfg = GRPOConfig(**merged)
# Guard: fail loudly if a future TRL renames/repurposes these knobs.
assert cfg.loss_type == merged["loss_type"], (
f"GRPOConfig loss_type drifted: requested {merged['loss_type']!r}, "
f"got {cfg.loss_type!r} — TRL may have renamed/repurposed the knob."
)
# Dr. GRPO requires NO std-dev advantage normalization. TRL accepts either
# the string "none" or the bool False to disable it; normalize before
# comparing so a future TRL that switches the representation still passes
# (and a genuinely-wrong value like "batch"/"group"/True fails loudly).
# (Cross-family review 2026-05-29: the prior literal `("none","False","False")`
# had a duplicated "False" and did a brittle case-sensitive str compare.)
assert str(cfg.scale_rewards).lower() in ("none", "false"), (
f"Dr. GRPO requires scale_rewards disabled (no std-norm); got "
f"{cfg.scale_rewards!r}. TRL knob may have drifted — re-verify against trl version."
)
assert cfg.num_iterations == merged["num_iterations"], "GRPOConfig dropped num_iterations"
return cfg
# ---------------------------------------------------------------------------
# Policy-optimization objective MENU (ADR-014)
# ---------------------------------------------------------------------------
#
# The base RL objective used to be hardcoded to Dr.GRPO (make_dr_grpo_config).
# make_po_config gives RL a real menu: GRPO-family objectives selectable by name.
# Verified against the installed trl==1.5.0 (introspected 2026-05-30): its
# GRPOTrainer already implements these as `loss_type` branches + knobs, so EVERY
# preset below is pure config — no custom _compute_loss override needed.
#
# Knob-space each preset sets (all real GRPOConfig fields in trl 1.5.0):
# loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo} (gspo = grpo loss +
# importance_sampling_level="sequence"; trl has no literal "gspo")
# scale_rewards ∈ {"group"(std-norm), "batch", "none"(no std-norm, Dr.GRPO)}
# epsilon / epsilon_high — symmetric vs decoupled "clip-higher" (DAPO)
# importance_sampling_level ∈ {"token", "sequence"(GSPO)}
# beta — KL-to-ref coef (0.0 = reference-free)
# mask_truncated_completions — DAPO overlong masking
# num_iterations — on-policy reuse (1 = strict on-policy)
#: Selectable base policy-optimization objectives (named presets over trl knobs).
PO_OBJECTIVES: dict[str, dict[str, Any]] = {
# Vanilla GRPO (DeepSeekMath, arXiv 2402.03300): group-relative advantage
# WITH std normalization + per-sequence length normalization, KL on.
"grpo": {
"loss_type": "grpo",
"scale_rewards": "group",
"importance_sampling_level": "token",
"num_iterations": 1,
},
# Dr.GRPO (arXiv 2503.20783): remove length-std normalization bias (no
# advantage /std, length-independent aggregation). Framework's historical
# default (== make_dr_grpo_config). Composer 2.5's base objective.
"dr_grpo": {
"loss_type": "dr_grpo",
"scale_rewards": "none",
"importance_sampling_level": "token",
"num_iterations": 1,
},
# BNPO: batch-normalized variant (trl loss_type), std over the batch.
"bnpo": {
"loss_type": "bnpo",
"scale_rewards": "batch",
"importance_sampling_level": "token",
"num_iterations": 1,
},
# DAPO (arXiv 2503.14476): decoupled "clip-higher" (epsilon_high > epsilon)
# + token-level loss + overlong masking + KL removed. High-value, low-cost
# anti-entropy-collapse objective. epsilon_high=0.28 per the paper.
"dapo": {
"loss_type": "dapo",
"scale_rewards": "none",
"epsilon": 0.2,
"epsilon_high": 0.28,
"mask_truncated_completions": True,
"beta": 0.0,
"importance_sampling_level": "token",
"num_iterations": 1,
},
# GSPO (Qwen, arXiv 2507.18071): SEQUENCE-level importance ratio (one length-
# normalized ratio per response) — stabilizes long-CoT and especially MoE RL.
# trl expresses this as the grpo loss + importance_sampling_level="sequence".
"gspo": {
"loss_type": "grpo",
"scale_rewards": "group",
"importance_sampling_level": "sequence",
"num_iterations": 1,
},
# CISPO (MiniMax-M1, arXiv 2506.13585): clip the IS weight and detach it as a
# constant coefficient on log π — every token keeps a gradient (fixes the
# "rare reasoning tokens get zeroed by the clip" pathology). eps_max≈5 (ScaleRL).
"cispo": {
"loss_type": "cispo",
"scale_rewards": "none",
"epsilon_high": 5.0,
"importance_sampling_level": "token",
"num_iterations": 1,
},
}
def make_po_config(objective: str = "dr_grpo", **overrides: Any):
"""Build a `trl.GRPOConfig` for a NAMED policy-optimization objective.
The menu that gives RL real options beyond the single hardcoded Dr.GRPO
recipe. ``objective`` selects a preset from ``PO_OBJECTIVES`` (grpo /
dr_grpo / bnpo / dapo / gspo / cispo); ``**overrides`` set or override any
GRPOConfig field on top (e.g. ``output_dir=...``, ``beta=...``,
``learning_rate=...``).
All presets are PURE CONFIG over trl 1.5.0's GRPOTrainer (verified by
introspecting the installed package 2026-05-30): the trainer already
implements each ``loss_type`` branch and the ``importance_sampling_level`` /
``epsilon_high`` knobs, so no custom ``_compute_loss`` is needed. See ADR-014.
Raises:
ValueError: unknown objective (lists the valid menu).
AssertionError: a requested knob silently failed to apply (drift guard).
"""
from trl import GRPOConfig # local import: only when actually building a config
key = (objective or "dr_grpo").lower()
if key not in PO_OBJECTIVES:
raise ValueError(
f"Unknown PO objective {objective!r}. Choose from: "
f"{sorted(PO_OBJECTIVES)}. (Each is a named preset over trl 1.5.0's "
f"GRPOConfig knobs — see PO_OBJECTIVES / ADR-014.)"
)
preset = dict(PO_OBJECTIVES[key])
merged = {**preset, **overrides}
cfg = GRPOConfig(**merged)
# Drift guards: fail loudly if a future trl renamed/repurposed a knob we set,
# so a preset can never silently degrade to a different objective.
if "loss_type" in merged:
assert str(cfg.loss_type) == str(merged["loss_type"]), (
f"GRPOConfig.loss_type drifted: requested {merged['loss_type']!r}, "
f"got {cfg.loss_type!r} — trl may have renamed the knob."
)
if "importance_sampling_level" in merged and hasattr(cfg, "importance_sampling_level"):
assert str(cfg.importance_sampling_level) == str(
merged["importance_sampling_level"]
), (
f"importance_sampling_level drifted for objective {key!r}: requested "
f"{merged['importance_sampling_level']!r}, got {cfg.importance_sampling_level!r}."
)
if key == "gspo":
assert str(getattr(cfg, "importance_sampling_level", "token")) == "sequence", (
"GSPO requires importance_sampling_level='sequence'; it was overridden "
"to token, which silently degrades GSPO to GRPO. Drop that override."
)
if merged.get("epsilon_high") is not None:
assert abs(
float(getattr(cfg, "epsilon_high", merged["epsilon_high"]))
- float(merged["epsilon_high"])
) < 1e-9, f"epsilon_high (decoupled clip) drifted for {key!r}."
return cfg
__all__ = [
"ComposerReplicationTrainer",
"make_dr_grpo_config",
"make_po_config",
"PO_OBJECTIVES",
"validate_kl_in_reward_config",
]
|