trl-mcsd / MCSD.md
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# MCSD implementation guide for this TRL fork
This guide describes how to implement Multi-Criteria Self-Distillation (MCSD) on top of the current experimental SDPO stack in this repository.
The current OPSD path has one student context and one privileged teacher context. MCSD keeps the student path unchanged, replaces the single teacher with multiple criterion-specific privileged teachers, and merges those teachers with an LSC-style product of experts.
Do not change the default behavior. `mcsd_enable=False` must preserve the existing OPSD behavior.
## 1. Current code path to extend
Use these files as the implementation map:
- `examples/scripts/sdpo_rar_science.py`
- Builds the local RAR Science dataset.
- Currently creates one `privileged_context` string from the full rubric.
- The student sees only `prompt`; the teacher sees `privileged_context + prompt + completion`.
- `trl/experimental/self_distillation/self_distillation_mixin.py`
- `_set_signature_columns_if_needed` currently keeps `prompt` and `privileged_context`.
- `_split_prompt_and_privileged_context` currently returns one privileged context per example.
- `_compute_self_distillation_loss` builds student logits, teacher logits, top-k distributions, reverse-KL token loss, importance clipping, and final masked aggregation.
- `trl/experimental/sdpo/sdpo_trainer.py`
- `SuccessfulRolloutTeacherContextBuilder.build` turns `privileged_context` into `teacher_input_ids`, `teacher_attention_mask`, and `self_distillation_mask`.
- `SDPOTrainer._generate_and_score_completions` calls the teacher-context builder and attaches those tensors to the training batch.
- `trl/experimental/sdpo/sdpo_config.py`
- `SDPOConfig` owns the SDPO-specific CLI arguments and validation.
- `tests/experimental/test_sdpo_trainer.py`
- Existing tests already cover top-k distillation, teacher prompt construction, old-log-prob clipping setup, and teacher attention masks.
- Add MCSD coverage here.
Repository-specific constraint: this is experimental code, but keep duplicated SDPO/self-distillation patterns consistent. Add small local helpers only where they make the MCSD branch testable and readable.
## 2. Objective to implement
For each generated response token position `t`, keep the same student top-k token set used by OPSD:
```text
K_t = TopK_M(p_s,t), where M = args.distillation_topk
```
Renormalize the student and every criterion teacher over the same top-k set:
```text
log_q_s = log p_s over K_t, renormalized over K_t
log_q_t[:, j] = log p_teacher_j over K_t, renormalized over K_t
```
Build the merged MCSD teacher as:
```python
log_q_base = log_q_s.detach()
advantage = log_q_t - log_q_base.unsqueeze(1)
log_gate = F.logsigmoid(advantage - args.mcsd_gate_bias)
merged_log_unnorm = log_q_base + log_gate.sum(dim=1)
log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True)
log_q_mcsd = log_q_mcsd.detach()
```
Then replace the OPSD teacher distribution in the existing reverse-KL token loss:
```python
token_loss = (log_q_s.exp() * (log_q_s - log_q_mcsd)).sum(dim=-1)
```
Keep the existing completion mask, `self_distillation_mask`, `loss_type` aggregation, and `distillation_is_clip` logic.
Important details:
- Select top-k from the student distribution, not from teacher distributions.
- Use the same top-k indices for all teachers.
- Detach `log_q_s` before using it as the MCSD base measure.
- Apply `F.logsigmoid` to each criterion first, then sum the transformed criterion values in log-space.
- Do not average teacher distributions.
- Do not sum raw teacher log-probs.
- Detach `log_q_mcsd` before computing the reverse KL.
## 3. Add configuration arguments
Edit `trl/experimental/sdpo/sdpo_config.py`.
Add fields to `SDPOConfig`:
```python
mcsd_enable: bool = field(
default=False,
metadata={"help": "Enable Multi-Criteria Self-Distillation."},
)
mcsd_merge_mode: str = field(
default="lsc_poe",
metadata={"help": "MCSD teacher merge mode. Supported: `lsc_poe`."},
)
mcsd_gate_bias: float = field(
default=0.0,
metadata={"help": "Scalar bias subtracted from each MCSD criterion advantage before logsigmoid gating."},
)
```
Add these entries to the class docstring using the repository docstring format.
Extend `__post_init__` with narrow first-pass validation:
```python
if self.mcsd_enable:
if self.mcsd_merge_mode != "lsc_poe":
raise ValueError("mcsd_merge_mode must be `lsc_poe` when MCSD is enabled.")
if self.sdpo_policy_loss_mode != "distillation_only":
raise ValueError("MCSD currently supports `sdpo_policy_loss_mode='distillation_only'`.")
if not self.full_logit_distillation:
raise ValueError("MCSD requires `full_logit_distillation=True`.")
if self.distillation_topk is None:
raise ValueError("MCSD requires `distillation_topk`.")
if self.distillation_alpha != 1.0:
raise ValueError("MCSD requires reverse KL, so `distillation_alpha` must be 1.0.")
if self.distillation_add_tail:
raise ValueError("MCSD does not support `distillation_add_tail=True`.")
```
The intended CLI is:
```bash
--mcsd_enable true \
--mcsd_merge_mode lsc_poe \
--mcsd_gate_bias 0.0 \
--sdpo_policy_loss_mode distillation_only \
--full_logit_distillation true \
--distillation_topk 20 \
--distillation_alpha 1.0
```
## 4. Extend the dataset format
Keep the old single-context field for OPSD:
```python
"privileged_context": str
```
Add a plural field for MCSD:
```python
"privileged_contexts": list[str]
```
Each item in `privileged_contexts` is one criterion-specific teacher-only context. The number of criteria may differ by example.
For the current RAR data, `data/rar_science/*.jsonl` and `data/rar_medicine/*.jsonl` already contain rubric fields such as:
```python
"rubric": [{"description": "...", "title": "...", "weight": ...}, ...]
"rubric_list": ["Essential Criteria: ...", ...]
"rubric_count": 7
```
In `examples/scripts/sdpo_rar_science.py`, keep `_build_privileged_context` unchanged for OPSD and add separate criterion-level templates for MCSD:
```python
POSITIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation criteria:
{positive_rubric}
Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.
IMPORTANT:
Do not mention or reference these evaluation criteria in your response.
Do not indicate that you have seen any scoring rubric or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
NEGATIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation pitfall to avoid:
{negative_rubric}
Please provide a comprehensive and helpful response that addresses the question while carefully avoiding the issue described above.
IMPORTANT:
Do not mention or reference this evaluation criterion in your response.
Do not indicate that you have seen any scoring rubric, pitfall list, or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
```
Then add a criterion-level builder:
```python
def _build_privileged_contexts_by_criterion(
example: dict[str, Any], max_items_per_section: int | None
) -> list[str]:
include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)
contexts = []
for item in include_items:
contexts.append(
POSITIVE_RUBRIC_TEACHER_TEMPLATE.format(positive_rubric=item)
)
for item in avoid_items:
contexts.append(
NEGATIVE_RUBRIC_TEACHER_TEMPLATE.format(negative_rubric=item)
)
return contexts
```
Then update `_make_conversation`:
```python
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["question"]},
],
"solution": example["reference_answer"],
"privileged_context": _build_privileged_context(example, max_items_per_section),
"privileged_contexts": _build_privileged_contexts_by_criterion(example, max_items_per_section),
}
```
Also update `SelfDistillationMixin._set_signature_columns_if_needed`:
```python
self._signature_columns = ["prompt", "privileged_context", "privileged_contexts"]
```
Keep `_split_prompt_and_privileged_context` unchanged for the old OPSD path. Add a new helper for MCSD:
```python
@staticmethod
def _split_prompt_and_privileged_contexts(
inputs: list[dict[str, Any]],
) -> tuple[list[Any], list[Any], list[Any]]:
prompts = [example["prompt"] for example in inputs]
privileged_contexts = [example.get("privileged_context") for example in inputs]
criterion_contexts = [
example.get("privileged_contexts", example.get("privileged_context"))
for example in inputs
]
return prompts, privileged_contexts, criterion_contexts
```
Then update `SDPOTrainer._generate_and_score_completions`:
```python
prompts, privileged_contexts, criterion_contexts = self._split_prompt_and_privileged_contexts(inputs)
teacher_feedbacks = criterion_contexts if self.args.mcsd_enable else privileged_contexts
output = super()._generate_and_score_completions(inputs)
output.update(
self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=teacher_feedbacks)
)
```
This keeps the single-teacher field as the source for OPSD and the plural field as the source for MCSD.
Recommended batch convention:
```python
prompts: list[Any]
privileged_contexts: list[str | None] # old OPSD
criterion_contexts: list[list[str] | str | None] # MCSD
```
Treat a plain string in `criterion_contexts` as a one-criterion list so existing datasets can still be smoke-tested with MCSD.
## 5. Build criterion-specific teacher contexts
Edit `SuccessfulRolloutTeacherContextBuilder` in `trl/experimental/sdpo/sdpo_trainer.py`.
Do not rewrite the existing single-teacher path. Add a branch near the start of `build`:
```python
if self.trainer.args.mcsd_enable:
return self._build_mcsd(output, prompts, rewards, feedbacks=feedbacks)
```
Keep the existing `build` body as the OPSD implementation.
The MCSD builder returns:
```python
{
"teacher_input_ids_by_criterion": teacher_input_ids_by_criterion, # [B, K_max, L]
"teacher_attention_mask_by_criterion": teacher_attention_mask_by_criterion,# [B, K_max, L]
"mcsd_criterion_mask": mcsd_criterion_mask, # [B, K_max]
"mcsd_num_criteria": mcsd_num_criteria, # [B]
"self_distillation_mask": local_self_distillation_mask, # [B]
}
```
Implementation outline:
1. Gather rewards, completions, prompts, and feedbacks exactly like the current builder does.
2. For each global sample, construct a list of active criterion contexts.
3. Use `privileged_contexts` when available.
4. If only the old `privileged_context` string is available, wrap it as a one-element list.
5. If a sample has no active criteria, use one placeholder teacher prompt and set its criterion mask to `0`.
6. For active criteria, set the criterion mask to `1`.
7. For variable `K`, pad criteria to the local `K_max`.
8. Tokenize the flattened list of teacher messages with the existing `_tokenize_teacher_messages`.
9. Repeat `completion_ids` and `completion_mask` by `K_max`.
10. Concatenate teacher prompts and completions exactly like OPSD does.
11. Reshape flattened teacher tensors back to `[B, K_max, L]`.
Sketch:
```python
contexts_by_sample = []
mask_rows = []
mcsd_num_criteria = []
for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)):
raw_contexts = all_feedbacks[global_idx]
contexts = _normalize_criterion_contexts(raw_contexts)
if len(contexts) == 0:
contexts_by_sample.append([None])
mask_rows.append([0.0])
mcsd_num_criteria.append(0)
else:
contexts_by_sample.append(contexts)
mask_rows.append([1.0] * len(contexts))
mcsd_num_criteria.append(len(contexts))
max_criteria = max(len(contexts) for contexts in contexts_by_sample)
flat_teacher_messages = []
for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)):
original_prompt = all_prompts[global_idx]
contexts = list(contexts_by_sample[local_idx])
while len(contexts) < max_criteria:
contexts.append(None)
mask_rows[local_idx].append(0.0)
for context in contexts:
flat_teacher_messages.append(_build_teacher_message(original_prompt, context))
mcsd_criterion_mask = torch.tensor(mask_rows, device=device)
mcsd_num_criteria = torch.tensor(mcsd_num_criteria, device=device, dtype=torch.long)
local_self_distillation_mask = (mcsd_num_criteria > 0).float()
```
Tokenize `flat_teacher_messages`, concatenate the repeated completions, and reshape:
```python
teacher_batch = self._tokenize_teacher_messages(flat_teacher_messages)
flat_completion_ids = completion_ids.repeat_interleave(max_criteria, dim=0)
flat_completion_mask = completion_mask.repeat_interleave(max_criteria, dim=0)
flat_teacher_input_ids = torch.cat([teacher_batch.prompt_ids, flat_completion_ids], dim=1)
flat_teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, flat_completion_mask], dim=1)
teacher_input_ids_by_criterion = flat_teacher_input_ids.reshape(num_local, max_criteria, -1)
teacher_attention_mask_by_criterion = flat_teacher_attention_mask.reshape(num_local, max_criteria, -1)
```
A masked criterion contributes `0` to `log_gate.sum(dim=1)`, which is the neutral log-space contribution.
For RAR Science/Medicine MCSD, set:
```bash
--include_environment_feedback true \
--use_successful_as_teacher false
```
That makes each rubric item a criterion-specific hidden teacher context and avoids mixing successful-rollout mining into the first MCSD implementation.
## 6. Add the MCSD loss branch
Edit `SelfDistillationMixin._compute_self_distillation_loss` in `trl/experimental/self_distillation/self_distillation_mixin.py`.
The current method does this:
1. Builds `response_mask`.
2. Builds student logits.
3. Builds one teacher logits tensor.
4. Computes sampled-token logps for diagnostics and clipping.
5. Computes top-k reverse KL when `distillation_topk` and `full_logit_distillation` are active.
6. Applies importance clipping.
7. Aggregates with `_aggregate_self_distillation_loss`.
Add an MCSD branch after student logits are available and before the single-teacher forward pass:
```python
if self.args.mcsd_enable:
return self._compute_mcsd_self_distillation_loss(
model=model,
inputs=inputs,
completion_ids=completion_ids,
completion_mask=completion_mask,
response_mask=response_mask,
logits_to_keep=logits_to_keep,
student_logits=student_logits,
)
```
Keep all existing code below that branch for OPSD.
### 6.1 Forward all criterion teachers
Inside `_compute_mcsd_self_distillation_loss`, read the new tensors:
```python
teacher_input_ids = inputs["teacher_input_ids_by_criterion"] # [B, K, L]
teacher_attention_mask = inputs["teacher_attention_mask_by_criterion"] # [B, K, L]
criterion_mask = inputs["mcsd_criterion_mask"].to(student_logits.dtype) # [B, K]
```
Flatten the criterion dimension:
```python
batch_size, num_criteria, teacher_seq_len = teacher_input_ids.shape
flat_teacher_input_ids = teacher_input_ids.reshape(batch_size * num_criteria, teacher_seq_len)
flat_teacher_attention_mask = teacher_attention_mask.reshape(batch_size * num_criteria, teacher_seq_len)
```
Run the teacher model under the same no-grad/context manager as OPSD:
```python
teacher_model_inputs = {
"input_ids": flat_teacher_input_ids,
"attention_mask": flat_teacher_attention_mask,
"use_cache": False,
}
if "logits_to_keep" in self.model_kwarg_keys:
teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1
teacher_model = self._get_teacher_model_for_self_distillation(model)
with torch.no_grad(), self._get_teacher_context_for_self_distillation(model):
flat_teacher_logits = teacher_model(**teacher_model_inputs).logits
```
Then align logits exactly like OPSD:
```python
flat_teacher_logits = flat_teacher_logits[:, :-1, :]
flat_teacher_logits = flat_teacher_logits[:, -logits_to_keep:, :]
flat_teacher_logits = flat_teacher_logits / self.temperature
teacher_logits = flat_teacher_logits.reshape(batch_size, num_criteria, logits_to_keep, -1)
```
This alignment is required because teacher prompts are longer than student prompts, but both tensors end with the same generated completion.
### 6.2 Compute top-k MCSD reverse KL
Use this code shape for the core math:
```python
student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True)
student_log_probs = student_logits - student_logsumexp
topk_student_log_probs, topk_indices = torch.topk(
student_log_probs,
k=self.args.distillation_topk,
dim=-1,
)
log_q_s = self._renorm_topk_log_probs(topk_student_log_probs)
q_s = log_q_s.exp()
teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True)
teacher_log_probs = teacher_logits - teacher_logsumexp
teacher_topk_indices = topk_indices.unsqueeze(1).expand(-1, num_criteria, -1, -1)
teacher_topk_log_probs = torch.gather(teacher_log_probs, dim=-1, index=teacher_topk_indices)
log_q_t = self._renorm_topk_log_probs(teacher_topk_log_probs)
log_q_base = log_q_s.detach()
advantage = log_q_t - log_q_base.unsqueeze(1)
log_gate = F.logsigmoid(advantage - self.args.mcsd_gate_bias)
log_gate = log_gate * criterion_mask[:, :, None, None]
merged_log_unnorm = log_q_base + log_gate.sum(dim=1)
log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True)
log_q_mcsd = log_q_mcsd.detach()
per_token_loss = (q_s * (log_q_s - log_q_mcsd)).sum(dim=-1)
```
Notes:
- `criterion_mask == 0` makes the criterion contribute `0` to the log-space expert sum.
- `self_distillation_mask` already zeros out samples with no active teacher supervision through `response_mask`.
- Reject `distillation_add_tail=True` for MCSD in this implementation:
```python
if self.mcsd_enable and self.distillation_add_tail:
raise ValueError("MCSD does not support `distillation_add_tail=True` in the first implementation.")
```
### 6.3 Keep sampled-token logps for clipping
The existing clipping path uses sampled-token student log-probs and `old_per_token_logps`.
Compute sampled-token logps from `student_logits`:
```python
idx = completion_ids.unsqueeze(-1)
student_per_token_logps = torch.gather(student_log_probs, dim=-1, index=idx).squeeze(-1)
```
Then reuse the existing helper:
```python
if self.args.distillation_is_clip is not None:
old_log_probs = inputs.get("old_per_token_logps")
if old_log_probs is not None:
per_token_loss = self._apply_importance_sampling_clipping(
per_token_loss,
student_per_token_logps,
old_log_probs,
self.args.distillation_is_clip,
)
```
Keep the existing pointwise KL diagnostics if possible. Use the same code from the OPSD branch so clipping metrics remain aligned.
### 6.4 Aggregate exactly like OPSD
Use the existing aggregation helper:
```python
loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask)
```
Log the distillation loss through the existing metric helper so `train/loss_opsd` continues to be populated by the base trainer:
```python
mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0)
self._log_self_distillation_metric(
mode,
"distillation_loss",
self.accelerator.gather(mean_distill_loss).mean().item(),
)
```
Also log MCSD-specific diagnostics.
## 7. Log MCSD diagnostics
Keep the existing OPSD TensorBoard keys alive.
The base trainer maps `self_distillation/distillation_loss` to:
```text
train/loss_opsd
```
Therefore the MCSD loss branch must still call:
```python
self._log_self_distillation_metric(mode, "distillation_loss", ...)
```
This preserves the old `train/loss_opsd` panel. It also keeps `train/loss_total`, `train/grad_norm`, `progress/*`, `rollout/*`, and `system/*` unchanged because those are logged outside the teacher-merge logic.
For the old `opsd/*` diagnostic keys, continue emitting them in MCSD mode so existing dashboards do not go blank:
```text
opsd/student_logp_mean
opsd/teacher_logp_mean
opsd/teacher_student_logp_gap_mean
opsd/token_advantage_mean
opsd/token_advantage_std
opsd/pointwise_kl_mean
opsd/pointwise_kl_clip_frac
opsd/pointwise_kl_clipped_mean
```
In MCSD mode, interpret the teacher-side `opsd/*` keys as diagnostics against the merged MCSD teacher, not against a single privileged teacher. For sampled-token teacher metrics, gather the sampled response token from `log_q_mcsd` only when the sampled token is inside the student top-k set; otherwise skip that token for those compatibility diagnostics. The primary MCSD diagnostics below are the authoritative metrics for the merged teacher.
Use slash-style metric names to match the repository:
```python
mcsd/loss
mcsd/num_criteria_mean
mcsd/gate_bias
mcsd/teacher_entropy_mean
mcsd/student_entropy_mean
mcsd/merged_entropy_mean
mcsd/gate_mean
mcsd/gate_min
mcsd/gate_max
```
Compute them under `torch.no_grad()` and apply masks where practical:
```python
gate = log_gate.exp()
student_entropy = -(q_s * log_q_s).sum(dim=-1)
merged_entropy = -(log_q_mcsd.exp() * log_q_mcsd).sum(dim=-1)
teacher_entropy = -(log_q_t.exp() * log_q_t).sum(dim=-1)
```
For criterion-level metrics, mask with:
```python
criterion_token_mask = criterion_mask[:, :, None] * response_mask[:, None, :]
```
For token-level metrics, mask with:
```python
response_mask
```
Gather metric tensors through `self.accelerator.gather(...)` before appending scalar `.item()` values, following the style already used in `self_distillation_mixin.py`.
## 8. Update callbacks and artifacts only as needed
The existing callback `on_teacher_context_built` receives:
```python
teacher_input_ids
teacher_attention_mask
completion_mask
self_distillation_mask
```
For MCSD, pass the new tensors too:
```python
teacher_input_ids_by_criterion=output["teacher_input_ids_by_criterion"]
teacher_attention_mask_by_criterion=output["teacher_attention_mask_by_criterion"]
mcsd_criterion_mask=output["mcsd_criterion_mask"]
mcsd_num_criteria=output["mcsd_num_criteria"]
```
Do not remove the old callback payload for OPSD.
For rollout artifacts, add the number of active criteria per sample. Avoid dumping full teacher logits.
## 9. Update the launch script
In `run_sdpo.sh`, add the MCSD flags:
```bash
--mcsd_enable true \
--mcsd_merge_mode lsc_poe \
--mcsd_gate_bias 0.0 \
--include_environment_feedback true \
--use_successful_as_teacher false \
--sdpo_policy_loss_mode distillation_only \
--full_logit_distillation true \
--distillation_topk 20 \
--distillation_alpha 1.0 \
```
Keep `--distillation_is_clip 2` if you want the same importance-clipping behavior as OPSD.
## 10. Add tests
Add focused tests in `tests/experimental/test_sdpo_trainer.py`.
Minimum test list:
1. Config defaults
- `SDPOConfig(output_dir=...).mcsd_enable is False`
- Existing tests still pass without setting any MCSD flag.
2. Config validation
- `mcsd_enable=True` rejects `full_logit_distillation=False`.
- `mcsd_enable=True` rejects `distillation_topk=None`.
- `mcsd_enable=True` rejects `distillation_alpha != 1.0`.
- `mcsd_enable=True` rejects unsupported `mcsd_merge_mode`.
3. Dataset retention
- A dataset with `privileged_contexts` survives trainer column filtering.
4. Teacher context construction
- A sample with three criterion contexts returns `teacher_input_ids_by_criterion` with shape `[B, 3, L]`.
- `mcsd_criterion_mask` is `[1, 1, 1]`.
- A sample with fewer criteria is padded and has zeros in the padded criterion mask.
- Teacher completion attention still equals the original `completion_mask` on every active criterion.
5. MCSD math unit test
- Use tiny synthetic `student_logits` and `teacher_logits`.
- Verify the merged teacher equals:
```python
softmax(log_q_s.detach() + F.logsigmoid(log_q_t - log_q_s.detach()).sum(dim=1))
```
- Verify it is not equal to an arithmetic average of teacher distributions.
- Verify inactive criteria do not change the merged teacher.
6. End-to-end tiny training
- Use `trl-internal-testing/tiny-Qwen2ForCausalLM-2.5`.
- Dataset has `prompt`, `privileged_context`, and `privileged_contexts`.
- Args include:
```python
mcsd_enable=True,
mcsd_merge_mode="lsc_poe",
mcsd_gate_bias=0.0,
include_environment_feedback=True,
use_successful_as_teacher=False,
sdpo_policy_loss_mode="distillation_only",
full_logit_distillation=True,
distillation_topk=5,
distillation_alpha=1.0,
distillation_is_clip=None,
max_steps=1,
```
- Assert `train_loss` is present.
- Assert at least one `mcsd/*` metric is logged.
Run the targeted tests:
```bash
pytest tests/experimental/test_sdpo_trainer.py -k "mcsd or training"
```
Then run the full SDPO test file:
```bash
pytest tests/experimental/test_sdpo_trainer.py
```
## 11. Paper index requirement
If MCSD is being implemented from a paper, add a subsection to `paper_index.md`. Per repo guidance, use Hugging Face paper links:
```text
https://huggingface.co/papers/<paper-id>
```
Do this in the same PR as the MCSD implementation.
## 12. Implementation checklist
1. Add `mcsd_enable`, `mcsd_merge_mode`, and `mcsd_gate_bias` to `SDPOConfig`.
2. Validate MCSD only for top-k full-logit reverse-KL distillation.
3. Preserve `privileged_context` and add `privileged_contexts` to the RAR dataset mapping.
4. Keep `privileged_contexts` during trainer column filtering.
5. Add an MCSD branch to the teacher-context builder.
6. Return `[B, K_max, L]` teacher tensors plus `mcsd_criterion_mask`.
7. Add an MCSD branch to `_compute_self_distillation_loss`.
8. Forward criterion teachers by flattening `[B, K, L]` to `[B*K, L]`.
9. Gather every teacher over the student top-k token indices.
10. Build the merged teacher with `log_q_s.detach() + sum_j(logsigmoid(log_q_t_j - log_q_s.detach() - bias))`.
11. Compute reverse KL from current student top-k distribution to detached merged teacher.
12. Reuse existing importance clipping and loss aggregation.
13. Log `mcsd/*` diagnostics.
14. Update callbacks/artifacts with criterion counts and masks.
15. Add tests for config, data retention, context construction, MCSD math, and tiny training.
16. Update `paper_index.md` if MCSD corresponds to a paper implementation.
Final behavior: when MCSD is disabled, this repository trains exactly as current OPSD; when MCSD is enabled, the single teacher distribution is replaced by an LSC-transformed product-of-experts teacher over the same student-selected top-k tokens.