| # MCSD implementation guide for this TRL fork | |
| This guide describes how to implement Multi-Criteria Self-Distillation (MCSD) on top of the current experimental SDPO stack in this repository. | |
| The current OPSD path has one student context and one privileged teacher context. MCSD keeps the student path unchanged, replaces the single teacher with multiple criterion-specific privileged teachers, and merges those teachers with an LSC-style product of experts. | |
| Do not change the default behavior. `mcsd_enable=False` must preserve the existing OPSD behavior. | |
| ## 1. Current code path to extend | |
| Use these files as the implementation map: | |
| - `examples/scripts/sdpo_rar_science.py` | |
| - Builds the local RAR Science dataset. | |
| - Currently creates one `privileged_context` string from the full rubric. | |
| - The student sees only `prompt`; the teacher sees `privileged_context + prompt + completion`. | |
| - `trl/experimental/self_distillation/self_distillation_mixin.py` | |
| - `_set_signature_columns_if_needed` currently keeps `prompt` and `privileged_context`. | |
| - `_split_prompt_and_privileged_context` currently returns one privileged context per example. | |
| - `_compute_self_distillation_loss` builds student logits, teacher logits, top-k distributions, reverse-KL token loss, importance clipping, and final masked aggregation. | |
| - `trl/experimental/sdpo/sdpo_trainer.py` | |
| - `SuccessfulRolloutTeacherContextBuilder.build` turns `privileged_context` into `teacher_input_ids`, `teacher_attention_mask`, and `self_distillation_mask`. | |
| - `SDPOTrainer._generate_and_score_completions` calls the teacher-context builder and attaches those tensors to the training batch. | |
| - `trl/experimental/sdpo/sdpo_config.py` | |
| - `SDPOConfig` owns the SDPO-specific CLI arguments and validation. | |
| - `tests/experimental/test_sdpo_trainer.py` | |
| - Existing tests already cover top-k distillation, teacher prompt construction, old-log-prob clipping setup, and teacher attention masks. | |
| - Add MCSD coverage here. | |
| Repository-specific constraint: this is experimental code, but keep duplicated SDPO/self-distillation patterns consistent. Add small local helpers only where they make the MCSD branch testable and readable. | |
| ## 2. Objective to implement | |
| For each generated response token position `t`, keep the same student top-k token set used by OPSD: | |
| ```text | |
| K_t = TopK_M(p_s,t), where M = args.distillation_topk | |
| ``` | |
| Renormalize the student and every criterion teacher over the same top-k set: | |
| ```text | |
| log_q_s = log p_s over K_t, renormalized over K_t | |
| log_q_t[:, j] = log p_teacher_j over K_t, renormalized over K_t | |
| ``` | |
| Build the merged MCSD teacher as: | |
| ```python | |
| log_q_base = log_q_s.detach() | |
| advantage = log_q_t - log_q_base.unsqueeze(1) | |
| log_gate = F.logsigmoid(advantage - args.mcsd_gate_bias) | |
| merged_log_unnorm = log_q_base + log_gate.sum(dim=1) | |
| log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True) | |
| log_q_mcsd = log_q_mcsd.detach() | |
| ``` | |
| Then replace the OPSD teacher distribution in the existing reverse-KL token loss: | |
| ```python | |
| token_loss = (log_q_s.exp() * (log_q_s - log_q_mcsd)).sum(dim=-1) | |
| ``` | |
| Keep the existing completion mask, `self_distillation_mask`, `loss_type` aggregation, and `distillation_is_clip` logic. | |
| Important details: | |
| - Select top-k from the student distribution, not from teacher distributions. | |
| - Use the same top-k indices for all teachers. | |
| - Detach `log_q_s` before using it as the MCSD base measure. | |
| - Apply `F.logsigmoid` to each criterion first, then sum the transformed criterion values in log-space. | |
| - Do not average teacher distributions. | |
| - Do not sum raw teacher log-probs. | |
| - Detach `log_q_mcsd` before computing the reverse KL. | |
| ## 3. Add configuration arguments | |
| Edit `trl/experimental/sdpo/sdpo_config.py`. | |
| Add fields to `SDPOConfig`: | |
| ```python | |
| mcsd_enable: bool = field( | |
| default=False, | |
| metadata={"help": "Enable Multi-Criteria Self-Distillation."}, | |
| ) | |
| mcsd_merge_mode: str = field( | |
| default="lsc_poe", | |
| metadata={"help": "MCSD teacher merge mode. Supported: `lsc_poe`."}, | |
| ) | |
| mcsd_gate_bias: float = field( | |
| default=0.0, | |
| metadata={"help": "Scalar bias subtracted from each MCSD criterion advantage before logsigmoid gating."}, | |
| ) | |
| ``` | |
| Add these entries to the class docstring using the repository docstring format. | |
| Extend `__post_init__` with narrow first-pass validation: | |
| ```python | |
| if self.mcsd_enable: | |
| if self.mcsd_merge_mode != "lsc_poe": | |
| raise ValueError("mcsd_merge_mode must be `lsc_poe` when MCSD is enabled.") | |
| if self.sdpo_policy_loss_mode != "distillation_only": | |
| raise ValueError("MCSD currently supports `sdpo_policy_loss_mode='distillation_only'`.") | |
| if not self.full_logit_distillation: | |
| raise ValueError("MCSD requires `full_logit_distillation=True`.") | |
| if self.distillation_topk is None: | |
| raise ValueError("MCSD requires `distillation_topk`.") | |
| if self.distillation_alpha != 1.0: | |
| raise ValueError("MCSD requires reverse KL, so `distillation_alpha` must be 1.0.") | |
| if self.distillation_add_tail: | |
| raise ValueError("MCSD does not support `distillation_add_tail=True`.") | |
| ``` | |
| The intended CLI is: | |
| ```bash | |
| --mcsd_enable true \ | |
| --mcsd_merge_mode lsc_poe \ | |
| --mcsd_gate_bias 0.0 \ | |
| --sdpo_policy_loss_mode distillation_only \ | |
| --full_logit_distillation true \ | |
| --distillation_topk 20 \ | |
| --distillation_alpha 1.0 | |
| ``` | |
| ## 4. Extend the dataset format | |
| Keep the old single-context field for OPSD: | |
| ```python | |
| "privileged_context": str | |
| ``` | |
| Add a plural field for MCSD: | |
| ```python | |
| "privileged_contexts": list[str] | |
| ``` | |
| Each item in `privileged_contexts` is one criterion-specific teacher-only context. The number of criteria may differ by example. | |
| For the current RAR data, `data/rar_science/*.jsonl` and `data/rar_medicine/*.jsonl` already contain rubric fields such as: | |
| ```python | |
| "rubric": [{"description": "...", "title": "...", "weight": ...}, ...] | |
| "rubric_list": ["Essential Criteria: ...", ...] | |
| "rubric_count": 7 | |
| ``` | |
| In `examples/scripts/sdpo_rar_science.py`, keep `_build_privileged_context` unchanged for OPSD and add separate criterion-level templates for MCSD: | |
| ```python | |
| POSITIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation criteria: | |
| {positive_rubric} | |
| Please provide a comprehensive and helpful response that addresses the question while following the above guidelines. | |
| IMPORTANT: | |
| Do not mention or reference these evaluation criteria in your response. | |
| Do not indicate that you have seen any scoring rubric or evaluation guidelines. | |
| Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited. | |
| """ | |
| NEGATIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation pitfall to avoid: | |
| {negative_rubric} | |
| Please provide a comprehensive and helpful response that addresses the question while carefully avoiding the issue described above. | |
| IMPORTANT: | |
| Do not mention or reference this evaluation criterion in your response. | |
| Do not indicate that you have seen any scoring rubric, pitfall list, or evaluation guidelines. | |
| Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited. | |
| """ | |
| ``` | |
| Then add a criterion-level builder: | |
| ```python | |
| def _build_privileged_contexts_by_criterion( | |
| example: dict[str, Any], max_items_per_section: int | None | |
| ) -> list[str]: | |
| include_items, avoid_items = _collect_rubric_items(example, max_items_per_section) | |
| contexts = [] | |
| for item in include_items: | |
| contexts.append( | |
| POSITIVE_RUBRIC_TEACHER_TEMPLATE.format(positive_rubric=item) | |
| ) | |
| for item in avoid_items: | |
| contexts.append( | |
| NEGATIVE_RUBRIC_TEACHER_TEMPLATE.format(negative_rubric=item) | |
| ) | |
| return contexts | |
| ``` | |
| Then update `_make_conversation`: | |
| ```python | |
| return { | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": example["question"]}, | |
| ], | |
| "solution": example["reference_answer"], | |
| "privileged_context": _build_privileged_context(example, max_items_per_section), | |
| "privileged_contexts": _build_privileged_contexts_by_criterion(example, max_items_per_section), | |
| } | |
| ``` | |
| Also update `SelfDistillationMixin._set_signature_columns_if_needed`: | |
| ```python | |
| self._signature_columns = ["prompt", "privileged_context", "privileged_contexts"] | |
| ``` | |
| Keep `_split_prompt_and_privileged_context` unchanged for the old OPSD path. Add a new helper for MCSD: | |
| ```python | |
| @staticmethod | |
| def _split_prompt_and_privileged_contexts( | |
| inputs: list[dict[str, Any]], | |
| ) -> tuple[list[Any], list[Any], list[Any]]: | |
| prompts = [example["prompt"] for example in inputs] | |
| privileged_contexts = [example.get("privileged_context") for example in inputs] | |
| criterion_contexts = [ | |
| example.get("privileged_contexts", example.get("privileged_context")) | |
| for example in inputs | |
| ] | |
| return prompts, privileged_contexts, criterion_contexts | |
| ``` | |
| Then update `SDPOTrainer._generate_and_score_completions`: | |
| ```python | |
| prompts, privileged_contexts, criterion_contexts = self._split_prompt_and_privileged_contexts(inputs) | |
| teacher_feedbacks = criterion_contexts if self.args.mcsd_enable else privileged_contexts | |
| output = super()._generate_and_score_completions(inputs) | |
| output.update( | |
| self.teacher_context_builder.build(output, prompts, output["rewards"], feedbacks=teacher_feedbacks) | |
| ) | |
| ``` | |
| This keeps the single-teacher field as the source for OPSD and the plural field as the source for MCSD. | |
| Recommended batch convention: | |
| ```python | |
| prompts: list[Any] | |
| privileged_contexts: list[str | None] # old OPSD | |
| criterion_contexts: list[list[str] | str | None] # MCSD | |
| ``` | |
| Treat a plain string in `criterion_contexts` as a one-criterion list so existing datasets can still be smoke-tested with MCSD. | |
| ## 5. Build criterion-specific teacher contexts | |
| Edit `SuccessfulRolloutTeacherContextBuilder` in `trl/experimental/sdpo/sdpo_trainer.py`. | |
| Do not rewrite the existing single-teacher path. Add a branch near the start of `build`: | |
| ```python | |
| if self.trainer.args.mcsd_enable: | |
| return self._build_mcsd(output, prompts, rewards, feedbacks=feedbacks) | |
| ``` | |
| Keep the existing `build` body as the OPSD implementation. | |
| The MCSD builder returns: | |
| ```python | |
| { | |
| "teacher_input_ids_by_criterion": teacher_input_ids_by_criterion, # [B, K_max, L] | |
| "teacher_attention_mask_by_criterion": teacher_attention_mask_by_criterion,# [B, K_max, L] | |
| "mcsd_criterion_mask": mcsd_criterion_mask, # [B, K_max] | |
| "mcsd_num_criteria": mcsd_num_criteria, # [B] | |
| "self_distillation_mask": local_self_distillation_mask, # [B] | |
| } | |
| ``` | |
| Implementation outline: | |
| 1. Gather rewards, completions, prompts, and feedbacks exactly like the current builder does. | |
| 2. For each global sample, construct a list of active criterion contexts. | |
| 3. Use `privileged_contexts` when available. | |
| 4. If only the old `privileged_context` string is available, wrap it as a one-element list. | |
| 5. If a sample has no active criteria, use one placeholder teacher prompt and set its criterion mask to `0`. | |
| 6. For active criteria, set the criterion mask to `1`. | |
| 7. For variable `K`, pad criteria to the local `K_max`. | |
| 8. Tokenize the flattened list of teacher messages with the existing `_tokenize_teacher_messages`. | |
| 9. Repeat `completion_ids` and `completion_mask` by `K_max`. | |
| 10. Concatenate teacher prompts and completions exactly like OPSD does. | |
| 11. Reshape flattened teacher tensors back to `[B, K_max, L]`. | |
| Sketch: | |
| ```python | |
| contexts_by_sample = [] | |
| mask_rows = [] | |
| mcsd_num_criteria = [] | |
| for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)): | |
| raw_contexts = all_feedbacks[global_idx] | |
| contexts = _normalize_criterion_contexts(raw_contexts) | |
| if len(contexts) == 0: | |
| contexts_by_sample.append([None]) | |
| mask_rows.append([0.0]) | |
| mcsd_num_criteria.append(0) | |
| else: | |
| contexts_by_sample.append(contexts) | |
| mask_rows.append([1.0] * len(contexts)) | |
| mcsd_num_criteria.append(len(contexts)) | |
| max_criteria = max(len(contexts) for contexts in contexts_by_sample) | |
| flat_teacher_messages = [] | |
| for local_idx, global_idx in enumerate(range(process_start, process_start + num_local)): | |
| original_prompt = all_prompts[global_idx] | |
| contexts = list(contexts_by_sample[local_idx]) | |
| while len(contexts) < max_criteria: | |
| contexts.append(None) | |
| mask_rows[local_idx].append(0.0) | |
| for context in contexts: | |
| flat_teacher_messages.append(_build_teacher_message(original_prompt, context)) | |
| mcsd_criterion_mask = torch.tensor(mask_rows, device=device) | |
| mcsd_num_criteria = torch.tensor(mcsd_num_criteria, device=device, dtype=torch.long) | |
| local_self_distillation_mask = (mcsd_num_criteria > 0).float() | |
| ``` | |
| Tokenize `flat_teacher_messages`, concatenate the repeated completions, and reshape: | |
| ```python | |
| teacher_batch = self._tokenize_teacher_messages(flat_teacher_messages) | |
| flat_completion_ids = completion_ids.repeat_interleave(max_criteria, dim=0) | |
| flat_completion_mask = completion_mask.repeat_interleave(max_criteria, dim=0) | |
| flat_teacher_input_ids = torch.cat([teacher_batch.prompt_ids, flat_completion_ids], dim=1) | |
| flat_teacher_attention_mask = torch.cat([teacher_batch.prompt_mask, flat_completion_mask], dim=1) | |
| teacher_input_ids_by_criterion = flat_teacher_input_ids.reshape(num_local, max_criteria, -1) | |
| teacher_attention_mask_by_criterion = flat_teacher_attention_mask.reshape(num_local, max_criteria, -1) | |
| ``` | |
| A masked criterion contributes `0` to `log_gate.sum(dim=1)`, which is the neutral log-space contribution. | |
| For RAR Science/Medicine MCSD, set: | |
| ```bash | |
| --include_environment_feedback true \ | |
| --use_successful_as_teacher false | |
| ``` | |
| That makes each rubric item a criterion-specific hidden teacher context and avoids mixing successful-rollout mining into the first MCSD implementation. | |
| ## 6. Add the MCSD loss branch | |
| Edit `SelfDistillationMixin._compute_self_distillation_loss` in `trl/experimental/self_distillation/self_distillation_mixin.py`. | |
| The current method does this: | |
| 1. Builds `response_mask`. | |
| 2. Builds student logits. | |
| 3. Builds one teacher logits tensor. | |
| 4. Computes sampled-token logps for diagnostics and clipping. | |
| 5. Computes top-k reverse KL when `distillation_topk` and `full_logit_distillation` are active. | |
| 6. Applies importance clipping. | |
| 7. Aggregates with `_aggregate_self_distillation_loss`. | |
| Add an MCSD branch after student logits are available and before the single-teacher forward pass: | |
| ```python | |
| if self.args.mcsd_enable: | |
| return self._compute_mcsd_self_distillation_loss( | |
| model=model, | |
| inputs=inputs, | |
| completion_ids=completion_ids, | |
| completion_mask=completion_mask, | |
| response_mask=response_mask, | |
| logits_to_keep=logits_to_keep, | |
| student_logits=student_logits, | |
| ) | |
| ``` | |
| Keep all existing code below that branch for OPSD. | |
| ### 6.1 Forward all criterion teachers | |
| Inside `_compute_mcsd_self_distillation_loss`, read the new tensors: | |
| ```python | |
| teacher_input_ids = inputs["teacher_input_ids_by_criterion"] # [B, K, L] | |
| teacher_attention_mask = inputs["teacher_attention_mask_by_criterion"] # [B, K, L] | |
| criterion_mask = inputs["mcsd_criterion_mask"].to(student_logits.dtype) # [B, K] | |
| ``` | |
| Flatten the criterion dimension: | |
| ```python | |
| batch_size, num_criteria, teacher_seq_len = teacher_input_ids.shape | |
| flat_teacher_input_ids = teacher_input_ids.reshape(batch_size * num_criteria, teacher_seq_len) | |
| flat_teacher_attention_mask = teacher_attention_mask.reshape(batch_size * num_criteria, teacher_seq_len) | |
| ``` | |
| Run the teacher model under the same no-grad/context manager as OPSD: | |
| ```python | |
| teacher_model_inputs = { | |
| "input_ids": flat_teacher_input_ids, | |
| "attention_mask": flat_teacher_attention_mask, | |
| "use_cache": False, | |
| } | |
| if "logits_to_keep" in self.model_kwarg_keys: | |
| teacher_model_inputs["logits_to_keep"] = logits_to_keep + 1 | |
| teacher_model = self._get_teacher_model_for_self_distillation(model) | |
| with torch.no_grad(), self._get_teacher_context_for_self_distillation(model): | |
| flat_teacher_logits = teacher_model(**teacher_model_inputs).logits | |
| ``` | |
| Then align logits exactly like OPSD: | |
| ```python | |
| flat_teacher_logits = flat_teacher_logits[:, :-1, :] | |
| flat_teacher_logits = flat_teacher_logits[:, -logits_to_keep:, :] | |
| flat_teacher_logits = flat_teacher_logits / self.temperature | |
| teacher_logits = flat_teacher_logits.reshape(batch_size, num_criteria, logits_to_keep, -1) | |
| ``` | |
| This alignment is required because teacher prompts are longer than student prompts, but both tensors end with the same generated completion. | |
| ### 6.2 Compute top-k MCSD reverse KL | |
| Use this code shape for the core math: | |
| ```python | |
| student_logsumexp = torch.logsumexp(student_logits, dim=-1, keepdim=True) | |
| student_log_probs = student_logits - student_logsumexp | |
| topk_student_log_probs, topk_indices = torch.topk( | |
| student_log_probs, | |
| k=self.args.distillation_topk, | |
| dim=-1, | |
| ) | |
| log_q_s = self._renorm_topk_log_probs(topk_student_log_probs) | |
| q_s = log_q_s.exp() | |
| teacher_logsumexp = torch.logsumexp(teacher_logits, dim=-1, keepdim=True) | |
| teacher_log_probs = teacher_logits - teacher_logsumexp | |
| teacher_topk_indices = topk_indices.unsqueeze(1).expand(-1, num_criteria, -1, -1) | |
| teacher_topk_log_probs = torch.gather(teacher_log_probs, dim=-1, index=teacher_topk_indices) | |
| log_q_t = self._renorm_topk_log_probs(teacher_topk_log_probs) | |
| log_q_base = log_q_s.detach() | |
| advantage = log_q_t - log_q_base.unsqueeze(1) | |
| log_gate = F.logsigmoid(advantage - self.args.mcsd_gate_bias) | |
| log_gate = log_gate * criterion_mask[:, :, None, None] | |
| merged_log_unnorm = log_q_base + log_gate.sum(dim=1) | |
| log_q_mcsd = merged_log_unnorm - torch.logsumexp(merged_log_unnorm, dim=-1, keepdim=True) | |
| log_q_mcsd = log_q_mcsd.detach() | |
| per_token_loss = (q_s * (log_q_s - log_q_mcsd)).sum(dim=-1) | |
| ``` | |
| Notes: | |
| - `criterion_mask == 0` makes the criterion contribute `0` to the log-space expert sum. | |
| - `self_distillation_mask` already zeros out samples with no active teacher supervision through `response_mask`. | |
| - Reject `distillation_add_tail=True` for MCSD in this implementation: | |
| ```python | |
| if self.mcsd_enable and self.distillation_add_tail: | |
| raise ValueError("MCSD does not support `distillation_add_tail=True` in the first implementation.") | |
| ``` | |
| ### 6.3 Keep sampled-token logps for clipping | |
| The existing clipping path uses sampled-token student log-probs and `old_per_token_logps`. | |
| Compute sampled-token logps from `student_logits`: | |
| ```python | |
| idx = completion_ids.unsqueeze(-1) | |
| student_per_token_logps = torch.gather(student_log_probs, dim=-1, index=idx).squeeze(-1) | |
| ``` | |
| Then reuse the existing helper: | |
| ```python | |
| if self.args.distillation_is_clip is not None: | |
| old_log_probs = inputs.get("old_per_token_logps") | |
| if old_log_probs is not None: | |
| per_token_loss = self._apply_importance_sampling_clipping( | |
| per_token_loss, | |
| student_per_token_logps, | |
| old_log_probs, | |
| self.args.distillation_is_clip, | |
| ) | |
| ``` | |
| Keep the existing pointwise KL diagnostics if possible. Use the same code from the OPSD branch so clipping metrics remain aligned. | |
| ### 6.4 Aggregate exactly like OPSD | |
| Use the existing aggregation helper: | |
| ```python | |
| loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask) | |
| ``` | |
| Log the distillation loss through the existing metric helper so `train/loss_opsd` continues to be populated by the base trainer: | |
| ```python | |
| mean_distill_loss = (per_token_loss * response_mask).sum() / response_mask.sum().clamp(min=1.0) | |
| self._log_self_distillation_metric( | |
| mode, | |
| "distillation_loss", | |
| self.accelerator.gather(mean_distill_loss).mean().item(), | |
| ) | |
| ``` | |
| Also log MCSD-specific diagnostics. | |
| ## 7. Log MCSD diagnostics | |
| Keep the existing OPSD TensorBoard keys alive. | |
| The base trainer maps `self_distillation/distillation_loss` to: | |
| ```text | |
| train/loss_opsd | |
| ``` | |
| Therefore the MCSD loss branch must still call: | |
| ```python | |
| self._log_self_distillation_metric(mode, "distillation_loss", ...) | |
| ``` | |
| This preserves the old `train/loss_opsd` panel. It also keeps `train/loss_total`, `train/grad_norm`, `progress/*`, `rollout/*`, and `system/*` unchanged because those are logged outside the teacher-merge logic. | |
| For the old `opsd/*` diagnostic keys, continue emitting them in MCSD mode so existing dashboards do not go blank: | |
| ```text | |
| opsd/student_logp_mean | |
| opsd/teacher_logp_mean | |
| opsd/teacher_student_logp_gap_mean | |
| opsd/token_advantage_mean | |
| opsd/token_advantage_std | |
| opsd/pointwise_kl_mean | |
| opsd/pointwise_kl_clip_frac | |
| opsd/pointwise_kl_clipped_mean | |
| ``` | |
| In MCSD mode, interpret the teacher-side `opsd/*` keys as diagnostics against the merged MCSD teacher, not against a single privileged teacher. For sampled-token teacher metrics, gather the sampled response token from `log_q_mcsd` only when the sampled token is inside the student top-k set; otherwise skip that token for those compatibility diagnostics. The primary MCSD diagnostics below are the authoritative metrics for the merged teacher. | |
| Use slash-style metric names to match the repository: | |
| ```python | |
| mcsd/loss | |
| mcsd/num_criteria_mean | |
| mcsd/gate_bias | |
| mcsd/teacher_entropy_mean | |
| mcsd/student_entropy_mean | |
| mcsd/merged_entropy_mean | |
| mcsd/gate_mean | |
| mcsd/gate_min | |
| mcsd/gate_max | |
| ``` | |
| Compute them under `torch.no_grad()` and apply masks where practical: | |
| ```python | |
| gate = log_gate.exp() | |
| student_entropy = -(q_s * log_q_s).sum(dim=-1) | |
| merged_entropy = -(log_q_mcsd.exp() * log_q_mcsd).sum(dim=-1) | |
| teacher_entropy = -(log_q_t.exp() * log_q_t).sum(dim=-1) | |
| ``` | |
| For criterion-level metrics, mask with: | |
| ```python | |
| criterion_token_mask = criterion_mask[:, :, None] * response_mask[:, None, :] | |
| ``` | |
| For token-level metrics, mask with: | |
| ```python | |
| response_mask | |
| ``` | |
| Gather metric tensors through `self.accelerator.gather(...)` before appending scalar `.item()` values, following the style already used in `self_distillation_mixin.py`. | |
| ## 8. Update callbacks and artifacts only as needed | |
| The existing callback `on_teacher_context_built` receives: | |
| ```python | |
| teacher_input_ids | |
| teacher_attention_mask | |
| completion_mask | |
| self_distillation_mask | |
| ``` | |
| For MCSD, pass the new tensors too: | |
| ```python | |
| teacher_input_ids_by_criterion=output["teacher_input_ids_by_criterion"] | |
| teacher_attention_mask_by_criterion=output["teacher_attention_mask_by_criterion"] | |
| mcsd_criterion_mask=output["mcsd_criterion_mask"] | |
| mcsd_num_criteria=output["mcsd_num_criteria"] | |
| ``` | |
| Do not remove the old callback payload for OPSD. | |
| For rollout artifacts, add the number of active criteria per sample. Avoid dumping full teacher logits. | |
| ## 9. Update the launch script | |
| In `run_sdpo.sh`, add the MCSD flags: | |
| ```bash | |
| --mcsd_enable true \ | |
| --mcsd_merge_mode lsc_poe \ | |
| --mcsd_gate_bias 0.0 \ | |
| --include_environment_feedback true \ | |
| --use_successful_as_teacher false \ | |
| --sdpo_policy_loss_mode distillation_only \ | |
| --full_logit_distillation true \ | |
| --distillation_topk 20 \ | |
| --distillation_alpha 1.0 \ | |
| ``` | |
| Keep `--distillation_is_clip 2` if you want the same importance-clipping behavior as OPSD. | |
| ## 10. Add tests | |
| Add focused tests in `tests/experimental/test_sdpo_trainer.py`. | |
| Minimum test list: | |
| 1. Config defaults | |
| - `SDPOConfig(output_dir=...).mcsd_enable is False` | |
| - Existing tests still pass without setting any MCSD flag. | |
| 2. Config validation | |
| - `mcsd_enable=True` rejects `full_logit_distillation=False`. | |
| - `mcsd_enable=True` rejects `distillation_topk=None`. | |
| - `mcsd_enable=True` rejects `distillation_alpha != 1.0`. | |
| - `mcsd_enable=True` rejects unsupported `mcsd_merge_mode`. | |
| 3. Dataset retention | |
| - A dataset with `privileged_contexts` survives trainer column filtering. | |
| 4. Teacher context construction | |
| - A sample with three criterion contexts returns `teacher_input_ids_by_criterion` with shape `[B, 3, L]`. | |
| - `mcsd_criterion_mask` is `[1, 1, 1]`. | |
| - A sample with fewer criteria is padded and has zeros in the padded criterion mask. | |
| - Teacher completion attention still equals the original `completion_mask` on every active criterion. | |
| 5. MCSD math unit test | |
| - Use tiny synthetic `student_logits` and `teacher_logits`. | |
| - Verify the merged teacher equals: | |
| ```python | |
| softmax(log_q_s.detach() + F.logsigmoid(log_q_t - log_q_s.detach()).sum(dim=1)) | |
| ``` | |
| - Verify it is not equal to an arithmetic average of teacher distributions. | |
| - Verify inactive criteria do not change the merged teacher. | |
| 6. End-to-end tiny training | |
| - Use `trl-internal-testing/tiny-Qwen2ForCausalLM-2.5`. | |
| - Dataset has `prompt`, `privileged_context`, and `privileged_contexts`. | |
| - Args include: | |
| ```python | |
| mcsd_enable=True, | |
| mcsd_merge_mode="lsc_poe", | |
| mcsd_gate_bias=0.0, | |
| include_environment_feedback=True, | |
| use_successful_as_teacher=False, | |
| sdpo_policy_loss_mode="distillation_only", | |
| full_logit_distillation=True, | |
| distillation_topk=5, | |
| distillation_alpha=1.0, | |
| distillation_is_clip=None, | |
| max_steps=1, | |
| ``` | |
| - Assert `train_loss` is present. | |
| - Assert at least one `mcsd/*` metric is logged. | |
| Run the targeted tests: | |
| ```bash | |
| pytest tests/experimental/test_sdpo_trainer.py -k "mcsd or training" | |
| ``` | |
| Then run the full SDPO test file: | |
| ```bash | |
| pytest tests/experimental/test_sdpo_trainer.py | |
| ``` | |
| ## 11. Paper index requirement | |
| If MCSD is being implemented from a paper, add a subsection to `paper_index.md`. Per repo guidance, use Hugging Face paper links: | |
| ```text | |
| https://huggingface.co/papers/<paper-id> | |
| ``` | |
| Do this in the same PR as the MCSD implementation. | |
| ## 12. Implementation checklist | |
| 1. Add `mcsd_enable`, `mcsd_merge_mode`, and `mcsd_gate_bias` to `SDPOConfig`. | |
| 2. Validate MCSD only for top-k full-logit reverse-KL distillation. | |
| 3. Preserve `privileged_context` and add `privileged_contexts` to the RAR dataset mapping. | |
| 4. Keep `privileged_contexts` during trainer column filtering. | |
| 5. Add an MCSD branch to the teacher-context builder. | |
| 6. Return `[B, K_max, L]` teacher tensors plus `mcsd_criterion_mask`. | |
| 7. Add an MCSD branch to `_compute_self_distillation_loss`. | |
| 8. Forward criterion teachers by flattening `[B, K, L]` to `[B*K, L]`. | |
| 9. Gather every teacher over the student top-k token indices. | |
| 10. Build the merged teacher with `log_q_s.detach() + sum_j(logsigmoid(log_q_t_j - log_q_s.detach() - bias))`. | |
| 11. Compute reverse KL from current student top-k distribution to detached merged teacher. | |
| 12. Reuse existing importance clipping and loss aggregation. | |
| 13. Log `mcsd/*` diagnostics. | |
| 14. Update callbacks/artifacts with criterion counts and masks. | |
| 15. Add tests for config, data retention, context construction, MCSD math, and tiny training. | |
| 16. Update `paper_index.md` if MCSD corresponds to a paper implementation. | |
| Final behavior: when MCSD is disabled, this repository trains exactly as current OPSD; when MCSD is enabled, the single teacher distribution is replaced by an LSC-transformed product-of-experts teacher over the same student-selected top-k tokens. | |