trl-mcsd / examples /scripts /sdpo_rar_science.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl[peft]",
# "math-verify",
# "latex2sympy2_extended",
# "trackio",
# "kernels",
# ]
# ///
"""
Usage:
```bash
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/sdpo_rar_science.py \
--model_name_or_path Qwen/Qwen3-4B \
--output_dir outputs/rar-science-qwen3-4b \
--dataset_dir data/rar_science \
--dataset_test_split val \
--learning_rate 5e-6 \
--dtype bfloat16 \
--bf16 true \
--max_prompt_length 2048 \
--max_completion_length 768 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_generations 4 \
--use_peft \
--lora_target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \
--sdpo_policy_loss_mode distillation_only \
--eval_strategy steps \
--eval_steps 200 \
--save_strategy steps \
--save_steps 200 \
--report_to none
```
This example trains SDPO on the local `rar_science` JSONL dataset. The student sees only the question, while the
teacher gets a rubric scaffold in `privileged_context`, built by placing positive-weight rubric items under
"IMPORTANT POINTS TO INCLUDE" and negative-weight rubric items under "IMPORTANT POINTS TO AVOID".
"""
import re
import unicodedata
from collections import Counter
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from pathlib import Path
from typing import Any
import torch
from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer, GenerationConfig
from trl import (
LogCompletionsCallback,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.data_utils import maybe_apply_chat_template
from trl.experimental.sdpo import SDPOConfig, SDPOTrainer
try:
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
HAS_MATH_VERIFY = True
except ImportError:
HAS_MATH_VERIFY = False
SYSTEM_PROMPT = "You are a helpful assistant."
PRIVILEGED_RUBRIC_TEMPLATE = """For this question, please consider the following evaluation criteria:
IMPORTANT POINTS TO INCLUDE (you should aim to address these):
{include_block}
IMPORTANT POINTS TO AVOID (you should not do these):
{avoid_block}
Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.
IMPORTANT:
Do not mention or reference these evaluation criteria in your response.
Do not indicate that you have seen any scoring rubric or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
POSITIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation criteria:
{positive_rubric}
Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.
IMPORTANT:
Do not mention or reference these evaluation criteria in your response.
Do not indicate that you have seen any scoring rubric or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
NEGATIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation pitfall to avoid:
{negative_rubric}
Please provide a comprehensive and helpful response that addresses the question while carefully avoiding the issue described above.
IMPORTANT:
Do not mention or reference this evaluation criterion in your response.
Do not indicate that you have seen any scoring rubric, pitfall list, or evaluation guidelines.
Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
"""
@dataclass
class RarScienceScriptArguments(ScriptArguments):
dataset_dir: str = field(
default="data/rar_science",
metadata={"help": "Directory containing local `train.jsonl`, `val.jsonl`, and optional `test.jsonl` files."},
)
dataset_test_split: str = field(
default="val",
metadata={"help": "Dataset split to use for evaluation. Defaults to `val` for the local RAR Science files."},
)
eval_num_prompts: int | None = field(
default=8,
metadata={"help": "Number of evaluation prompts to log with sampled completions. Set to 0 to disable."},
)
reward_eval_num_examples: int | None = field(
default=128,
metadata={"help": "Optional number of eval examples to score with greedy decoding before and after training."},
)
reward_eval_max_new_tokens: int = field(
default=256,
metadata={"help": "Maximum number of tokens for greedy reward evaluation generations."},
)
max_train_examples: int | None = field(
default=None,
metadata={"help": "Optional cap on the number of training examples loaded from the selected train split."},
)
max_eval_examples: int | None = field(
default=None,
metadata={"help": "Optional cap on the number of evaluation examples loaded from the selected eval split."},
)
dataset_shuffle_seed: int = field(
default=42,
metadata={"help": "Random seed used before applying example caps."},
)
max_rubric_items_per_section: int | None = field(
default=None,
metadata={"help": "Optional cap on the number of rubric items kept in each include/avoid section."},
)
@dataclass
class RarScienceSDPOConfig(SDPOConfig):
max_prompt_length: int | None = field(
default=2048,
metadata={"help": "Maximum prompt length for the student prompt."},
)
max_completion_length: int | None = field(
default=768,
metadata={"help": "Maximum completion length for the generated answer."},
)
include_environment_feedback: bool = field(
default=True,
metadata={"help": "Always pass the rubric scaffold as teacher-only privileged context."},
)
use_successful_as_teacher: bool = field(
default=False,
metadata={"help": "Disable successful-rollout teacher mining and use only rubric-conditioned teacher prompts."},
)
sdpo_policy_loss_mode: str = field(
default="distillation_only",
metadata={"help": "Use only the rubric-conditioned self-distillation loss."},
)
diagnostics_warning_interval: int = field(
default=0,
metadata={"help": "Disable flat-reward and no-success warnings for intentional zero-reward distillation-only training."},
)
feedback_template: str = field(
default="{feedback_raw}",
metadata={"help": "Inject the rubric scaffold directly without wrapping it as failure feedback."},
)
reprompt_template: str = field(
default=(
"{feedback}\n\n"
"Original question:\n"
"{prompt}\n"
"{solution}\n"
"Respond to the original question while silently following the hidden guidance above.\n"
),
metadata={"help": "Teacher prompt template for rubric-scaffolded SDPO."},
)
solution_template: str = field(
default=(
"\n\nA successful earlier attempt is available below. Reuse the useful parts if they help:\n\n"
"{successful_previous_attempt}\n"
),
metadata={"help": "How successful sibling rollouts are inserted into the teacher prompt."},
)
def _load_local_dataset(dataset_dir: str) -> DatasetDict:
dataset_path = Path(dataset_dir)
data_files = {}
for split_name in ("train", "val", "test"):
split_path = dataset_path / f"{split_name}.jsonl"
if split_path.exists():
data_files[split_name] = str(split_path)
if "train" not in data_files:
raise ValueError(f"No `train.jsonl` found in dataset_dir={dataset_dir!r}.")
return load_dataset("json", data_files=data_files)
def _clean_text(text: Any) -> str:
return " ".join(str(text).strip().split())
def _collect_rubric_items(
example: dict[str, Any], max_items_per_section: int | None
) -> tuple[list[str], list[str]]:
include_items = []
avoid_items = []
rubric = example.get("rubric") or []
if isinstance(rubric, list):
for item in rubric:
if not isinstance(item, dict):
continue
description = _clean_text(item.get("description") or item.get("title") or "")
if not description:
continue
weight = item.get("weight", 0)
if weight > 0:
include_items.append(description)
elif weight < 0:
avoid_items.append(description)
if not include_items and not avoid_items:
for item in example.get("rubric_list") or []:
text = _clean_text(item)
if not text:
continue
lower = text.lower()
if lower.startswith("pitfall criteria:") or " should not " in lower:
avoid_items.append(text)
else:
include_items.append(text)
if max_items_per_section is not None:
include_items = include_items[:max_items_per_section]
avoid_items = avoid_items[:max_items_per_section]
return include_items, avoid_items
def _format_rubric_block(items: list[str]) -> str:
if not items:
return "- None."
return "\n".join(f"- {item}" for item in items)
def _build_privileged_context(example: dict[str, Any], max_items_per_section: int | None) -> str:
include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)
return PRIVILEGED_RUBRIC_TEMPLATE.format(
include_block=_format_rubric_block(include_items),
avoid_block=_format_rubric_block(avoid_items),
)
def _build_privileged_contexts_by_criterion(
example: dict[str, Any], max_items_per_section: int | None
) -> list[str]:
include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)
contexts = []
for item in include_items:
contexts.append(POSITIVE_RUBRIC_TEACHER_TEMPLATE.format(positive_rubric=item))
for item in avoid_items:
contexts.append(NEGATIVE_RUBRIC_TEACHER_TEMPLATE.format(negative_rubric=item))
return contexts
def _make_conversation(example: dict[str, Any], max_items_per_section: int | None) -> dict[str, Any]:
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["question"]},
],
"solution": example["reference_answer"],
"privileged_context": _build_privileged_context(example, max_items_per_section),
"privileged_contexts": _build_privileged_contexts_by_criterion(example, max_items_per_section),
}
def _completion_text(completion: list[dict[str, str]] | str) -> str:
if isinstance(completion, list):
return completion[0]["content"]
return completion
def _extract_final_answer(text: str) -> str:
text = text.strip()
marker_match = re.search(r"final answer:\s*(.*)$", text, flags=re.IGNORECASE | re.DOTALL)
if marker_match:
return marker_match.group(1).strip()
boxed_match = re.search(r"####\s*(.*)$", text, flags=re.DOTALL)
if boxed_match:
return boxed_match.group(1).strip()
if "</think>" in text:
return text.split("</think>", 1)[1].strip()
return text
def _normalize_answer(text: str) -> str:
text = unicodedata.normalize("NFKC", text)
text = text.replace("\u2212", "-").replace("\u2013", "-").replace("\u2014", "-")
text = re.sub(r"\\boxed\s*\{([^{}]+)\}", r"\1", text)
text = re.sub(r"<[^>]+>", " ", text)
text = re.sub(r"`+", " ", text)
text = re.sub(r"\$(.*?)\$", r"\1", text)
text = re.sub(r"\s+", " ", text)
return text.strip().lower()
def _tokenize_for_similarity(text: str) -> list[str]:
normalized = _normalize_answer(text)
normalized = re.sub(r"[^\w\s.+\-/=]", " ", normalized, flags=re.UNICODE)
return re.findall(r"[\w.+\-/=]+", normalized, flags=re.UNICODE)
def _token_f1(prediction: str, reference: str) -> float:
prediction_tokens = _tokenize_for_similarity(prediction)
reference_tokens = _tokenize_for_similarity(reference)
if not prediction_tokens or not reference_tokens:
return 0.0
prediction_counter = Counter(prediction_tokens)
reference_counter = Counter(reference_tokens)
overlap = sum((prediction_counter & reference_counter).values())
if overlap == 0:
return 0.0
precision = overlap / len(prediction_tokens)
recall = overlap / len(reference_tokens)
return 2 * precision * recall / (precision + recall)
def _numeric_match_reward(prediction: str, reference: str) -> float | None:
def _extract_numbers(text: str) -> list[float]:
matches = re.findall(r"[-+]?\d+(?:\.\d+)?(?:e[-+]?\d+)?", text.lower())
return [float(match) for match in matches]
prediction_numbers = _extract_numbers(prediction)
reference_numbers = _extract_numbers(reference)
if len(prediction_numbers) != 1 or len(reference_numbers) != 1:
return None
pred_value = prediction_numbers[0]
ref_value = reference_numbers[0]
tolerance = max(1e-6, 1e-3 * max(abs(ref_value), 1.0))
return 1.0 if abs(pred_value - ref_value) <= tolerance else 0.0
def _math_verify_reward(prediction: str, reference: str) -> float | None:
if not HAS_MATH_VERIFY:
return None
if len(reference.split()) > 20:
return None
gold_parsed = parse(reference, parsing_timeout=None)
if len(gold_parsed) == 0:
return None
answer_parsed = parse(
prediction,
extraction_config=[
LatexExtractionConfig(
boxed_match_priority=0,
normalization_config=NormalizationConfig(units=True),
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
parsing_timeout=None,
)
return float(verify(gold_parsed, answer_parsed, timeout_seconds=None))
def _text_similarity_reward(prediction: str, reference: str) -> float:
normalized_prediction = _normalize_answer(prediction)
normalized_reference = _normalize_answer(reference)
if not normalized_prediction or not normalized_reference:
return 0.0
if normalized_prediction == normalized_reference:
return 1.0
if normalized_reference in normalized_prediction and len(normalized_reference) >= 12:
return 1.0
if normalized_prediction in normalized_reference and len(normalized_prediction) >= 12:
return 0.8
token_score = _token_f1(normalized_prediction, normalized_reference)
sequence_score = SequenceMatcher(None, normalized_prediction, normalized_reference).ratio()
score = max(token_score, sequence_score)
if score < 0.45:
return 0.0
return min(1.0, (score - 0.45) / 0.55)
def rar_science_answer_reward(completions, solution, **kwargs) -> list[float]:
rewards = []
for completion, reference in zip(completions, solution, strict=True):
prediction = _extract_final_answer(_completion_text(completion))
reward = _math_verify_reward(prediction, reference)
if reward is not None:
rewards.append(reward)
continue
reward = _numeric_match_reward(prediction, reference)
if reward is not None:
rewards.append(reward)
continue
rewards.append(_text_similarity_reward(prediction, reference))
return rewards
def constant_zero_reward(completions, **kwargs) -> list[float]:
# SDPOTrainer requires a reward function even in distillation-only mode. We keep rollout rewards inert and rely
# exclusively on rubric-conditioned teacher distillation by pairing this with `use_successful_as_teacher=False`.
return [0.0] * len(completions)
def _run_reward_eval(
trainer: SDPOTrainer, eval_dataset, max_new_tokens: int, num_examples: int | None, metric_prefix: str = "rar_science_eval"
) -> dict[str, float]:
if num_examples is not None:
eval_dataset = eval_dataset.select(range(min(num_examples, len(eval_dataset))))
prompts = eval_dataset["prompt"]
prompt_texts = [
maybe_apply_chat_template({"prompt": prompt}, trainer.processing_class)["prompt"] for prompt in prompts
]
tokenized = trainer.processing_class(
text=prompt_texts,
return_tensors="pt",
padding=True,
padding_side="left",
truncation=True,
max_length=trainer.max_prompt_length,
add_special_tokens=False,
)
tokenized = {key: value.to(trainer.accelerator.device) for key, value in tokenized.items()}
model = trainer.accelerator.unwrap_model(trainer.model)
was_training = model.training
model.eval()
with torch.no_grad():
generated = model.generate(
**tokenized,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=trainer.processing_class.pad_token_id,
eos_token_id=trainer.processing_class.eos_token_id,
)
if was_training:
model.train()
prompt_length = tokenized["input_ids"].shape[1]
completions = trainer.processing_class.batch_decode(generated[:, prompt_length:], skip_special_tokens=True)
completion_messages = [[{"role": "assistant", "content": completion}] for completion in completions]
answer_rewards = rar_science_answer_reward(completion_messages, solution=eval_dataset["solution"])
total = max(len(answer_rewards), 1)
exact_total = sum(1 for reward in answer_rewards if reward >= 0.999)
return {
f"{metric_prefix}/answer_reward": sum(answer_rewards) / total,
f"{metric_prefix}/answer_exact_rate": exact_total / total,
f"{metric_prefix}/num_scored": float(len(answer_rewards)),
}
if __name__ == "__main__":
parser = TrlParser((RarScienceScriptArguments, RarScienceSDPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config
raw_dataset = _load_local_dataset(script_args.dataset_dir)
train_split = raw_dataset[script_args.dataset_train_split]
if script_args.max_train_examples is not None:
train_split = train_split.shuffle(seed=script_args.dataset_shuffle_seed).select(
range(min(script_args.max_train_examples, len(train_split)))
)
train_dataset = train_split.map(
lambda example: _make_conversation(example, script_args.max_rubric_items_per_section),
remove_columns=train_split.column_names,
)
eval_dataset = None
if training_args.eval_strategy != "no" and script_args.dataset_test_split in raw_dataset:
eval_split = raw_dataset[script_args.dataset_test_split]
if script_args.max_eval_examples is not None:
eval_split = eval_split.shuffle(seed=script_args.dataset_shuffle_seed).select(
range(min(script_args.max_eval_examples, len(eval_split)))
)
eval_dataset = eval_split.map(
lambda example: _make_conversation(example, script_args.max_rubric_items_per_section),
remove_columns=eval_split.column_names,
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = SDPOTrainer(
model=model_args.model_name_or_path,
args=training_args,
reward_funcs=[constant_zero_reward],
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_args),
processing_class=tokenizer,
)
if eval_dataset is not None and script_args.eval_num_prompts:
generation_config = GenerationConfig(
max_new_tokens=training_args.max_completion_length,
do_sample=True,
temperature=training_args.temperature,
)
trainer.add_callback(
LogCompletionsCallback(trainer, generation_config, num_prompts=script_args.eval_num_prompts)
)
pre_metrics = None
if eval_dataset is not None:
pre_metrics = _run_reward_eval(
trainer,
eval_dataset,
max_new_tokens=script_args.reward_eval_max_new_tokens,
num_examples=script_args.reward_eval_num_examples,
)
trainer.log_metrics("eval", {f"before_{key}": value for key, value in pre_metrics.items()})
trainer.save_metrics("eval", {f"before_{key}": value for key, value in pre_metrics.items()})
trainer.train()
trainer.save_model(training_args.output_dir)
if eval_dataset is not None and pre_metrics is not None:
post_metrics = _run_reward_eval(
trainer,
eval_dataset,
max_new_tokens=script_args.reward_eval_max_new_tokens,
num_examples=script_args.reward_eval_num_examples,
)
after_metrics = {f"after_{key}": value for key, value in post_metrics.items()}
delta_metrics = {
f"delta_{key.split('/', 1)[1]}": after_metrics[f"after_{key}"] - pre_metrics[key] for key in pre_metrics
}
trainer.log_metrics("eval", after_metrics | delta_metrics)
trainer.save_metrics("eval", after_metrics | delta_metrics)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_dir)