MCSD implementation guide for this TRL fork
This guide describes how to implement Multi-Criteria Self-Distillation (MCSD) on top of the current experimental SDPO stack in this repository.
The current OPSD path has one student context and one privileged teacher context. MCSD keeps the student path unchanged, replaces the single teacher with multiple criterion-specific privileged teachers, and merges those teachers with an LSC-style product of experts.
Do not change the default behavior. mcsd_enable=False must preserve the existing OPSD behavior.
1. Current code path to extend
Use these files as the implementation map:
examples/scripts/sdpo_rar_science.py- Builds the local RAR Science dataset.
- Currently creates one
privileged_contextstring from the full rubric. - The student sees only
prompt; the teacher seesprivileged_context + prompt + completion.
trl/experimental/self_distillation/self_distillation_mixin.py_set_signature_columns_if_neededcurrently keepspromptandprivileged_context._split_prompt_and_privileged_contextcurrently returns one privileged context per example._compute_self_distillation_lossbuilds student logits, teacher logits, top-k distributions, reverse-KL token loss, importance clipping, and final masked aggregation.
trl/experimental/sdpo/sdpo_trainer.pySuccessfulRolloutTeacherContextBuilder.buildturnsprivileged_contextintoteacher_input_ids,teacher_attention_mask, andself_distillation_mask.SDPOTrainer._generate_and_score_completionscalls the teacher-context builder and attaches those tensors to the training batch.
trl/experimental/sdpo/sdpo_config.pySDPOConfigowns the SDPO-specific CLI arguments and validation.
tests/experimental/test_sdpo_trainer.py- Existing tests already cover top-k distillation, teacher prompt construction, old-log-prob clipping setup, and teacher attention masks.
- Add MCSD coverage here.
Repository-specific constraint: this is experimental code, but keep duplicated SDPO/self-distillation patterns consistent. Add small local helpers only where they make the MCSD branch testable and readable.
2. Objective to implement
For each generated response token position t, keep the same student top-k token set used by OPSD:
K_t = TopK_M(p_s,t), where M = args.distillation_topk
Renormalize the student and every criterion teacher over the same top-k set:
log_q_s = log p_s over K_t, renormalized over K_t
log_q_t[:, j] = log p_teacher_j over K_t, renormalized over K_t
Build the merged MCSD teacher as:
log_q_base = log_q_s.detach()
advantage = log_q_t - log_q_base.unsqueeze(1)
log_gate = F.logsigmoid(advantage - args.mcsd_gate_bias)
merged_log_unnorm = log_q_base + log_gate.sum(dim=1)
log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True)
log_q_mcsd = log_q_mcsd.detach()
Then replace the OPSD teacher distribution in the existing reverse-KL token loss:
token_loss = (log_q_s.exp() * (log_q_s - log_q_mcsd)).sum(dim=-1)
Keep the existing completion mask, self_distillation_mask, loss_type aggregation, and distillation_is_clip logic.
Important details:
- Select top-k from the student distribution, not from teacher distributions.
- Use the same top-k indices for all teachers.
- Detach
log_q_sbefore using it as the MCSD base measure. - Apply
F.logsigmoidto each criterion first, then sum the transformed criterion values in log-space. - Do not average teacher distributions.
- Do not sum raw teacher log-probs.
- Detach
log_q_mcsdbefore computing the reverse KL.
3. Add configuration arguments
Edit trl/experimental/sdpo/sdpo_config.py.
Add fields to SDPOConfig:
mcsd_enable: bool = field(
default=False,
metadata={"help": "Enable Multi-Criteria Self-Distillation."},
)
mcsd_merge_mode: str = field(
default="lsc_poe",
metadata={"help": "MCSD teacher merge mode. Supported: `lsc_poe`."},
)
mcsd_gate_bias: float = field(
default=0.0,
metadata={"help": "Scalar bias subtracted from each MCSD criterion advantage before logsigmoid gating."},
)
Add these entries to the class docstring using the repository docstring format.
Extend __post_init__ with narrow first-pass validation:
if self.mcsd_enable:
if self.mcsd_merge_mode != "lsc_poe":
raise ValueError("mcsd_merge_mode must be `lsc_poe` when MCSD is enabled.")
if self.sdpo_policy_loss_mode != "distillation_only":
raise ValueError("MCSD currently supports `sdpo_policy_loss_mode='distillation_only'`.")
if not self.full_logit_distillation:
raise ValueError("MCSD requires `full_logit_distillation=True`.")
if self.distillation_topk is None:
raise ValueError("MCSD requires `distillation_topk`.")
if self.distillation_alpha != 1.0:
raise ValueError("MCSD requires reverse KL, so `distillation_alpha` must be 1.0.")
if self.distillation_add_tail:
raise ValueError("MCSD does not support `distillation_add_tail=True`.")
The intended CLI is:
--mcsd_enable true \
--mcsd_merge_mode lsc_poe \
--mcsd_gate_bias 0.0 \
--sdpo_policy_loss_mode distillation_only \
--full_logit_distillation true \
--distillation_topk 20 \
--distillation_alpha 1.0
4. Extend the dataset format
Keep the old single-context field for OPSD:
"privileged_context": str
Add a plural field for MCSD:
"privileged_contexts": list[str]
Each item in privileged_contexts is one criterion-specific teacher-only context. The number of criteria may differ by example.
For the current RAR data, data/rar_science/*.jsonl and data/rar_medicine/*.jsonl already contain rubric fields such as:
"rubric": [{"description": "...", "title": "...", "weight": ...}, ...]
"rubric_list": ["Essential Criteria: ...", ...]
"rubric_count": 7
In examples/scripts/sdpo_rar_science.py, keep _build_privileged_context unchanged for OPSD and add separate criterion-level templates for MCSD:
POSITIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation criteria:
{positive_rubric}
Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.
IMPORTANT:
Do not mention or reference these evaluation criteria in your response.
Do not indicate that you have seen any scoring rubric or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
NEGATIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation pitfall to avoid:
{negative_rubric}
Please provide a comprehensive and helpful response that addresses the question while carefully avoiding the issue described above.
IMPORTANT:
Do not mention or reference this evaluation criterion in your response.
Do not indicate that you have seen any scoring rubric, pitfall list, or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
Then add a criterion-level builder:
def _build_privileged_contexts_by_criterion(
example: dict[str, Any], max_items_per_section: int | None
) -> list[str]:
include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)
contexts = []
for item in include_items:
contexts.append(
POSITIVE_RUBRIC_TEACHER_TEMPLATE.format(positive_rubric=item)
)
for item in avoid_items:
contexts.append(
NEGATIVE_RUBRIC_TEACHER_TEMPLATE.format(negative_rubric=item)
)
return contexts
Then update _make_conversation:
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["question"]},
],
"solution": example["reference_answer"],
"privileged_context": _build_privileged_context(example, max_items_per_section),
"privileged_contexts": _build_privileged_contexts_by_criterion(example, max_items_per_section),
}
Also update SelfDistillationMixin._set_signature_columns_if_needed:
self._signature_columns = ["prompt", "privileged_context", "privileged_contexts"]
Keep _split_prompt_and_privileged_context unchanged for the old OPSD path. Add a new helper for MCSD:
@staticmethod
def _split_prompt_and_privileged_contexts(
inputs: list[dict[str, Any]],
) -> tuple[list[Any], list[Any], list[Any]]:
prompts = [example["prompt"] for example in inputs]
privileged_contexts = [example.get("privileged_context") for example in inputs]
criterion_contexts = [
example.get("privileged_contexts", example.get("privileged_context"))
for example in inputs
]
return prompts, privileged_contexts, criterion_contexts
Then update SDPOTrainer._generate_and_score_completions:
prompts, privileged_contexts, criterion_contexts = self._split_prompt_and_privileged_contexts(inputs)
teacher_feedbacks = criterion_contexts if self.args.mcsd_enable else privileged_contexts
output = super()._generate_and_score_completions(inputs)
output.update(
self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=teacher_feedbacks)
)
This keeps the single-teacher field as the source for OPSD and the plural field as the source for MCSD.
Recommended batch convention:
prompts: list[Any]
privileged_contexts: list[str | None] # old OPSD
criterion_contexts: list[list[str] | str | None] # MCSD
Treat a plain string in criterion_contexts as a one-criterion list so existing datasets can still be smoke-tested with MCSD.
5. Build criterion-specific teacher contexts
Edit SuccessfulRolloutTeacherContextBuilder in trl/experimental/sdpo/sdpo_trainer.py.
Do not rewrite the existing single-teacher path. Add a branch near the start of build:
if self.trainer.args.mcsd_enable:
return self._build_mcsd(output, prompts, rewards, feedbacks=feedbacks)
Keep the existing build body as the OPSD implementation.
The MCSD builder returns:
{
"teacher_input_ids_by_criterion": teacher_input_ids_by_criterion, # [B, K_max, L]
"teacher_attention_mask_by_criterion": teacher_attention_mask_by_criterion,# [B, K_max, L]
"mcsd_criterion_mask": mcsd_criterion_mask, # [B, K_max]
"mcsd_num_criteria": mcsd_num_criteria, # [B]
"self_distillation_mask": local_self_distillation_mask, # [B]
}
Implementation outline:
- Gather rewards, completions, prompts, and feedbacks exactly like the current builder does.
- For each global sample, construct a list of active criterion contexts.
- Use
privileged_contextswhen available. - If only the old
privileged_contextstring is available, wrap it as a one-element list. - If a sample has no active criteria, use one placeholder teacher prompt and set its criterion mask to
0. - For active criteria, set the criterion mask to
1. - For variable
K, pad criteria to the localK_max. - Tokenize the flattened list of teacher messages with the existing
_tokenize_teacher_messages. - Repeat
completion_idsandcompletion_maskbyK_max. - Concatenate teacher prompts and completions exactly like OPSD does.
- Reshape flattened teacher tensors back to
[B, K_max, L].
Sketch:
contexts_by_sample = []
mask_rows = []
mcsd_num_criteria = []
for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)):
raw_contexts = all_feedbacks[global_idx]
contexts = _normalize_criterion_contexts(raw_contexts)
if len(contexts) == 0:
contexts_by_sample.append([None])
mask_rows.append([0.0])
mcsd_num_criteria.append(0)
else:
contexts_by_sample.append(contexts)
mask_rows.append([1.0] * len(contexts))
mcsd_num_criteria.append(len(contexts))
max_criteria = max(len(contexts) for contexts in contexts_by_sample)
flat_teacher_messages = []
for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)):
original_prompt = all_prompts[global_idx]
contexts = list(contexts_by_sample[local_idx])
while len(contexts) < max_criteria:
contexts.append(None)
mask_rows[local_idx].append(0.0)
for context in contexts:
flat_teacher_messages.append(_build_teacher_message(original_prompt, context))
mcsd_criterion_mask = torch.tensor(mask_rows, device=device)
mcsd_num_criteria = torch.tensor(mcsd_num_criteria, device=device, dtype=torch.long)
local_self_distillation_mask = (mcsd_num_criteria > 0).float()
Tokenize flat_teacher_messages, concatenate the repeated completions, and reshape:
teacher_batch = self._tokenize_teacher_messages(flat_teacher_messages)
flat_completion_ids = completion_ids.repeat_interleave(max_criteria, dim=0)
flat_completion_mask = completion_mask.repeat_interleave(max_criteria, dim=0)
flat_teacher_input_ids = torch.cat([teacher_batch.prompt_ids, flat_completion_ids], dim=1)
flat_teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, flat_completion_mask], dim=1)
teacher_input_ids_by_criterion = flat_teacher_input_ids.reshape(num_local, max_criteria, -1)
teacher_attention_mask_by_criterion = flat_teacher_attention_mask.reshape(num_local, max_criteria, -1)
A masked criterion contributes 0 to log_gate.sum(dim=1), which is the neutral log-space contribution.
For RAR Science/Medicine MCSD, set:
--include_environment_feedback true \
--use_successful_as_teacher false
That makes each rubric item a criterion-specific hidden teacher context and avoids mixing successful-rollout mining into the first MCSD implementation.
6. Add the MCSD loss branch
Edit SelfDistillationMixin._compute_self_distillation_loss in trl/experimental/self_distillation/self_distillation_mixin.py.
The current method does this:
- Builds
response_mask. - Builds student logits.
- Builds one teacher logits tensor.
- Computes sampled-token logps for diagnostics and clipping.
- Computes top-k reverse KL when
distillation_topkandfull_logit_distillationare active. - Applies importance clipping.
- Aggregates with
_aggregate_self_distillation_loss.
Add an MCSD branch after student logits are available and before the single-teacher forward pass:
if self.args.mcsd_enable:
return self._compute_mcsd_self_distillation_loss(
model=model,
inputs=inputs,
completion_ids=completion_ids,
completion_mask=completion_mask,
response_mask=response_mask,
logits_to_keep=logits_to_keep,
student_logits=student_logits,
)
Keep all existing code below that branch for OPSD.
6.1 Forward all criterion teachers
Inside _compute_mcsd_self_distillation_loss, read the new tensors:
teacher_input_ids = inputs["teacher_input_ids_by_criterion"] # [B, K, L]
teacher_attention_mask = inputs["teacher_attention_mask_by_criterion"] # [B, K, L]
criterion_mask = inputs["mcsd_criterion_mask"].to(student_logits.dtype) # [B, K]
Flatten the criterion dimension:
batch_size, num_criteria, teacher_seq_len = teacher_input_ids.shape
flat_teacher_input_ids = teacher_input_ids.reshape(batch_size * num_criteria, teacher_seq_len)
flat_teacher_attention_mask = teacher_attention_mask.reshape(batch_size * num_criteria, teacher_seq_len)
Run the teacher model under the same no-grad/context manager as OPSD:
teacher_model_inputs = {
"input_ids": flat_teacher_input_ids,
"attention_mask": flat_teacher_attention_mask,
"use_cache": False,
}
if "logits_to_keep" in self.model_kwarg_keys:
teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1
teacher_model = self._get_teacher_model_for_self_distillation(model)
with torch.no_grad(), self._get_teacher_context_for_self_distillation(model):
flat_teacher_logits = teacher_model(**teacher_model_inputs).logits
Then align logits exactly like OPSD:
flat_teacher_logits = flat_teacher_logits[:, :-1, :]
flat_teacher_logits = flat_teacher_logits[:, -logits_to_keep:, :]
flat_teacher_logits = flat_teacher_logits / self.temperature
teacher_logits = flat_teacher_logits.reshape(batch_size, num_criteria, logits_to_keep, -1)
This alignment is required because teacher prompts are longer than student prompts, but both tensors end with the same generated completion.
6.2 Compute top-k MCSD reverse KL
Use this code shape for the core math:
student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True)
student_log_probs = student_logits - student_logsumexp
topk_student_log_probs, topk_indices = torch.topk(
student_log_probs,
k=self.args.distillation_topk,
dim=-1,
)
log_q_s = self._renorm_topk_log_probs(topk_student_log_probs)
q_s = log_q_s.exp()
teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True)
teacher_log_probs = teacher_logits - teacher_logsumexp
teacher_topk_indices = topk_indices.unsqueeze(1).expand(-1, num_criteria, -1, -1)
teacher_topk_log_probs = torch.gather(teacher_log_probs, dim=-1, index=teacher_topk_indices)
log_q_t = self._renorm_topk_log_probs(teacher_topk_log_probs)
log_q_base = log_q_s.detach()
advantage = log_q_t - log_q_base.unsqueeze(1)
log_gate = F.logsigmoid(advantage - self.args.mcsd_gate_bias)
log_gate = log_gate * criterion_mask[:, :, None, None]
merged_log_unnorm = log_q_base + log_gate.sum(dim=1)
log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True)
log_q_mcsd = log_q_mcsd.detach()
per_token_loss = (q_s * (log_q_s - log_q_mcsd)).sum(dim=-1)
Notes:
criterion_mask == 0makes the criterion contribute0to the log-space expert sum.self_distillation_maskalready zeros out samples with no active teacher supervision throughresponse_mask.- Reject
distillation_add_tail=Truefor MCSD in this implementation:
if self.mcsd_enable and self.distillation_add_tail:
raise ValueError("MCSD does not support `distillation_add_tail=True` in the first implementation.")
6.3 Keep sampled-token logps for clipping
The existing clipping path uses sampled-token student log-probs and old_per_token_logps.
Compute sampled-token logps from student_logits:
idx = completion_ids.unsqueeze(-1)
student_per_token_logps = torch.gather(student_log_probs, dim=-1, index=idx).squeeze(-1)
Then reuse the existing helper:
if self.args.distillation_is_clip is not None:
old_log_probs = inputs.get("old_per_token_logps")
if old_log_probs is not None:
per_token_loss = self._apply_importance_sampling_clipping(
per_token_loss,
student_per_token_logps,
old_log_probs,
self.args.distillation_is_clip,
)
Keep the existing pointwise KL diagnostics if possible. Use the same code from the OPSD branch so clipping metrics remain aligned.
6.4 Aggregate exactly like OPSD
Use the existing aggregation helper:
loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask)
Log the distillation loss through the existing metric helper so train/loss_opsd continues to be populated by the base trainer:
mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0)
self._log_self_distillation_metric(
mode,
"distillation_loss",
self.accelerator.gather(mean_distill_loss).mean().item(),
)
Also log MCSD-specific diagnostics.
7. Log MCSD diagnostics
Keep the existing OPSD TensorBoard keys alive.
The base trainer maps self_distillation/distillation_loss to:
train/loss_opsd
Therefore the MCSD loss branch must still call:
self._log_self_distillation_metric(mode, "distillation_loss", ...)
This preserves the old train/loss_opsd panel. It also keeps train/loss_total, train/grad_norm, progress/*, rollout/*, and system/* unchanged because those are logged outside the teacher-merge logic.
For the old opsd/* diagnostic keys, continue emitting them in MCSD mode so existing dashboards do not go blank:
opsd/student_logp_mean
opsd/teacher_logp_mean
opsd/teacher_student_logp_gap_mean
opsd/token_advantage_mean
opsd/token_advantage_std
opsd/pointwise_kl_mean
opsd/pointwise_kl_clip_frac
opsd/pointwise_kl_clipped_mean
In MCSD mode, interpret the teacher-side opsd/* keys as diagnostics against the merged MCSD teacher, not against a single privileged teacher. For sampled-token teacher metrics, gather the sampled response token from log_q_mcsd only when the sampled token is inside the student top-k set; otherwise skip that token for those compatibility diagnostics. The primary MCSD diagnostics below are the authoritative metrics for the merged teacher.
Use slash-style metric names to match the repository:
mcsd/loss
mcsd/num_criteria_mean
mcsd/gate_bias
mcsd/teacher_entropy_mean
mcsd/student_entropy_mean
mcsd/merged_entropy_mean
mcsd/gate_mean
mcsd/gate_min
mcsd/gate_max
Compute them under torch.no_grad() and apply masks where practical:
gate = log_gate.exp()
student_entropy = -(q_s * log_q_s).sum(dim=-1)
merged_entropy = -(log_q_mcsd.exp() * log_q_mcsd).sum(dim=-1)
teacher_entropy = -(log_q_t.exp() * log_q_t).sum(dim=-1)
For criterion-level metrics, mask with:
criterion_token_mask = criterion_mask[:, :, None] * response_mask[:, None, :]
For token-level metrics, mask with:
response_mask
Gather metric tensors through self.accelerator.gather(...) before appending scalar .item() values, following the style already used in self_distillation_mixin.py.
8. Update callbacks and artifacts only as needed
The existing callback on_teacher_context_built receives:
teacher_input_ids
teacher_attention_mask
completion_mask
self_distillation_mask
For MCSD, pass the new tensors too:
teacher_input_ids_by_criterion=output["teacher_input_ids_by_criterion"]
teacher_attention_mask_by_criterion=output["teacher_attention_mask_by_criterion"]
mcsd_criterion_mask=output["mcsd_criterion_mask"]
mcsd_num_criteria=output["mcsd_num_criteria"]
Do not remove the old callback payload for OPSD.
For rollout artifacts, add the number of active criteria per sample. Avoid dumping full teacher logits.
9. Update the launch script
In run_sdpo.sh, add the MCSD flags:
--mcsd_enable true \
--mcsd_merge_mode lsc_poe \
--mcsd_gate_bias 0.0 \
--include_environment_feedback true \
--use_successful_as_teacher false \
--sdpo_policy_loss_mode distillation_only \
--full_logit_distillation true \
--distillation_topk 20 \
--distillation_alpha 1.0 \
Keep --distillation_is_clip 2 if you want the same importance-clipping behavior as OPSD.
10. Add tests
Add focused tests in tests/experimental/test_sdpo_trainer.py.
Minimum test list:
Config defaults
SDPOConfig(output_dir=...).mcsd_enable is False- Existing tests still pass without setting any MCSD flag.
Config validation
mcsd_enable=Truerejectsfull_logit_distillation=False.mcsd_enable=Truerejectsdistillation_topk=None.mcsd_enable=Truerejectsdistillation_alpha != 1.0.mcsd_enable=Truerejects unsupportedmcsd_merge_mode.
Dataset retention
- A dataset with
privileged_contextssurvives trainer column filtering.
- A dataset with
Teacher context construction
- A sample with three criterion contexts returns
teacher_input_ids_by_criterionwith shape[B, 3, L]. mcsd_criterion_maskis[1, 1, 1].- A sample with fewer criteria is padded and has zeros in the padded criterion mask.
- Teacher completion attention still equals the original
completion_maskon every active criterion.
- A sample with three criterion contexts returns
MCSD math unit test
- Use tiny synthetic
student_logitsandteacher_logits. - Verify the merged teacher equals:
- Use tiny synthetic
softmax(log_q_s.detach() + F.logsigmoid(log_q_t - log_q_s.detach()).sum(dim=1))
- Verify it is not equal to an arithmetic average of teacher distributions.
- Verify inactive criteria do not change the merged teacher.
- End-to-end tiny training
- Use
trl-internal-testing/tiny-Qwen2ForCausalLM-2.5. - Dataset has
prompt,privileged_context, andprivileged_contexts. - Args include:
- Use
mcsd_enable=True,
mcsd_merge_mode="lsc_poe",
mcsd_gate_bias=0.0,
include_environment_feedback=True,
use_successful_as_teacher=False,
sdpo_policy_loss_mode="distillation_only",
full_logit_distillation=True,
distillation_topk=5,
distillation_alpha=1.0,
distillation_is_clip=None,
max_steps=1,
- Assert
train_lossis present. - Assert at least one
mcsd/*metric is logged.
Run the targeted tests:
pytest tests/experimental/test_sdpo_trainer.py -k "mcsd or training"
Then run the full SDPO test file:
pytest tests/experimental/test_sdpo_trainer.py
11. Paper index requirement
If MCSD is being implemented from a paper, add a subsection to paper_index.md. Per repo guidance, use Hugging Face paper links:
https://huggingface.co/papers/<paper-id>
Do this in the same PR as the MCSD implementation.
12. Implementation checklist
- Add
mcsd_enable,mcsd_merge_mode, andmcsd_gate_biastoSDPOConfig. - Validate MCSD only for top-k full-logit reverse-KL distillation.
- Preserve
privileged_contextand addprivileged_contextsto the RAR dataset mapping. - Keep
privileged_contextsduring trainer column filtering. - Add an MCSD branch to the teacher-context builder.
- Return
[B, K_max, L]teacher tensors plusmcsd_criterion_mask. - Add an MCSD branch to
_compute_self_distillation_loss. - Forward criterion teachers by flattening
[B, K, L]to[B*K, L]. - Gather every teacher over the student top-k token indices.
- Build the merged teacher with
log_q_s.detach() + sum_j(logsigmoid(log_q_t_j - log_q_s.detach() - bias)). - Compute reverse KL from current student top-k distribution to detached merged teacher.
- Reuse existing importance clipping and loss aggregation.
- Log
mcsd/*diagnostics. - Update callbacks/artifacts with criterion counts and masks.
- Add tests for config, data retention, context construction, MCSD math, and tiny training.
- Update
paper_index.mdif MCSD corresponds to a paper implementation.
Final behavior: when MCSD is disabled, this repository trains exactly as current OPSD; when MCSD is enabled, the single teacher distribution is replaced by an LSC-transformed product-of-experts teacher over the same student-selected top-k tokens.