File size: 44,217 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
ac05fbf
 
bd0c358
 
 
 
ac05fbf
 
e5add15
 
 
ac05fbf
 
e5add15
ac05fbf
 
e5add15
ac05fbf
 
41289bf
 
 
 
ac05fbf
 
 
 
 
 
 
 
e5add15
 
 
 
 
 
 
 
 
ac05fbf
 
 
41289bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
e5add15
 
ac05fbf
 
 
 
bde5c5e
41289bf
 
bd0c358
 
 
ac05fbf
 
e5add15
 
 
 
 
ac05fbf
 
 
 
 
 
 
bde5c5e
 
 
 
 
 
41289bf
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
41289bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
41289bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
 
ac05fbf
 
 
41289bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185cce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bde5c5e
185cce2
 
 
 
 
 
 
ac05fbf
bde5c5e
 
185cce2
 
bde5c5e
185cce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d02d724
 
 
 
 
 
678d10b
 
 
 
 
 
 
 
 
 
 
 
 
 
d02d724
185cce2
d02d724
 
 
 
185cce2
 
ac05fbf
d02d724
 
 
 
 
 
 
 
ac05fbf
185cce2
 
d02d724
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41289bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bde5c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d02d724
 
 
 
 
 
 
 
 
 
 
 
bde5c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185cce2
 
 
 
 
 
 
 
 
 
 
 
 
bde5c5e
 
 
 
 
aae66fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41289bf
aae66fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
"""composer_trainer.py — TRL GRPOTrainer subclass with SDPO + trace-replay channels.

Architecture spec: docs/INTEGRATION_ARCHITECTURE.md § "Recipe A".
Verified extension point: GRPOTrainer._compute_loss(model, inputs)
  (DeepWiki audit of huggingface/trl, 2026-05-25).

Total loss:
    total_loss = grpo_loss
               + alpha_sdpo  * sdpo_kl_at_error_turns
               + beta_replay * trace_replay_dpo_loss

Where:
  - grpo_loss is the parent GRPOTrainer's loss (RLVR + DAPO patches).
  - sdpo_kl_at_error_turns is generalized_jsd_loss between student's logits and
    teacher's (= same-model-with-hint-context) logits, masked to error-turn tokens only.
  - trace_replay_dpo_loss is DPO loss over (chosen, rejected) pairs derived from
    N external teacher disagreement with the student.

The data collator (data_collator.py) is responsible for:
  - Detecting error sites in the rollout and constructing ctx_teacher = ctx_student + hint.
  - Computing sdpo_loss_mask (1 at post-hint error-turn tokens, 0 elsewhere).
  - Loading DPO pairs from the trace-replay output (see teacher_replay.py).
  - Precomputing reference-policy logprobs for DPO.
"""

from __future__ import annotations

import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import torch
import torch.nn.functional as F  # noqa: N812 — repo-wide torch convention

if TYPE_CHECKING:  # type-only — never imported at runtime (keeps the dep lazy)
    from composer_replication.safety import HeldOutGuard

# These imports work when TRL is installed — they're not skeleton imports.
# When TRL is missing we fall back to `object` so the module still imports
# (e.g. for documentation generation) but raise a clear ImportError at
# instantiation time rather than the cryptic `object.__init__()` error.
try:
    from trl import GRPOTrainer  # type: ignore
    _TRL_AVAILABLE = True
except ImportError:  # pragma: no cover — only hit in unit-test stubs without TRL
    GRPOTrainer = object  # type: ignore — fallback so module imports without TRL
    _TRL_AVAILABLE = False

from composer_replication.opsd import generalized_jsd_loss
from composer_replication.trainer.kl_in_reward import (
    apply_kl_in_reward,
    kl_penalty_per_sequence,
)

logger = logging.getLogger(__name__)


class ComposerReplicationTrainer(GRPOTrainer):  # type: ignore[misc, valid-type]
    """TRL GRPOTrainer with Composer-recipe channels (SDPO) + novel trace-replay-DPO.

    Args (in addition to GRPOTrainer's):
        alpha_sdpo: weight on SDPO hint-distill loss. Default 0.0 (disabled).
            Opt in by passing >0 once your data collator produces
            `sdpo_loss_mask` and `ctx_teacher_input_ids` columns.
        beta_replay: weight on trace-replay DPO loss. Default 0.0 (disabled).
            Opt in by passing >0 once your data collator produces
            `dpo_chosen_input_ids` / `dpo_rejected_input_ids` etc.
        sdpo_jsd_beta: beta param of generalized_jsd_loss
            (0=KL(teacher||student), 0.5=JSD, 1=KL(student||teacher) per
            upstream OPSD convention; see composer_replication/opsd.py).
        sdpo_temperature: temperature for SDPO loss; SDPO paper uses 1.0.
        sdpo_token_clip: per-token JSD clip for stability; None = no clip.
        replay_dpo_beta: beta param of the DPO loss (β in the standard DPO formula).
        kl_in_reward: when True, apply the KL-to-reference penalty in the
            **reward** (Composer-2 §4.1 / verl choice) instead of TRL's native
            **in-loss** k3 term. The penalty is folded into GRPO's advantages at
            scoring time (``adv -= beta·(KL - group_mean(KL))``) and TRL's
            in-loss KL is suppressed for that step. The F5 audit's #1 fidelity
            fix: the 2025/26 evidence (arXiv:2512.21852, verl, TRL #4967) shows
            k1-in-reward improves OOD generalization where k3-in-reward can
            collapse. REQUIRES ``beta>0`` (the KL coefficient — also how TRL
            decides to compute reference logprobs) and ``scale_rewards`` in
            {none,false} (the advantage-adjustment identity is exact only
            without std-normalization — the Dr.GRPO / Composer regime). Default
            False = TRL's native in-loss KL, byte-for-byte legacy behavior.
        kl_estimator: ``"k1"`` (default; ``logp - ref_logp``, the Composer-2 /
            verl choice this path exists for) or ``"k3"`` (Schulman; lets an
            experiment A/B k1-in-reward vs k3-in-reward). Only consulted when
            ``kl_in_reward=True``.
        heldout_guard: optional ``HeldOutGuard`` (the #2 collapse safeguard from
            ``composer_replication.safety``). Default None = OFF (no behavior
            change whatsoever). When supplied, the trainer folds one checkpoint's
            metrics into the guard at the ``args.logging_steps`` cadence (the same
            place the loss components are logged) and HALTS the run on a fired
            verdict — the run-level reward-hacking / collapse tripwire actually
            firing instead of sitting inert.
        heldout_eval_fn: zero-arg callable returning the held-out (real) eval
            score as a float, evaluated each guard cadence. Injectable so the
            trainer never hardcodes an eval — pass a closure over your disjoint
            held-out pool (the ``HeldoutSplit`` discipline). Required whenever
            ``heldout_guard`` is set; the guard's whole signal is in-loop reward
            vs. this held-out score.
        strict_killswitch: when True (default), a fired guard verdict raises
            ``CollapseStopError`` to hard-stop training (exception-based control
            flow, matching ``HeldOutGuard.raise_if_fired``). When False the
            verdict is logged and ``self.control.should_training_stop`` is set so
            the HF loop ends gracefully after the step (soft stop). Only consulted
            when ``heldout_guard`` is set.
    """

    def __init__(
        self,
        *args: Any,
        alpha_sdpo: float = 0.0,
        beta_replay: float = 0.0,
        sdpo_jsd_beta: float = 0.5,
        sdpo_temperature: float = 1.0,
        sdpo_token_clip: float | None = None,
        replay_dpo_beta: float = 0.1,
        strict_sdpo_alignment: bool = True,
        kl_in_reward: bool = False,
        kl_estimator: str = "k1",
        heldout_guard: HeldOutGuard | None = None,
        heldout_eval_fn: Callable[[], float] | None = None,
        strict_killswitch: bool = True,
        **kwargs: Any,
    ):
        if not _TRL_AVAILABLE:
            raise ImportError(
                "ComposerReplicationTrainer requires TRL. Install with "
                "`pip install -e .[train]`."
            )
        super().__init__(*args, **kwargs)
        self.alpha_sdpo = alpha_sdpo
        self.beta_replay = beta_replay
        self.sdpo_jsd_beta = sdpo_jsd_beta
        self.sdpo_temperature = sdpo_temperature
        self.sdpo_token_clip = sdpo_token_clip
        self.replay_dpo_beta = replay_dpo_beta
        # When True (default), an SDPO student/teacher shape mismatch is a hard
        # error — it means the data collator failed to align the post-hint
        # section, which silently zeroes the distillation signal (the exact
        # trust-gap flagged in ADR-008). Set False only for production runs
        # where a single malformed batch should warn-and-skip rather than abort.
        self.strict_sdpo_alignment = strict_sdpo_alignment
        # --- k1-in-reward KL (F5 #1 fidelity fix; Composer-2 §4.1 / verl) ----
        # OFF by default → TRL's native in-loss k3 KL, byte-for-byte legacy.
        # When ON we keep self.beta as the KL coef (TRL needs beta>0 to even
        # create the ref model + compute ref logps), fold the k1 penalty into
        # advantages during scoring, and zero TRL's in-loss KL per step.
        self.kl_in_reward = kl_in_reward
        self.kl_estimator = kl_estimator
        if kl_in_reward:
            validate_kl_in_reward_config(
                kl_estimator=kl_estimator,
                beta=float(getattr(self.args, "beta", 0.0)),
                scale_rewards=getattr(self.args, "scale_rewards", "group"),
            )
        # --- run-level collapse kill-switch (#2 safeguard) -------------------
        # OPTIONAL + OFF BY DEFAULT: when heldout_guard is None the loss path is
        # byte-for-byte the legacy behavior. When set, _maybe_update_killswitch
        # folds metrics into the guard at the logging cadence (see _compute_loss).
        self.heldout_guard = heldout_guard
        self.heldout_eval_fn = heldout_eval_fn
        self.strict_killswitch = strict_killswitch
        if heldout_guard is not None and heldout_eval_fn is None:
            raise ValueError(
                "heldout_guard was provided without heldout_eval_fn: the guard's "
                "tripwire compares in-loop reward against a DISJOINT held-out "
                "(real) eval score, so it needs an injectable zero-arg "
                "heldout_eval_fn() -> float. Pass a closure over your held-out "
                "pool (the HeldoutSplit discipline)."
            )

    # ----------------------------------------------------------------------
    # Loss override (the integration core)
    # ----------------------------------------------------------------------

    # ----------------------------------------------------------------------
    # k1-in-reward: fold the KL penalty into advantages at scoring time, and
    # suppress TRL's native in-loss k3 KL inside _compute_loss.
    # ----------------------------------------------------------------------

    def _generate_and_score_completions(
        self,
        inputs: list[dict[str, Any]],
    ) -> dict[str, Any]:
        """Override: after TRL scores completions, fold a k1 KL penalty into the
        advantages (Composer-2 in-reward KL) when ``kl_in_reward`` is set.

        No-op (exact legacy path) when ``kl_in_reward`` is False. When set, TRL
        has already computed ``advantages``, ``ref_per_token_logps`` (because
        ``beta>0``), and the completion logprobs; we recompute the per-sequence
        k1 penalty and apply the exact group-mean-baseline correction.
        """
        output = super()._generate_and_score_completions(inputs)
        if not getattr(self, "kl_in_reward", False):
            return output

        ref_logps = output.get("ref_per_token_logps")
        # The "old" (sampling-time) policy logps are TRL's in-loss π term; they
        # may be lazily None when generation/optimization are aligned and not
        # vLLM (see TRL _compute_loss: old := per_token_logps.detach()). In that
        # aligned case we cannot read π logps here, so we defer to _compute_loss
        # (which always has per_token_logps) by stashing what we need.
        old_logps = output.get("old_per_token_logps")
        completion_mask = output.get("completion_mask")
        if ref_logps is None or completion_mask is None:
            # beta>0 guarantees ref_logps; this branch only trips on a TRL
            # internals change — fail loud rather than silently skip the penalty.
            raise RuntimeError(
                "kl_in_reward=True but TRL did not return ref_per_token_logps / "
                "completion_mask from scoring (beta>0 should guarantee them). "
                "TRL internals may have changed; re-verify the in-reward path."
            )

        if old_logps is not None:
            penalty = kl_penalty_per_sequence(
                policy_logps=old_logps,
                ref_logps=ref_logps,
                completion_mask=completion_mask,
                estimator=self.kl_estimator,
            )
            output["advantages"] = apply_kl_in_reward(
                advantages=output["advantages"],
                kl_penalty=penalty,
                num_generations=self.num_generations,
                coef=float(self.args.beta),
            )
            output["_kl_in_reward_applied"] = torch.tensor(True)
        else:
            # Aligned non-vLLM case: π logps materialize only in _compute_loss.
            # Stash ref logps + mask so _compute_loss can apply the penalty there.
            output["_kl_in_reward_applied"] = torch.tensor(False)
        return output

    def _compute_loss(
        self,
        model: torch.nn.Module,
        inputs: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """Override: total_loss = grpo + α*sdpo + β*replay.

        When ``kl_in_reward`` is set, TRL's native in-loss KL term (gated on
        ``self.beta``) is suppressed by temporarily zeroing ``self.beta`` for the
        duration of the parent call — the KL has already been (or is about to be)
        accounted for in the reward/advantage, so double-counting it in the loss
        would be wrong. ``self.beta`` is restored in ``finally``.
        """
        # Channel 1: standard GRPO loss. ``getattr`` (not ``self.kl_in_reward``)
        # so an instance built via ``__new__`` + manual wiring (the SDPO /
        # kill-switch unit-test pattern that skips __init__) defaults to the
        # legacy path instead of raising AttributeError.
        if getattr(self, "kl_in_reward", False):
            grpo_loss = self._grpo_loss_kl_in_reward(model, inputs)
        else:
            grpo_loss = super()._compute_loss(model, inputs)

        # Channel 2: SDPO hint-distill at error sites
        sdpo_kl = self._compute_sdpo_loss(model, inputs)

        # Channel 3: trace-replay DPO from teacher disagreement
        replay_dpo = self._compute_trace_replay_loss(model, inputs)

        # Compose
        total = grpo_loss + self.alpha_sdpo * sdpo_kl + self.beta_replay * replay_dpo

        # Log per-channel components (so we can ablate post-hoc)
        if hasattr(self, "state") and getattr(self, "args", None) is not None:
            log_steps = getattr(self.args, "logging_steps", 50)
            if self.state.global_step % log_steps == 0:
                self.log({  # type: ignore[attr-defined]
                    "loss/grpo":               float(grpo_loss.detach()),
                    "loss/sdpo_kl":            float(sdpo_kl.detach()),
                    "loss/trace_replay_dpo":   float(replay_dpo.detach()),
                    "loss/total":              float(total.detach()),
                    "loss/alpha_sdpo":         self.alpha_sdpo,
                    "loss/beta_replay":        self.beta_replay,
                })
                # Fold one checkpoint into the run-level collapse kill-switch at
                # the SAME cadence (no-op unless a guard was configured).
                self._maybe_update_killswitch()

        return total

    def _grpo_loss_kl_in_reward(
        self,
        model: torch.nn.Module,
        inputs: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """GRPO loss with the KL applied in the reward, not the loss.

        Two responsibilities:
          1. Suppress TRL's native in-loss k3 KL term for this step by zeroing
             ``self.beta`` across the parent ``_compute_loss`` call (restored in
             ``finally``). ``self.beta`` gates the in-loss KL add (TRL
             ``_compute_loss``: ``if self.beta != 0.0: per_token_loss += beta*kl``).
          2. Handle the deferred case: when generation/optimization are aligned
             and not using vLLM, the sampling-time policy logps are None at
             scoring time, so ``_generate_and_score_completions`` could not fold
             the penalty into advantages. Here ``per_token_logps`` is available,
             so we apply the exact same advantage correction in-place on
             ``inputs["advantages"]`` BEFORE the parent computes the surrogate.
        """
        # Deferred-penalty path: advantages not yet KL-adjusted (aligned, no vLLM).
        applied = inputs.get("_kl_in_reward_applied")
        already_applied = bool(applied.item()) if applied is not None else False
        if not already_applied and "ref_per_token_logps" in inputs:
            with torch.no_grad():
                prompt_ids, completion_ids = inputs["prompt_ids"], inputs["completion_ids"]
                completion_mask = inputs["completion_mask"]
                input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
                attention_mask = torch.cat([inputs["prompt_mask"], completion_mask], dim=1)
                logits_to_keep = completion_ids.size(1)
                policy_logps, _ = self._get_per_token_logps_and_entropies(
                    model, input_ids, attention_mask, logits_to_keep
                )
                penalty = kl_penalty_per_sequence(
                    policy_logps=policy_logps,
                    ref_logps=inputs["ref_per_token_logps"],
                    completion_mask=completion_mask,
                    estimator=self.kl_estimator,
                )
                advantages = inputs["advantages"]
                # advantages may be (B,) or (B,1) — squeeze for the penalty math,
                # restore the original shape after.
                adv_flat = advantages.reshape(advantages.shape[0])
                adj = apply_kl_in_reward(
                    advantages=adv_flat,
                    kl_penalty=penalty,
                    num_generations=self.num_generations,
                    coef=float(self.args.beta),
                )
                inputs["advantages"] = adj.reshape(advantages.shape)

        # Suppress TRL's in-loss KL: zero beta for the parent call, restore after.
        saved_beta = self.beta
        try:
            self.beta = 0.0
            return super()._compute_loss(model, inputs)
        finally:
            self.beta = saved_beta

    # ----------------------------------------------------------------------
    # Run-level collapse kill-switch (#2 safeguard) — optional, OFF by default
    # ----------------------------------------------------------------------

    def _maybe_update_killswitch(self) -> None:
        """Fold this checkpoint's metrics into ``heldout_guard`` and act on a fire.

        No-op when no guard was configured (the default) — this is the
        backward-compat guarantee: without ``heldout_guard`` the trainer behaves
        exactly as before. When a guard IS set:

          * ``in_loop_reward`` is the GRPO reward signal TRL already aggregates
            into ``self._metrics[mode]["reward"]`` each step (we read the latest;
            no extra forward pass).
          * ``heldout_score`` comes from the injected ``heldout_eval_fn()`` — the
            trainer never hardcodes an eval.
          * ``kl_to_init`` (token-mean nats/token, the ``token_mean_kl``
            convention the guard expects) is read from TRL's logged ``"kl"``
            metric when present, else left None (KL path stays inert).

        On a fired verdict the verdict is logged. If ``strict_killswitch`` (the
        default) the verdict is converted into a ``CollapseStopError`` via
        ``HeldOutGuard.raise_if_fired`` (hard stop); otherwise the HF training
        loop is asked to stop gracefully after this step.
        """
        guard = self.heldout_guard
        if guard is None:
            return  # OFF by default — zero behavior change

        round_idx = int(getattr(self.state, "global_step", 0))
        in_loop_reward = self._latest_metric("reward")
        if in_loop_reward is None:
            # No reward aggregated yet (e.g. very first micro-step before TRL has
            # populated its metrics). Skip this cadence rather than feed a
            # fabricated 0.0 that would pollute the guard's baseline/EMA.
            logger.debug(
                "kill-switch: no in-loop reward metric yet at step %d; skipping.",
                round_idx,
            )
            return

        assert self.heldout_eval_fn is not None  # enforced in __init__
        heldout_score = float(self.heldout_eval_fn())
        kl_to_init = self._latest_metric("kl")  # token-mean KL, or None

        status = guard.update(
            round_idx=round_idx,
            in_loop_reward=in_loop_reward,
            heldout_score=heldout_score,
            kl_to_init=kl_to_init,
        )

        self.log({  # type: ignore[attr-defined]
            "killswitch/in_loop_reward":  status.in_loop_ema,
            "killswitch/heldout_score":   status.heldout_ema,
            "killswitch/proxy_real_gap":  status.proxy_real_gap,
            "killswitch/fire":            float(status.fire),
        })

        if status.fire:
            logger.error(
                "HeldOutGuard FIRED at step %d — halting run. reason: %s",
                round_idx, status.reason,
            )
            if self.strict_killswitch:
                # Typed exception — exception-based hard stop.
                guard.raise_if_fired(status)
            else:
                # Soft stop: let the HF loop terminate gracefully after this step.
                control = getattr(self, "control", None)
                if control is not None:
                    control.should_training_stop = True

    def _latest_metric(self, name: str) -> float | None:
        """Most-recent value of a TRL-aggregated train metric, or None.

        TRL's GRPOTrainer appends per-step aggregates to
        ``self._metrics["train"][name]`` (e.g. ``"reward"``, ``"kl"``). We read
        the tail defensively so a TRL internals rename degrades to None (KL/reward
        path goes inert) rather than crashing training.
        """
        metrics = getattr(self, "_metrics", None)
        if not isinstance(metrics, dict):
            return None
        train = metrics.get("train")
        if not isinstance(train, dict):
            return None
        series = train.get(name)
        if not series:
            return None
        try:
            return float(series[-1])
        except (TypeError, ValueError, IndexError):
            return None

    # ----------------------------------------------------------------------
    # Channel 2: SDPO hint-distill
    # ----------------------------------------------------------------------

    def _compute_sdpo_loss(
        self,
        model: torch.nn.Module,
        inputs: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """Compute generalized_jsd_loss between student and hint-conditioned teacher.

        Both come from the SAME model — teacher just has hint inserted into context.
        Skipped (returns 0) if the batch has no error sites (data collator emits
        empty ctx_teacher_input_ids).
        """
        if (
            self.alpha_sdpo == 0.0
            or "ctx_teacher_input_ids" not in inputs
            or inputs["ctx_teacher_input_ids"].numel() == 0
        ):
            return torch.tensor(0.0, device=_device_of(model), requires_grad=True)

        # Student forward (with grad, on the original-context input)
        student_logits = model(input_ids=inputs["input_ids"]).logits

        # Teacher forward (no grad — same model, hint-conditioned context)
        with torch.no_grad():
            teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits

        # ------------------------------------------------------------------
        # ALIGNMENT (cross-family review 2026-05-29 — the 4/4-reviewer P0).
        #
        # The teacher context has a hint inserted at the error turn, so the
        # teacher's post-hint response tokens are shifted right by len(hint)
        # relative to the student's. A bare `student.shape == teacher.shape`
        # check does NOT establish token-level alignment: equal-length tensors
        # whose response regions are offset will be JSD'd position-by-position
        # against each other, distilling garbage into the policy.
        #
        # The ONLY correct alignment is an explicit map from the collator that
        # selects, for each response token, the matching index in each sequence.
        # We require it whenever SDPO is active:
        #   - `student_response_idx` / `teacher_response_idx`: LongTensors of
        #     equal length selecting the aligned response positions in each
        #     sequence (the collator builds these knowing where it inserted the
        #     hint). JSD is computed over the gathered, provably-aligned logits.
        #   - If the collator cannot yet supply them, strict mode raises (loud
        #     failure) rather than silently distilling misaligned tokens.
        s_idx = inputs.get("student_response_idx")
        t_idx = inputs.get("teacher_response_idx")
        if s_idx is None or t_idx is None:
            msg = (
                "SDPO alignment indices missing: the collator must emit "
                "`student_response_idx` and `teacher_response_idx` (matching "
                "LongTensors selecting the aligned post-hint response tokens) so "
                "the JSD compares corresponding tokens. A shape-only check does "
                "NOT establish alignment — the hint shifts the teacher's response "
                "tokens right, so equal-length sequences can still be misaligned "
                "and silently distill garbage into the policy (ADR-008 trust-gap)."
            )
            if self.strict_sdpo_alignment:
                raise ValueError(
                    msg + " (strict_sdpo_alignment=True; pass False to fall back "
                    "to the legacy shape-only check for resilience.)"
                )
            logger.warning("%s Falling back to shape-only alignment check.", msg)
            if student_logits.shape != teacher_logits.shape:
                logger.warning(
                    "SDPO shape mismatch student=%s teacher=%s; skipping.",
                    tuple(student_logits.shape), tuple(teacher_logits.shape),
                )
                return torch.tensor(0.0, device=_device_of(model), requires_grad=True)
            return generalized_jsd_loss(
                student_logits=student_logits,
                teacher_logits=teacher_logits,
                labels=inputs.get("sdpo_loss_mask"),
                beta=self.sdpo_jsd_beta,
                temperature=self.sdpo_temperature,
                token_clip=self.sdpo_token_clip,
                reduction="batchmean",
            )

        # Validate the index tensors describe a consistent 1:1 alignment.
        if s_idx.shape != t_idx.shape:
            raise ValueError(
                f"SDPO alignment index shape mismatch: student_response_idx="
                f"{tuple(s_idx.shape)} vs teacher_response_idx={tuple(t_idx.shape)}. "
                "They must select the same number of aligned response tokens."
            )
        # Gather the provably-aligned response logits from each sequence, then
        # JSD only those positions (this is the masked error-turn distillation).
        # gather over the sequence dim (dim=1): expand index to the vocab dim.
        #
        # ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row
        # *_valid mask. Negative indices are illegal for torch.gather, so clamp
        # to 0 before gathering, then neutralize those positions by feeding
        # labels=-100 (the standard HF ignore convention that generalized_jsd_loss
        # already honors). This makes sentinel/padding positions contribute 0.
        #
        # Final-verify 2026-05-29: combine BOTH valid masks (not just student's)
        # AND the sentinel guard. If a future collator ever emits divergent
        # student/teacher valid tails, a teacher sentinel clamped to 0 would
        # otherwise be silently distilled against teacher position 0. Belt-and-
        # suspenders: valid iff student-valid AND teacher-valid AND both indices
        # non-sentinel.
        s_valid = inputs.get("student_response_valid")
        t_valid = inputs.get("teacher_response_valid")
        aligned_mask = (s_idx >= 0) & (t_idx >= 0)
        if s_valid is not None:
            aligned_mask = aligned_mask & s_valid.bool()
        if t_valid is not None:
            aligned_mask = aligned_mask & t_valid.bool()

        vocab = student_logits.size(-1)
        s_safe = s_idx.clamp_min(0)
        t_safe = t_idx.clamp_min(0)
        s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab)
        t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab)
        student_aligned = torch.gather(student_logits, 1, s_gather)
        teacher_aligned = torch.gather(teacher_logits, 1, t_gather)

        # Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at
        # sentinel/padding positions so they drop out of the JSD reduction.
        aligned_labels = torch.where(
            aligned_mask,
            torch.ones_like(s_idx),
            torch.full_like(s_idx, -100),
        )

        return generalized_jsd_loss(
            student_logits=student_aligned,
            teacher_logits=teacher_aligned,
            labels=aligned_labels,  # sentinel-masked aligned error-turn positions
            beta=self.sdpo_jsd_beta,
            temperature=self.sdpo_temperature,
            token_clip=self.sdpo_token_clip,
            reduction="batchmean",
        )

    # ----------------------------------------------------------------------
    # Channel 3: trace-replay DPO
    # ----------------------------------------------------------------------

    def _compute_trace_replay_loss(
        self,
        model: torch.nn.Module,
        inputs: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """Standard DPO loss using (chosen, rejected) pairs from teacher disagreement.

        DPO loss formula (Rafailov et al. 2023):
            L = -log σ(β · (logπ(chosen) - logπ_ref(chosen)
                          - logπ(rejected) + logπ_ref(rejected)))

        Where logπ_ref are precomputed by the data collator using the
        reference (init student) policy.
        """
        if (
            self.beta_replay == 0.0
            or "dpo_chosen_input_ids" not in inputs
            or inputs["dpo_chosen_input_ids"].numel() == 0
        ):
            return torch.tensor(0.0, device=_device_of(model), requires_grad=True)

        # Forward passes for chosen and rejected, gather logprobs at response tokens
        chosen_logprobs = self._sequence_logprobs(
            model, inputs["dpo_chosen_input_ids"], inputs["dpo_chosen_response_mask"]
        )
        rejected_logprobs = self._sequence_logprobs(
            model, inputs["dpo_rejected_input_ids"], inputs["dpo_rejected_response_mask"]
        )

        ref_chosen_logprobs = inputs["dpo_chosen_ref_logprobs"]
        ref_rejected_logprobs = inputs["dpo_rejected_ref_logprobs"]

        logits = self.replay_dpo_beta * (
            (chosen_logprobs - ref_chosen_logprobs)
            - (rejected_logprobs - ref_rejected_logprobs)
        )
        return -F.logsigmoid(logits).mean()

    @staticmethod
    def _sequence_logprobs(
        model: torch.nn.Module,
        input_ids: torch.Tensor,
        response_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Sum logprob of response tokens given the prompt prefix.

        Standard DPO accounting: we only score the response tokens (where
        response_mask == 1), not the prompt tokens.
        """
        outputs = model(input_ids=input_ids)
        # Shift for next-token prediction: logits[t] predicts input_ids[t+1]
        logits = outputs.logits[:, :-1, :]
        targets = input_ids[:, 1:]
        log_probs = F.log_softmax(logits, dim=-1)
        token_logprobs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
        # Mask out prompt + padding; sum response-token logprobs
        masked = token_logprobs * response_mask[:, 1:].float()
        return masked.sum(dim=-1)


def _device_of(model: torch.nn.Module) -> torch.device:
    """Return the device of any parameter of the model — robust to FSDP/DDP wrappers."""
    return next(model.parameters()).device


def validate_kl_in_reward_config(
    *,
    kl_estimator: str,
    beta: float,
    scale_rewards: Any,
) -> None:
    """Validate the (kl_estimator, beta, scale_rewards) combo for k1-in-reward.

    Extracted so the preconditions are unit-testable without standing up a real
    GRPOTrainer (which needs a model + dataset). Raises ``ValueError`` on any
    invalid combination; returns None when the config is sound.

    Preconditions (see ``kl_in_reward.py`` for the algebra):
      * ``kl_estimator`` in {k1, k3}.
      * ``beta != 0`` — TRL only builds the reference model and computes ref
        logprobs when beta>0, and the in-reward penalty needs ref logps. beta
        doubles as the in-reward KL coefficient (the in-loss k3 term is
        suppressed per step).
      * ``scale_rewards`` in {none, false} — the advantage-adjustment identity
        is exact only without per-group std-normalization (the Dr.GRPO /
        Composer regime).
    """
    if kl_estimator not in ("k1", "k3"):
        raise ValueError(f"kl_estimator must be 'k1' or 'k3', got {kl_estimator!r}.")
    if float(beta) == 0.0:
        raise ValueError(
            "kl_in_reward=True requires a non-zero `beta` (the KL coefficient): "
            "TRL only creates the reference model and computes ref logprobs when "
            "beta>0, and k1-in-reward needs those ref logps. Set beta to your KL "
            "coefficient (e.g. make_po_config('dr_grpo', beta=0.04)); the in-loss "
            "k3 term is suppressed automatically so beta acts purely as the "
            "in-reward k1 coefficient."
        )
    if str(scale_rewards).lower() not in ("none", "false"):
        raise ValueError(
            "kl_in_reward=True requires scale_rewards in {none,false} "
            f"(got {scale_rewards!r}). The advantage-adjustment identity "
            "adv -= beta·(KL - group_mean(KL)) is EXACT only without per-group "
            "std-normalization (the Dr.GRPO / Composer regime). With std-norm, "
            "folding KL into the reward also shifts the group std, so the linear "
            "correction no longer matches true in-reward KL. Use "
            "make_po_config('dr_grpo', beta=…) (scale_rewards='none')."
        )


def make_dr_grpo_config(**overrides: Any):
    """Build a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe.

    Per the Composer 2 technical report (arXiv:2603.24477,
    research/10-composer2-techreport-mining.md) the RL base is Dr. GRPO
    (Liu et al., arXiv:2503.20783):

      - ``loss_type="dr_grpo"``  — removes GRPO's length-standardization term
        (which injects a length bias). TRL's own help text cites the Dr. GRPO
        paper for this.
      - ``scale_rewards="none"`` — NO std-dev advantage normalization. TRL docs:
        "The Dr. GRPO paper recommends not scaling rewards, as scaling by the
        standard deviation introduces a question-level difficulty bias."
      - ``num_iterations=1``     — single-epoch regime (a prompt is never
        trained on twice), matching the tech report.
      - ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012
        finding #1, verified against the installed trl==1.5.0 source):
        ``GRPOTrainer._compute_loss`` uses the **k3** estimator
        ``exp(ref_logp - logp) - (ref_logp - logp) - 1``
        (trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator
        ``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance,
        always-non-negative KL approximation; k1 is its unbiased but
        higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses
        KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2))
        and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL
        to force k1; we document the honest delta. See
        ``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``.

    Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
    ``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
    explicitly overridden, and a sanity assertion guards against silent drift.
    """
    from trl import GRPOConfig  # local import: only when actually building a config

    dr_grpo_defaults: dict[str, Any] = {
        "loss_type": "dr_grpo",
        "scale_rewards": "none",
        "num_iterations": 1,
    }
    merged = {**dr_grpo_defaults, **overrides}
    cfg = GRPOConfig(**merged)
    # Guard: fail loudly if a future TRL renames/repurposes these knobs.
    assert cfg.loss_type == merged["loss_type"], (
        f"GRPOConfig loss_type drifted: requested {merged['loss_type']!r}, "
        f"got {cfg.loss_type!r} — TRL may have renamed/repurposed the knob."
    )
    # Dr. GRPO requires NO std-dev advantage normalization. TRL accepts either
    # the string "none" or the bool False to disable it; normalize before
    # comparing so a future TRL that switches the representation still passes
    # (and a genuinely-wrong value like "batch"/"group"/True fails loudly).
    # (Cross-family review 2026-05-29: the prior literal `("none","False","False")`
    # had a duplicated "False" and did a brittle case-sensitive str compare.)
    assert str(cfg.scale_rewards).lower() in ("none", "false"), (
        f"Dr. GRPO requires scale_rewards disabled (no std-norm); got "
        f"{cfg.scale_rewards!r}. TRL knob may have drifted — re-verify against trl version."
    )
    assert cfg.num_iterations == merged["num_iterations"], "GRPOConfig dropped num_iterations"
    return cfg


# ---------------------------------------------------------------------------
# Policy-optimization objective MENU (ADR-014)
# ---------------------------------------------------------------------------
#
# The base RL objective used to be hardcoded to Dr.GRPO (make_dr_grpo_config).
# make_po_config gives RL a real menu: GRPO-family objectives selectable by name.
# Verified against the installed trl==1.5.0 (introspected 2026-05-30): its
# GRPOTrainer already implements these as `loss_type` branches + knobs, so EVERY
# preset below is pure config — no custom _compute_loss override needed.
#
# Knob-space each preset sets (all real GRPOConfig fields in trl 1.5.0):
#   loss_type ∈ {grpo, dr_grpo, bnpo, dapo, cispo}   (gspo = grpo loss +
#       importance_sampling_level="sequence"; trl has no literal "gspo")
#   scale_rewards ∈ {"group"(std-norm), "batch", "none"(no std-norm, Dr.GRPO)}
#   epsilon / epsilon_high   — symmetric vs decoupled "clip-higher" (DAPO)
#   importance_sampling_level ∈ {"token", "sequence"(GSPO)}
#   beta                     — KL-to-ref coef (0.0 = reference-free)
#   mask_truncated_completions — DAPO overlong masking
#   num_iterations           — on-policy reuse (1 = strict on-policy)

#: Selectable base policy-optimization objectives (named presets over trl knobs).
PO_OBJECTIVES: dict[str, dict[str, Any]] = {
    # Vanilla GRPO (DeepSeekMath, arXiv 2402.03300): group-relative advantage
    # WITH std normalization + per-sequence length normalization, KL on.
    "grpo": {
        "loss_type": "grpo",
        "scale_rewards": "group",
        "importance_sampling_level": "token",
        "num_iterations": 1,
    },
    # Dr.GRPO (arXiv 2503.20783): remove length-std normalization bias (no
    # advantage /std, length-independent aggregation). Framework's historical
    # default (== make_dr_grpo_config). Composer 2.5's base objective.
    "dr_grpo": {
        "loss_type": "dr_grpo",
        "scale_rewards": "none",
        "importance_sampling_level": "token",
        "num_iterations": 1,
    },
    # BNPO: batch-normalized variant (trl loss_type), std over the batch.
    "bnpo": {
        "loss_type": "bnpo",
        "scale_rewards": "batch",
        "importance_sampling_level": "token",
        "num_iterations": 1,
    },
    # DAPO (arXiv 2503.14476): decoupled "clip-higher" (epsilon_high > epsilon)
    # + token-level loss + overlong masking + KL removed. High-value, low-cost
    # anti-entropy-collapse objective. epsilon_high=0.28 per the paper.
    "dapo": {
        "loss_type": "dapo",
        "scale_rewards": "none",
        "epsilon": 0.2,
        "epsilon_high": 0.28,
        "mask_truncated_completions": True,
        "beta": 0.0,
        "importance_sampling_level": "token",
        "num_iterations": 1,
    },
    # GSPO (Qwen, arXiv 2507.18071): SEQUENCE-level importance ratio (one length-
    # normalized ratio per response) — stabilizes long-CoT and especially MoE RL.
    # trl expresses this as the grpo loss + importance_sampling_level="sequence".
    "gspo": {
        "loss_type": "grpo",
        "scale_rewards": "group",
        "importance_sampling_level": "sequence",
        "num_iterations": 1,
    },
    # CISPO (MiniMax-M1, arXiv 2506.13585): clip the IS weight and detach it as a
    # constant coefficient on log π — every token keeps a gradient (fixes the
    # "rare reasoning tokens get zeroed by the clip" pathology). eps_max≈5 (ScaleRL).
    "cispo": {
        "loss_type": "cispo",
        "scale_rewards": "none",
        "epsilon_high": 5.0,
        "importance_sampling_level": "token",
        "num_iterations": 1,
    },
}


def make_po_config(objective: str = "dr_grpo", **overrides: Any):
    """Build a `trl.GRPOConfig` for a NAMED policy-optimization objective.

    The menu that gives RL real options beyond the single hardcoded Dr.GRPO
    recipe. ``objective`` selects a preset from ``PO_OBJECTIVES`` (grpo /
    dr_grpo / bnpo / dapo / gspo / cispo); ``**overrides`` set or override any
    GRPOConfig field on top (e.g. ``output_dir=...``, ``beta=...``,
    ``learning_rate=...``).

    All presets are PURE CONFIG over trl 1.5.0's GRPOTrainer (verified by
    introspecting the installed package 2026-05-30): the trainer already
    implements each ``loss_type`` branch and the ``importance_sampling_level`` /
    ``epsilon_high`` knobs, so no custom ``_compute_loss`` is needed. See ADR-014.

    Raises:
        ValueError: unknown objective (lists the valid menu).
        AssertionError: a requested knob silently failed to apply (drift guard).
    """
    from trl import GRPOConfig  # local import: only when actually building a config

    key = (objective or "dr_grpo").lower()
    if key not in PO_OBJECTIVES:
        raise ValueError(
            f"Unknown PO objective {objective!r}. Choose from: "
            f"{sorted(PO_OBJECTIVES)}. (Each is a named preset over trl 1.5.0's "
            f"GRPOConfig knobs — see PO_OBJECTIVES / ADR-014.)"
        )

    preset = dict(PO_OBJECTIVES[key])
    merged = {**preset, **overrides}
    cfg = GRPOConfig(**merged)

    # Drift guards: fail loudly if a future trl renamed/repurposed a knob we set,
    # so a preset can never silently degrade to a different objective.
    if "loss_type" in merged:
        assert str(cfg.loss_type) == str(merged["loss_type"]), (
            f"GRPOConfig.loss_type drifted: requested {merged['loss_type']!r}, "
            f"got {cfg.loss_type!r} — trl may have renamed the knob."
        )
    if "importance_sampling_level" in merged and hasattr(cfg, "importance_sampling_level"):
        assert str(cfg.importance_sampling_level) == str(
            merged["importance_sampling_level"]
        ), (
            f"importance_sampling_level drifted for objective {key!r}: requested "
            f"{merged['importance_sampling_level']!r}, got {cfg.importance_sampling_level!r}."
        )
    if key == "gspo":
        assert str(getattr(cfg, "importance_sampling_level", "token")) == "sequence", (
            "GSPO requires importance_sampling_level='sequence'; it was overridden "
            "to token, which silently degrades GSPO to GRPO. Drop that override."
        )
    if merged.get("epsilon_high") is not None:
        assert abs(
            float(getattr(cfg, "epsilon_high", merged["epsilon_high"]))
            - float(merged["epsilon_high"])
        ) < 1e-9, f"epsilon_high (decoupled clip) drifted for {key!r}."
    return cfg


__all__ = [
    "ComposerReplicationTrainer",
    "make_dr_grpo_config",
    "make_po_config",
    "PO_OBJECTIVES",
    "validate_kl_in_reward_config",
]