File size: 27,850 Bytes
1fa3c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
# MCSD implementation guide for this TRL fork

This guide describes how to implement Multi-Criteria Self-Distillation (MCSD) on top of the current experimental SDPO stack in this repository.

The current OPSD path has one student context and one privileged teacher context. MCSD keeps the student path unchanged, replaces the single teacher with multiple criterion-specific privileged teachers, and merges those teachers with an LSC-style product of experts.

Do not change the default behavior. `mcsd_enable=False` must preserve the existing OPSD behavior.

## 1. Current code path to extend

Use these files as the implementation map:

- `examples/scripts/sdpo_rar_science.py`
  - Builds the local RAR Science dataset.
  - Currently creates one `privileged_context` string from the full rubric.
  - The student sees only `prompt`; the teacher sees `privileged_context + prompt + completion`.

- `trl/experimental/self_distillation/self_distillation_mixin.py`
  - `_set_signature_columns_if_needed` currently keeps `prompt` and `privileged_context`.
  - `_split_prompt_and_privileged_context` currently returns one privileged context per example.
  - `_compute_self_distillation_loss` builds student logits, teacher logits, top-k distributions, reverse-KL token loss, importance clipping, and final masked aggregation.

- `trl/experimental/sdpo/sdpo_trainer.py`
  - `SuccessfulRolloutTeacherContextBuilder.build` turns `privileged_context` into `teacher_input_ids`, `teacher_attention_mask`, and `self_distillation_mask`.
  - `SDPOTrainer._generate_and_score_completions` calls the teacher-context builder and attaches those tensors to the training batch.

- `trl/experimental/sdpo/sdpo_config.py`
  - `SDPOConfig` owns the SDPO-specific CLI arguments and validation.

- `tests/experimental/test_sdpo_trainer.py`
  - Existing tests already cover top-k distillation, teacher prompt construction, old-log-prob clipping setup, and teacher attention masks.
  - Add MCSD coverage here.

Repository-specific constraint: this is experimental code, but keep duplicated SDPO/self-distillation patterns consistent. Add small local helpers only where they make the MCSD branch testable and readable.

## 2. Objective to implement

For each generated response token position `t`, keep the same student top-k token set used by OPSD:

```text

K_t = TopK_M(p_s,t), where M = args.distillation_topk

```

Renormalize the student and every criterion teacher over the same top-k set:

```text

log_q_s       = log p_s over K_t, renormalized over K_t

log_q_t[:, j] = log p_teacher_j over K_t, renormalized over K_t

```

Build the merged MCSD teacher as:

```python

log_q_base = log_q_s.detach()

advantage = log_q_t - log_q_base.unsqueeze(1)

log_gate = F.logsigmoid(advantage - args.mcsd_gate_bias)

merged_log_unnorm = log_q_base + log_gate.sum(dim=1)

log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True)

log_q_mcsd = log_q_mcsd.detach()

```

Then replace the OPSD teacher distribution in the existing reverse-KL token loss:

```python

token_loss = (log_q_s.exp() * (log_q_s - log_q_mcsd)).sum(dim=-1)

```

Keep the existing completion mask, `self_distillation_mask`, `loss_type` aggregation, and `distillation_is_clip` logic.

Important details:

- Select top-k from the student distribution, not from teacher distributions.
- Use the same top-k indices for all teachers.
- Detach `log_q_s` before using it as the MCSD base measure.
- Apply `F.logsigmoid` to each criterion first, then sum the transformed criterion values in log-space.
- Do not average teacher distributions.
- Do not sum raw teacher log-probs.
- Detach `log_q_mcsd` before computing the reverse KL.

## 3. Add configuration arguments

Edit `trl/experimental/sdpo/sdpo_config.py`.

Add fields to `SDPOConfig`:

```python

mcsd_enable: bool = field(

    default=False,

    metadata={"help": "Enable Multi-Criteria Self-Distillation."},

)

mcsd_merge_mode: str = field(

    default="lsc_poe",

    metadata={"help": "MCSD teacher merge mode. Supported: `lsc_poe`."},

)

mcsd_gate_bias: float = field(

    default=0.0,

    metadata={"help": "Scalar bias subtracted from each MCSD criterion advantage before logsigmoid gating."},

)

```

Add these entries to the class docstring using the repository docstring format.

Extend `__post_init__` with narrow first-pass validation:

```python

if self.mcsd_enable:

    if self.mcsd_merge_mode != "lsc_poe":

        raise ValueError("mcsd_merge_mode must be `lsc_poe` when MCSD is enabled.")

    if self.sdpo_policy_loss_mode != "distillation_only":

        raise ValueError("MCSD currently supports `sdpo_policy_loss_mode='distillation_only'`.")

    if not self.full_logit_distillation:

        raise ValueError("MCSD requires `full_logit_distillation=True`.")

    if self.distillation_topk is None:

        raise ValueError("MCSD requires `distillation_topk`.")

    if self.distillation_alpha != 1.0:

        raise ValueError("MCSD requires reverse KL, so `distillation_alpha` must be 1.0.")

    if self.distillation_add_tail:

        raise ValueError("MCSD does not support `distillation_add_tail=True`.")

```

The intended CLI is:

```bash

--mcsd_enable true \

--mcsd_merge_mode lsc_poe \

--mcsd_gate_bias 0.0 \

--sdpo_policy_loss_mode distillation_only \

--full_logit_distillation true \

--distillation_topk 20 \

--distillation_alpha 1.0

```

## 4. Extend the dataset format

Keep the old single-context field for OPSD:

```python

"privileged_context": str

```

Add a plural field for MCSD:

```python

"privileged_contexts": list[str]

```

Each item in `privileged_contexts` is one criterion-specific teacher-only context. The number of criteria may differ by example.

For the current RAR data, `data/rar_science/*.jsonl` and `data/rar_medicine/*.jsonl` already contain rubric fields such as:

```python

"rubric": [{"description": "...", "title": "...", "weight": ...}, ...]

"rubric_list": ["Essential Criteria: ...", ...]

"rubric_count": 7

```

In `examples/scripts/sdpo_rar_science.py`, keep `_build_privileged_context` unchanged for OPSD and add separate criterion-level templates for MCSD:

```python

POSITIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation criteria:



{positive_rubric}



Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.



IMPORTANT:

Do not mention or reference these evaluation criteria in your response.

Do not indicate that you have seen any scoring rubric or evaluation guidelines.

Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.

"""



NEGATIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation pitfall to avoid:



{negative_rubric}



Please provide a comprehensive and helpful response that addresses the question while carefully avoiding the issue described above.



IMPORTANT:

Do not mention or reference this evaluation criterion in your response.

Do not indicate that you have seen any scoring rubric, pitfall list, or evaluation guidelines.

Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.

"""

```

Then add a criterion-level builder:

```python

def _build_privileged_contexts_by_criterion(

    example: dict[str, Any], max_items_per_section: int | None

) -> list[str]:

    include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)

    contexts = []

    for item in include_items:

        contexts.append(

            POSITIVE_RUBRIC_TEACHER_TEMPLATE.format(positive_rubric=item)

        )

    for item in avoid_items:

        contexts.append(

            NEGATIVE_RUBRIC_TEACHER_TEMPLATE.format(negative_rubric=item)

        )

    return contexts

```

Then update `_make_conversation`:

```python

return {

    "prompt": [

        {"role": "system", "content": SYSTEM_PROMPT},

        {"role": "user", "content": example["question"]},

    ],

    "solution": example["reference_answer"],

    "privileged_context": _build_privileged_context(example, max_items_per_section),

    "privileged_contexts": _build_privileged_contexts_by_criterion(example, max_items_per_section),

}

```

Also update `SelfDistillationMixin._set_signature_columns_if_needed`:

```python

self._signature_columns = ["prompt", "privileged_context", "privileged_contexts"]

```

Keep `_split_prompt_and_privileged_context` unchanged for the old OPSD path. Add a new helper for MCSD:

```python

@staticmethod

def _split_prompt_and_privileged_contexts(

    inputs: list[dict[str, Any]],

) -> tuple[list[Any], list[Any], list[Any]]:

    prompts = [example["prompt"] for example in inputs]

    privileged_contexts = [example.get("privileged_context") for example in inputs]

    criterion_contexts = [

        example.get("privileged_contexts", example.get("privileged_context"))

        for example in inputs

    ]

    return prompts, privileged_contexts, criterion_contexts

```

Then update `SDPOTrainer._generate_and_score_completions`:

```python

prompts, privileged_contexts, criterion_contexts = self._split_prompt_and_privileged_contexts(inputs)

teacher_feedbacks = criterion_contexts if self.args.mcsd_enable else privileged_contexts



output = super()._generate_and_score_completions(inputs)

output.update(

    self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=teacher_feedbacks)

)

```

This keeps the single-teacher field as the source for OPSD and the plural field as the source for MCSD.

Recommended batch convention:

```python

prompts: list[Any]

privileged_contexts: list[str | None]        # old OPSD

criterion_contexts: list[list[str] | str | None]  # MCSD

```

Treat a plain string in `criterion_contexts` as a one-criterion list so existing datasets can still be smoke-tested with MCSD.

## 5. Build criterion-specific teacher contexts

Edit `SuccessfulRolloutTeacherContextBuilder` in `trl/experimental/sdpo/sdpo_trainer.py`.

Do not rewrite the existing single-teacher path. Add a branch near the start of `build`:

```python

if self.trainer.args.mcsd_enable:

    return self._build_mcsd(output, prompts, rewards, feedbacks=feedbacks)

```

Keep the existing `build` body as the OPSD implementation.

The MCSD builder returns:

```python

{

    "teacher_input_ids_by_criterion": teacher_input_ids_by_criterion,          # [B, K_max, L]

    "teacher_attention_mask_by_criterion": teacher_attention_mask_by_criterion,# [B, K_max, L]

    "mcsd_criterion_mask": mcsd_criterion_mask,                                # [B, K_max]

    "mcsd_num_criteria": mcsd_num_criteria,                                    # [B]

    "self_distillation_mask": local_self_distillation_mask,                    # [B]

}

```

Implementation outline:

1. Gather rewards, completions, prompts, and feedbacks exactly like the current builder does.
2. For each global sample, construct a list of active criterion contexts.
3. Use `privileged_contexts` when available.
4. If only the old `privileged_context` string is available, wrap it as a one-element list.
5. If a sample has no active criteria, use one placeholder teacher prompt and set its criterion mask to `0`.
6. For active criteria, set the criterion mask to `1`.
7. For variable `K`, pad criteria to the local `K_max`.
8. Tokenize the flattened list of teacher messages with the existing `_tokenize_teacher_messages`.
9. Repeat `completion_ids` and `completion_mask` by `K_max`.
10. Concatenate teacher prompts and completions exactly like OPSD does.
11. Reshape flattened teacher tensors back to `[B, K_max, L]`.

Sketch:

```python

contexts_by_sample = []

mask_rows = []

mcsd_num_criteria = []



for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)):

    raw_contexts = all_feedbacks[global_idx]

    contexts = _normalize_criterion_contexts(raw_contexts)



    if len(contexts) == 0:

        contexts_by_sample.append([None])

        mask_rows.append([0.0])

        mcsd_num_criteria.append(0)

    else:

        contexts_by_sample.append(contexts)

        mask_rows.append([1.0] * len(contexts))

        mcsd_num_criteria.append(len(contexts))



max_criteria = max(len(contexts) for contexts in contexts_by_sample)

flat_teacher_messages = []



for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)):

    original_prompt = all_prompts[global_idx]

    contexts = list(contexts_by_sample[local_idx])



    while len(contexts) < max_criteria:

        contexts.append(None)

        mask_rows[local_idx].append(0.0)



    for context in contexts:

        flat_teacher_messages.append(_build_teacher_message(original_prompt, context))



mcsd_criterion_mask = torch.tensor(mask_rows, device=device)

mcsd_num_criteria = torch.tensor(mcsd_num_criteria, device=device, dtype=torch.long)

local_self_distillation_mask = (mcsd_num_criteria > 0).float()

```

Tokenize `flat_teacher_messages`, concatenate the repeated completions, and reshape:

```python

teacher_batch = self._tokenize_teacher_messages(flat_teacher_messages)

flat_completion_ids = completion_ids.repeat_interleave(max_criteria, dim=0)

flat_completion_mask = completion_mask.repeat_interleave(max_criteria, dim=0)



flat_teacher_input_ids = torch.cat([teacher_batch.prompt_ids, flat_completion_ids], dim=1)

flat_teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, flat_completion_mask], dim=1)



teacher_input_ids_by_criterion = flat_teacher_input_ids.reshape(num_local, max_criteria, -1)

teacher_attention_mask_by_criterion = flat_teacher_attention_mask.reshape(num_local, max_criteria, -1)

```

A masked criterion contributes `0` to `log_gate.sum(dim=1)`, which is the neutral log-space contribution.

For RAR Science/Medicine MCSD, set:

```bash

--include_environment_feedback true \

--use_successful_as_teacher false

```

That makes each rubric item a criterion-specific hidden teacher context and avoids mixing successful-rollout mining into the first MCSD implementation.

## 6. Add the MCSD loss branch

Edit `SelfDistillationMixin._compute_self_distillation_loss` in `trl/experimental/self_distillation/self_distillation_mixin.py`.

The current method does this:

1. Builds `response_mask`.
2. Builds student logits.
3. Builds one teacher logits tensor.
4. Computes sampled-token logps for diagnostics and clipping.
5. Computes top-k reverse KL when `distillation_topk` and `full_logit_distillation` are active.
6. Applies importance clipping.
7. Aggregates with `_aggregate_self_distillation_loss`.

Add an MCSD branch after student logits are available and before the single-teacher forward pass:

```python

if self.args.mcsd_enable:

    return self._compute_mcsd_self_distillation_loss(

        model=model,

        inputs=inputs,

        completion_ids=completion_ids,

        completion_mask=completion_mask,

        response_mask=response_mask,

        logits_to_keep=logits_to_keep,

        student_logits=student_logits,

    )

```

Keep all existing code below that branch for OPSD.

### 6.1 Forward all criterion teachers

Inside `_compute_mcsd_self_distillation_loss`, read the new tensors:

```python

teacher_input_ids = inputs["teacher_input_ids_by_criterion"]              # [B, K, L]

teacher_attention_mask = inputs["teacher_attention_mask_by_criterion"]    # [B, K, L]

criterion_mask = inputs["mcsd_criterion_mask"].to(student_logits.dtype)   # [B, K]

```

Flatten the criterion dimension:

```python

batch_size, num_criteria, teacher_seq_len = teacher_input_ids.shape



flat_teacher_input_ids = teacher_input_ids.reshape(batch_size * num_criteria, teacher_seq_len)

flat_teacher_attention_mask = teacher_attention_mask.reshape(batch_size * num_criteria, teacher_seq_len)

```

Run the teacher model under the same no-grad/context manager as OPSD:

```python

teacher_model_inputs = {

    "input_ids": flat_teacher_input_ids,

    "attention_mask": flat_teacher_attention_mask,

    "use_cache": False,

}

if "logits_to_keep" in self.model_kwarg_keys:

    teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1



teacher_model = self._get_teacher_model_for_self_distillation(model)

with torch.no_grad(), self._get_teacher_context_for_self_distillation(model):

    flat_teacher_logits = teacher_model(**teacher_model_inputs).logits

```

Then align logits exactly like OPSD:

```python

flat_teacher_logits = flat_teacher_logits[:, :-1, :]

flat_teacher_logits = flat_teacher_logits[:, -logits_to_keep:, :]

flat_teacher_logits = flat_teacher_logits / self.temperature

teacher_logits = flat_teacher_logits.reshape(batch_size, num_criteria, logits_to_keep, -1)

```

This alignment is required because teacher prompts are longer than student prompts, but both tensors end with the same generated completion.

### 6.2 Compute top-k MCSD reverse KL

Use this code shape for the core math:

```python

student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True)

student_log_probs = student_logits - student_logsumexp



topk_student_log_probs, topk_indices = torch.topk(

    student_log_probs,

    k=self.args.distillation_topk,

    dim=-1,

)

log_q_s = self._renorm_topk_log_probs(topk_student_log_probs)

q_s = log_q_s.exp()



teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True)

teacher_log_probs = teacher_logits - teacher_logsumexp

teacher_topk_indices = topk_indices.unsqueeze(1).expand(-1, num_criteria, -1, -1)

teacher_topk_log_probs = torch.gather(teacher_log_probs, dim=-1, index=teacher_topk_indices)

log_q_t = self._renorm_topk_log_probs(teacher_topk_log_probs)



log_q_base = log_q_s.detach()

advantage = log_q_t - log_q_base.unsqueeze(1)

log_gate = F.logsigmoid(advantage - self.args.mcsd_gate_bias)

log_gate = log_gate * criterion_mask[:, :, None, None]



merged_log_unnorm = log_q_base + log_gate.sum(dim=1)

log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True)

log_q_mcsd = log_q_mcsd.detach()



per_token_loss = (q_s * (log_q_s - log_q_mcsd)).sum(dim=-1)

```

Notes:

- `criterion_mask == 0` makes the criterion contribute `0` to the log-space expert sum.
- `self_distillation_mask` already zeros out samples with no active teacher supervision through `response_mask`.
- Reject `distillation_add_tail=True` for MCSD in this implementation:

```python

if self.mcsd_enable and self.distillation_add_tail:

    raise ValueError("MCSD does not support `distillation_add_tail=True` in the first implementation.")

```

### 6.3 Keep sampled-token logps for clipping

The existing clipping path uses sampled-token student log-probs and `old_per_token_logps`.

Compute sampled-token logps from `student_logits`:

```python

idx = completion_ids.unsqueeze(-1)

student_per_token_logps = torch.gather(student_log_probs, dim=-1, index=idx).squeeze(-1)

```

Then reuse the existing helper:

```python

if self.args.distillation_is_clip is not None:

    old_log_probs = inputs.get("old_per_token_logps")

    if old_log_probs is not None:

        per_token_loss = self._apply_importance_sampling_clipping(

            per_token_loss,

            student_per_token_logps,

            old_log_probs,

            self.args.distillation_is_clip,

        )

```

Keep the existing pointwise KL diagnostics if possible. Use the same code from the OPSD branch so clipping metrics remain aligned.

### 6.4 Aggregate exactly like OPSD

Use the existing aggregation helper:

```python

loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask)

```

Log the distillation loss through the existing metric helper so `train/loss_opsd` continues to be populated by the base trainer:

```python

mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0)

self._log_self_distillation_metric(

    mode,

    "distillation_loss",

    self.accelerator.gather(mean_distill_loss).mean().item(),

)

```

Also log MCSD-specific diagnostics.

## 7. Log MCSD diagnostics

Keep the existing OPSD TensorBoard keys alive.

The base trainer maps `self_distillation/distillation_loss` to:

```text

train/loss_opsd

```

Therefore the MCSD loss branch must still call:

```python

self._log_self_distillation_metric(mode, "distillation_loss", ...)

```

This preserves the old `train/loss_opsd` panel. It also keeps `train/loss_total`, `train/grad_norm`, `progress/*`, `rollout/*`, and `system/*` unchanged because those are logged outside the teacher-merge logic.

For the old `opsd/*` diagnostic keys, continue emitting them in MCSD mode so existing dashboards do not go blank:

```text

opsd/student_logp_mean

opsd/teacher_logp_mean

opsd/teacher_student_logp_gap_mean

opsd/token_advantage_mean

opsd/token_advantage_std

opsd/pointwise_kl_mean

opsd/pointwise_kl_clip_frac

opsd/pointwise_kl_clipped_mean

```

In MCSD mode, interpret the teacher-side `opsd/*` keys as diagnostics against the merged MCSD teacher, not against a single privileged teacher. For sampled-token teacher metrics, gather the sampled response token from `log_q_mcsd` only when the sampled token is inside the student top-k set; otherwise skip that token for those compatibility diagnostics. The primary MCSD diagnostics below are the authoritative metrics for the merged teacher.

Use slash-style metric names to match the repository:

```python

mcsd/loss

mcsd/num_criteria_mean

mcsd/gate_bias

mcsd/teacher_entropy_mean

mcsd/student_entropy_mean

mcsd/merged_entropy_mean

mcsd/gate_mean

mcsd/gate_min

mcsd/gate_max

```

Compute them under `torch.no_grad()` and apply masks where practical:

```python

gate = log_gate.exp()

student_entropy = -(q_s * log_q_s).sum(dim=-1)

merged_entropy = -(log_q_mcsd.exp() * log_q_mcsd).sum(dim=-1)

teacher_entropy = -(log_q_t.exp() * log_q_t).sum(dim=-1)

```

For criterion-level metrics, mask with:

```python

criterion_token_mask = criterion_mask[:, :, None] * response_mask[:, None, :]

```

For token-level metrics, mask with:

```python

response_mask

```

Gather metric tensors through `self.accelerator.gather(...)` before appending scalar `.item()` values, following the style already used in `self_distillation_mixin.py`.

## 8. Update callbacks and artifacts only as needed

The existing callback `on_teacher_context_built` receives:

```python

teacher_input_ids

teacher_attention_mask

completion_mask

self_distillation_mask

```

For MCSD, pass the new tensors too:

```python

teacher_input_ids_by_criterion=output["teacher_input_ids_by_criterion"]

teacher_attention_mask_by_criterion=output["teacher_attention_mask_by_criterion"]

mcsd_criterion_mask=output["mcsd_criterion_mask"]

mcsd_num_criteria=output["mcsd_num_criteria"]

```

Do not remove the old callback payload for OPSD.

For rollout artifacts, add the number of active criteria per sample. Avoid dumping full teacher logits.

## 9. Update the launch script

In `run_sdpo.sh`, add the MCSD flags:

```bash

  --mcsd_enable true \

  --mcsd_merge_mode lsc_poe \

  --mcsd_gate_bias 0.0 \

  --include_environment_feedback true \

  --use_successful_as_teacher false \

  --sdpo_policy_loss_mode distillation_only \

  --full_logit_distillation true \

  --distillation_topk 20 \

  --distillation_alpha 1.0 \

```

Keep `--distillation_is_clip 2` if you want the same importance-clipping behavior as OPSD.

## 10. Add tests

Add focused tests in `tests/experimental/test_sdpo_trainer.py`.

Minimum test list:

1. Config defaults
   - `SDPOConfig(output_dir=...).mcsd_enable is False`
   - Existing tests still pass without setting any MCSD flag.

2. Config validation
   - `mcsd_enable=True` rejects `full_logit_distillation=False`.
   - `mcsd_enable=True` rejects `distillation_topk=None`.
   - `mcsd_enable=True` rejects `distillation_alpha != 1.0`.
   - `mcsd_enable=True` rejects unsupported `mcsd_merge_mode`.

3. Dataset retention
   - A dataset with `privileged_contexts` survives trainer column filtering.

4. Teacher context construction
   - A sample with three criterion contexts returns `teacher_input_ids_by_criterion` with shape `[B, 3, L]`.
   - `mcsd_criterion_mask` is `[1, 1, 1]`.
   - A sample with fewer criteria is padded and has zeros in the padded criterion mask.
   - Teacher completion attention still equals the original `completion_mask` on every active criterion.

5. MCSD math unit test
   - Use tiny synthetic `student_logits` and `teacher_logits`.
   - Verify the merged teacher equals:

```python

softmax(log_q_s.detach() + F.logsigmoid(log_q_t - log_q_s.detach()).sum(dim=1))

```

   - Verify it is not equal to an arithmetic average of teacher distributions.
   - Verify inactive criteria do not change the merged teacher.

6. End-to-end tiny training
   - Use `trl-internal-testing/tiny-Qwen2ForCausalLM-2.5`.
   - Dataset has `prompt`, `privileged_context`, and `privileged_contexts`.
   - Args include:

```python

mcsd_enable=True,

mcsd_merge_mode="lsc_poe",

mcsd_gate_bias=0.0,

include_environment_feedback=True,

use_successful_as_teacher=False,

sdpo_policy_loss_mode="distillation_only",

full_logit_distillation=True,

distillation_topk=5,

distillation_alpha=1.0,

distillation_is_clip=None,

max_steps=1,

```

   - Assert `train_loss` is present.
   - Assert at least one `mcsd/*` metric is logged.

Run the targeted tests:

```bash

pytest tests/experimental/test_sdpo_trainer.py -k "mcsd or training"

```

Then run the full SDPO test file:

```bash

pytest tests/experimental/test_sdpo_trainer.py

```

## 11. Paper index requirement

If MCSD is being implemented from a paper, add a subsection to `paper_index.md`. Per repo guidance, use Hugging Face paper links:

```text

https://huggingface.co/papers/<paper-id>

```

Do this in the same PR as the MCSD implementation.

## 12. Implementation checklist

1. Add `mcsd_enable`, `mcsd_merge_mode`, and `mcsd_gate_bias` to `SDPOConfig`.
2. Validate MCSD only for top-k full-logit reverse-KL distillation.
3. Preserve `privileged_context` and add `privileged_contexts` to the RAR dataset mapping.
4. Keep `privileged_contexts` during trainer column filtering.
5. Add an MCSD branch to the teacher-context builder.
6. Return `[B, K_max, L]` teacher tensors plus `mcsd_criterion_mask`.
7. Add an MCSD branch to `_compute_self_distillation_loss`.
8. Forward criterion teachers by flattening `[B, K, L]` to `[B*K, L]`.
9. Gather every teacher over the student top-k token indices.
10. Build the merged teacher with `log_q_s.detach() + sum_j(logsigmoid(log_q_t_j - log_q_s.detach() - bias))`.
11. Compute reverse KL from current student top-k distribution to detached merged teacher.
12. Reuse existing importance clipping and loss aggregation.
13. Log `mcsd/*` diagnostics.
14. Update callbacks/artifacts with criterion counts and masks.
15. Add tests for config, data retention, context construction, MCSD math, and tiny training.
16. Update `paper_index.md` if MCSD corresponds to a paper implementation.

Final behavior: when MCSD is disabled, this repository trains exactly as current OPSD; when MCSD is enabled, the single teacher distribution is replaced by an LSC-transformed product-of-experts teacher over the same student-selected top-k tokens.