|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Usage:
|
|
|
| ```bash
|
| accelerate launch \
|
| --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
| examples/scripts/sdpo_rar_science.py \
|
| --model_name_or_path Qwen/Qwen3-4B \
|
| --output_dir outputs/rar-science-qwen3-4b \
|
| --dataset_dir data/rar_science \
|
| --dataset_test_split val \
|
| --learning_rate 5e-6 \
|
| --dtype bfloat16 \
|
| --bf16 true \
|
| --max_prompt_length 2048 \
|
| --max_completion_length 768 \
|
| --per_device_train_batch_size 1 \
|
| --gradient_accumulation_steps 8 \
|
| --num_generations 4 \
|
| --use_peft \
|
| --lora_target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \
|
| --sdpo_policy_loss_mode distillation_only \
|
| --eval_strategy steps \
|
| --eval_steps 200 \
|
| --save_strategy steps \
|
| --save_steps 200 \
|
| --report_to none
|
| ```
|
|
|
| This example trains SDPO on the local `rar_science` JSONL dataset. The student sees only the question, while the
|
| teacher gets a rubric scaffold in `privileged_context`, built by placing positive-weight rubric items under
|
| "IMPORTANT POINTS TO INCLUDE" and negative-weight rubric items under "IMPORTANT POINTS TO AVOID".
|
| """
|
|
|
| import re
|
| import unicodedata
|
| from collections import Counter
|
| from dataclasses import dataclass, field
|
| from difflib import SequenceMatcher
|
| from pathlib import Path
|
| from typing import Any
|
|
|
| import torch
|
| from datasets import DatasetDict, load_dataset
|
| from transformers import AutoTokenizer, GenerationConfig
|
|
|
| from trl import (
|
| LogCompletionsCallback,
|
| ModelConfig,
|
| ScriptArguments,
|
| TrlParser,
|
| get_kbit_device_map,
|
| get_peft_config,
|
| get_quantization_config,
|
| )
|
| from trl.data_utils import maybe_apply_chat_template
|
| from trl.experimental.sdpo import SDPOConfig, SDPOTrainer
|
|
|
| try:
|
| from latex2sympy2_extended import NormalizationConfig
|
| from math_verify import LatexExtractionConfig, parse, verify
|
|
|
| HAS_MATH_VERIFY = True
|
| except ImportError:
|
| HAS_MATH_VERIFY = False
|
|
|
|
|
| SYSTEM_PROMPT = "You are a helpful assistant."
|
|
|
| PRIVILEGED_RUBRIC_TEMPLATE = """For this question, please consider the following evaluation criteria:
|
|
|
| IMPORTANT POINTS TO INCLUDE (you should aim to address these):
|
| {include_block}
|
|
|
| IMPORTANT POINTS TO AVOID (you should not do these):
|
| {avoid_block}
|
|
|
| Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.
|
|
|
| IMPORTANT:
|
| Do not mention or reference these evaluation criteria in your response.
|
| Do not indicate that you have seen any scoring rubric or evaluation guidelines.
|
| Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
|
| """
|
|
|
| POSITIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation criteria:
|
|
|
| {positive_rubric}
|
|
|
| Please provide a comprehensive and helpful response that addresses the question while following the above guidelines.
|
|
|
| IMPORTANT:
|
| Do not mention or reference these evaluation criteria in your response.
|
| Do not indicate that you have seen any scoring rubric or evaluation guidelines.
|
| Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
|
| """
|
|
|
| NEGATIVE_RUBRIC_TEACHER_TEMPLATE = """For this question, please consider the following evaluation pitfall to avoid:
|
|
|
| {negative_rubric}
|
|
|
| Please provide a comprehensive and helpful response that addresses the question while carefully avoiding the issue described above.
|
|
|
| IMPORTANT:
|
| Do not mention or reference this evaluation criterion in your response.
|
| Do not indicate that you have seen any scoring rubric, pitfall list, or evaluation guidelines.
|
| Your response should appear natural and spontaneous. Revealing that you have access to evaluation criteria would be considered cheating and is strictly prohibited.
|
| """
|
|
|
|
|
| @dataclass
|
| class RarScienceScriptArguments(ScriptArguments):
|
| dataset_dir: str = field(
|
| default="data/rar_science",
|
| metadata={"help": "Directory containing local `train.jsonl`, `val.jsonl`, and optional `test.jsonl` files."},
|
| )
|
| dataset_test_split: str = field(
|
| default="val",
|
| metadata={"help": "Dataset split to use for evaluation. Defaults to `val` for the local RAR Science files."},
|
| )
|
| eval_num_prompts: int | None = field(
|
| default=8,
|
| metadata={"help": "Number of evaluation prompts to log with sampled completions. Set to 0 to disable."},
|
| )
|
| reward_eval_num_examples: int | None = field(
|
| default=128,
|
| metadata={"help": "Optional number of eval examples to score with greedy decoding before and after training."},
|
| )
|
| reward_eval_max_new_tokens: int = field(
|
| default=256,
|
| metadata={"help": "Maximum number of tokens for greedy reward evaluation generations."},
|
| )
|
| max_train_examples: int | None = field(
|
| default=None,
|
| metadata={"help": "Optional cap on the number of training examples loaded from the selected train split."},
|
| )
|
| max_eval_examples: int | None = field(
|
| default=None,
|
| metadata={"help": "Optional cap on the number of evaluation examples loaded from the selected eval split."},
|
| )
|
| dataset_shuffle_seed: int = field(
|
| default=42,
|
| metadata={"help": "Random seed used before applying example caps."},
|
| )
|
| max_rubric_items_per_section: int | None = field(
|
| default=None,
|
| metadata={"help": "Optional cap on the number of rubric items kept in each include/avoid section."},
|
| )
|
|
|
|
|
| @dataclass
|
| class RarScienceSDPOConfig(SDPOConfig):
|
| max_prompt_length: int | None = field(
|
| default=2048,
|
| metadata={"help": "Maximum prompt length for the student prompt."},
|
| )
|
| max_completion_length: int | None = field(
|
| default=768,
|
| metadata={"help": "Maximum completion length for the generated answer."},
|
| )
|
| include_environment_feedback: bool = field(
|
| default=True,
|
| metadata={"help": "Always pass the rubric scaffold as teacher-only privileged context."},
|
| )
|
| use_successful_as_teacher: bool = field(
|
| default=False,
|
| metadata={"help": "Disable successful-rollout teacher mining and use only rubric-conditioned teacher prompts."},
|
| )
|
| sdpo_policy_loss_mode: str = field(
|
| default="distillation_only",
|
| metadata={"help": "Use only the rubric-conditioned self-distillation loss."},
|
| )
|
| diagnostics_warning_interval: int = field(
|
| default=0,
|
| metadata={"help": "Disable flat-reward and no-success warnings for intentional zero-reward distillation-only training."},
|
| )
|
| feedback_template: str = field(
|
| default="{feedback_raw}",
|
| metadata={"help": "Inject the rubric scaffold directly without wrapping it as failure feedback."},
|
| )
|
| reprompt_template: str = field(
|
| default=(
|
| "{feedback}\n\n"
|
| "Original question:\n"
|
| "{prompt}\n"
|
| "{solution}\n"
|
| "Respond to the original question while silently following the hidden guidance above.\n"
|
| ),
|
| metadata={"help": "Teacher prompt template for rubric-scaffolded SDPO."},
|
| )
|
| solution_template: str = field(
|
| default=(
|
| "\n\nA successful earlier attempt is available below. Reuse the useful parts if they help:\n\n"
|
| "{successful_previous_attempt}\n"
|
| ),
|
| metadata={"help": "How successful sibling rollouts are inserted into the teacher prompt."},
|
| )
|
|
|
|
|
| def _load_local_dataset(dataset_dir: str) -> DatasetDict:
|
| dataset_path = Path(dataset_dir)
|
| data_files = {}
|
| for split_name in ("train", "val", "test"):
|
| split_path = dataset_path / f"{split_name}.jsonl"
|
| if split_path.exists():
|
| data_files[split_name] = str(split_path)
|
|
|
| if "train" not in data_files:
|
| raise ValueError(f"No `train.jsonl` found in dataset_dir={dataset_dir!r}.")
|
|
|
| return load_dataset("json", data_files=data_files)
|
|
|
|
|
| def _clean_text(text: Any) -> str:
|
| return " ".join(str(text).strip().split())
|
|
|
|
|
| def _collect_rubric_items(
|
| example: dict[str, Any], max_items_per_section: int | None
|
| ) -> tuple[list[str], list[str]]:
|
| include_items = []
|
| avoid_items = []
|
|
|
| rubric = example.get("rubric") or []
|
| if isinstance(rubric, list):
|
| for item in rubric:
|
| if not isinstance(item, dict):
|
| continue
|
| description = _clean_text(item.get("description") or item.get("title") or "")
|
| if not description:
|
| continue
|
| weight = item.get("weight", 0)
|
| if weight > 0:
|
| include_items.append(description)
|
| elif weight < 0:
|
| avoid_items.append(description)
|
|
|
| if not include_items and not avoid_items:
|
| for item in example.get("rubric_list") or []:
|
| text = _clean_text(item)
|
| if not text:
|
| continue
|
| lower = text.lower()
|
| if lower.startswith("pitfall criteria:") or " should not " in lower:
|
| avoid_items.append(text)
|
| else:
|
| include_items.append(text)
|
|
|
| if max_items_per_section is not None:
|
| include_items = include_items[:max_items_per_section]
|
| avoid_items = avoid_items[:max_items_per_section]
|
|
|
| return include_items, avoid_items
|
|
|
|
|
| def _format_rubric_block(items: list[str]) -> str:
|
| if not items:
|
| return "- None."
|
| return "\n".join(f"- {item}" for item in items)
|
|
|
|
|
| def _build_privileged_context(example: dict[str, Any], max_items_per_section: int | None) -> str:
|
| include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)
|
| return PRIVILEGED_RUBRIC_TEMPLATE.format(
|
| include_block=_format_rubric_block(include_items),
|
| avoid_block=_format_rubric_block(avoid_items),
|
| )
|
|
|
|
|
| def _build_privileged_contexts_by_criterion(
|
| example: dict[str, Any], max_items_per_section: int | None
|
| ) -> list[str]:
|
| include_items, avoid_items = _collect_rubric_items(example, max_items_per_section)
|
| contexts = []
|
| for item in include_items:
|
| contexts.append(POSITIVE_RUBRIC_TEACHER_TEMPLATE.format(positive_rubric=item))
|
| for item in avoid_items:
|
| contexts.append(NEGATIVE_RUBRIC_TEACHER_TEMPLATE.format(negative_rubric=item))
|
| return contexts
|
|
|
|
|
| def _make_conversation(example: dict[str, Any], max_items_per_section: int | None) -> dict[str, Any]:
|
| return {
|
| "prompt": [
|
| {"role": "system", "content": SYSTEM_PROMPT},
|
| {"role": "user", "content": example["question"]},
|
| ],
|
| "solution": example["reference_answer"],
|
| "privileged_context": _build_privileged_context(example, max_items_per_section),
|
| "privileged_contexts": _build_privileged_contexts_by_criterion(example, max_items_per_section),
|
| }
|
|
|
|
|
| def _completion_text(completion: list[dict[str, str]] | str) -> str:
|
| if isinstance(completion, list):
|
| return completion[0]["content"]
|
| return completion
|
|
|
|
|
| def _extract_final_answer(text: str) -> str:
|
| text = text.strip()
|
| marker_match = re.search(r"final answer:\s*(.*)$", text, flags=re.IGNORECASE | re.DOTALL)
|
| if marker_match:
|
| return marker_match.group(1).strip()
|
|
|
| boxed_match = re.search(r"####\s*(.*)$", text, flags=re.DOTALL)
|
| if boxed_match:
|
| return boxed_match.group(1).strip()
|
|
|
| if "</think>" in text:
|
| return text.split("</think>", 1)[1].strip()
|
|
|
| return text
|
|
|
|
|
| def _normalize_answer(text: str) -> str:
|
| text = unicodedata.normalize("NFKC", text)
|
| text = text.replace("\u2212", "-").replace("\u2013", "-").replace("\u2014", "-")
|
| text = re.sub(r"\\boxed\s*\{([^{}]+)\}", r"\1", text)
|
| text = re.sub(r"<[^>]+>", " ", text)
|
| text = re.sub(r"`+", " ", text)
|
| text = re.sub(r"\$(.*?)\$", r"\1", text)
|
| text = re.sub(r"\s+", " ", text)
|
| return text.strip().lower()
|
|
|
|
|
| def _tokenize_for_similarity(text: str) -> list[str]:
|
| normalized = _normalize_answer(text)
|
| normalized = re.sub(r"[^\w\s.+\-/=]", " ", normalized, flags=re.UNICODE)
|
| return re.findall(r"[\w.+\-/=]+", normalized, flags=re.UNICODE)
|
|
|
|
|
| def _token_f1(prediction: str, reference: str) -> float:
|
| prediction_tokens = _tokenize_for_similarity(prediction)
|
| reference_tokens = _tokenize_for_similarity(reference)
|
|
|
| if not prediction_tokens or not reference_tokens:
|
| return 0.0
|
|
|
| prediction_counter = Counter(prediction_tokens)
|
| reference_counter = Counter(reference_tokens)
|
| overlap = sum((prediction_counter & reference_counter).values())
|
| if overlap == 0:
|
| return 0.0
|
|
|
| precision = overlap / len(prediction_tokens)
|
| recall = overlap / len(reference_tokens)
|
| return 2 * precision * recall / (precision + recall)
|
|
|
|
|
| def _numeric_match_reward(prediction: str, reference: str) -> float | None:
|
| def _extract_numbers(text: str) -> list[float]:
|
| matches = re.findall(r"[-+]?\d+(?:\.\d+)?(?:e[-+]?\d+)?", text.lower())
|
| return [float(match) for match in matches]
|
|
|
| prediction_numbers = _extract_numbers(prediction)
|
| reference_numbers = _extract_numbers(reference)
|
| if len(prediction_numbers) != 1 or len(reference_numbers) != 1:
|
| return None
|
|
|
| pred_value = prediction_numbers[0]
|
| ref_value = reference_numbers[0]
|
| tolerance = max(1e-6, 1e-3 * max(abs(ref_value), 1.0))
|
| return 1.0 if abs(pred_value - ref_value) <= tolerance else 0.0
|
|
|
|
|
| def _math_verify_reward(prediction: str, reference: str) -> float | None:
|
| if not HAS_MATH_VERIFY:
|
| return None
|
| if len(reference.split()) > 20:
|
| return None
|
|
|
| gold_parsed = parse(reference, parsing_timeout=None)
|
| if len(gold_parsed) == 0:
|
| return None
|
|
|
| answer_parsed = parse(
|
| prediction,
|
| extraction_config=[
|
| LatexExtractionConfig(
|
| boxed_match_priority=0,
|
| normalization_config=NormalizationConfig(units=True),
|
| try_extract_without_anchor=False,
|
| )
|
| ],
|
| extraction_mode="first_match",
|
| parsing_timeout=None,
|
| )
|
| return float(verify(gold_parsed, answer_parsed, timeout_seconds=None))
|
|
|
|
|
| def _text_similarity_reward(prediction: str, reference: str) -> float:
|
| normalized_prediction = _normalize_answer(prediction)
|
| normalized_reference = _normalize_answer(reference)
|
| if not normalized_prediction or not normalized_reference:
|
| return 0.0
|
|
|
| if normalized_prediction == normalized_reference:
|
| return 1.0
|
| if normalized_reference in normalized_prediction and len(normalized_reference) >= 12:
|
| return 1.0
|
| if normalized_prediction in normalized_reference and len(normalized_prediction) >= 12:
|
| return 0.8
|
|
|
| token_score = _token_f1(normalized_prediction, normalized_reference)
|
| sequence_score = SequenceMatcher(None, normalized_prediction, normalized_reference).ratio()
|
| score = max(token_score, sequence_score)
|
| if score < 0.45:
|
| return 0.0
|
| return min(1.0, (score - 0.45) / 0.55)
|
|
|
|
|
| def rar_science_answer_reward(completions, solution, **kwargs) -> list[float]:
|
| rewards = []
|
| for completion, reference in zip(completions, solution, strict=True):
|
| prediction = _extract_final_answer(_completion_text(completion))
|
|
|
| reward = _math_verify_reward(prediction, reference)
|
| if reward is not None:
|
| rewards.append(reward)
|
| continue
|
|
|
| reward = _numeric_match_reward(prediction, reference)
|
| if reward is not None:
|
| rewards.append(reward)
|
| continue
|
|
|
| rewards.append(_text_similarity_reward(prediction, reference))
|
| return rewards
|
|
|
|
|
| def constant_zero_reward(completions, **kwargs) -> list[float]:
|
|
|
|
|
| return [0.0] * len(completions)
|
|
|
|
|
| def _run_reward_eval(
|
| trainer: SDPOTrainer, eval_dataset, max_new_tokens: int, num_examples: int | None, metric_prefix: str = "rar_science_eval"
|
| ) -> dict[str, float]:
|
| if num_examples is not None:
|
| eval_dataset = eval_dataset.select(range(min(num_examples, len(eval_dataset))))
|
|
|
| prompts = eval_dataset["prompt"]
|
| prompt_texts = [
|
| maybe_apply_chat_template({"prompt": prompt}, trainer.processing_class)["prompt"] for prompt in prompts
|
| ]
|
| tokenized = trainer.processing_class(
|
| text=prompt_texts,
|
| return_tensors="pt",
|
| padding=True,
|
| padding_side="left",
|
| truncation=True,
|
| max_length=trainer.max_prompt_length,
|
| add_special_tokens=False,
|
| )
|
| tokenized = {key: value.to(trainer.accelerator.device) for key, value in tokenized.items()}
|
|
|
| model = trainer.accelerator.unwrap_model(trainer.model)
|
| was_training = model.training
|
| model.eval()
|
| with torch.no_grad():
|
| generated = model.generate(
|
| **tokenized,
|
| max_new_tokens=max_new_tokens,
|
| do_sample=False,
|
| pad_token_id=trainer.processing_class.pad_token_id,
|
| eos_token_id=trainer.processing_class.eos_token_id,
|
| )
|
| if was_training:
|
| model.train()
|
|
|
| prompt_length = tokenized["input_ids"].shape[1]
|
| completions = trainer.processing_class.batch_decode(generated[:, prompt_length:], skip_special_tokens=True)
|
| completion_messages = [[{"role": "assistant", "content": completion}] for completion in completions]
|
|
|
| answer_rewards = rar_science_answer_reward(completion_messages, solution=eval_dataset["solution"])
|
| total = max(len(answer_rewards), 1)
|
| exact_total = sum(1 for reward in answer_rewards if reward >= 0.999)
|
|
|
| return {
|
| f"{metric_prefix}/answer_reward": sum(answer_rewards) / total,
|
| f"{metric_prefix}/answer_exact_rate": exact_total / total,
|
| f"{metric_prefix}/num_scored": float(len(answer_rewards)),
|
| }
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = TrlParser((RarScienceScriptArguments, RarScienceSDPOConfig, ModelConfig))
|
| script_args, training_args, model_args = parser.parse_args_and_config()
|
|
|
| dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
| training_args.model_init_kwargs = dict(
|
| revision=model_args.model_revision,
|
| attn_implementation=model_args.attn_implementation,
|
| dtype=dtype,
|
| )
|
| quantization_config = get_quantization_config(model_args)
|
| if quantization_config is not None:
|
| training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
|
| training_args.model_init_kwargs["quantization_config"] = quantization_config
|
|
|
| raw_dataset = _load_local_dataset(script_args.dataset_dir)
|
|
|
| train_split = raw_dataset[script_args.dataset_train_split]
|
| if script_args.max_train_examples is not None:
|
| train_split = train_split.shuffle(seed=script_args.dataset_shuffle_seed).select(
|
| range(min(script_args.max_train_examples, len(train_split)))
|
| )
|
|
|
| train_dataset = train_split.map(
|
| lambda example: _make_conversation(example, script_args.max_rubric_items_per_section),
|
| remove_columns=train_split.column_names,
|
| )
|
|
|
| eval_dataset = None
|
| if training_args.eval_strategy != "no" and script_args.dataset_test_split in raw_dataset:
|
| eval_split = raw_dataset[script_args.dataset_test_split]
|
| if script_args.max_eval_examples is not None:
|
| eval_split = eval_split.shuffle(seed=script_args.dataset_shuffle_seed).select(
|
| range(min(script_args.max_eval_examples, len(eval_split)))
|
| )
|
| eval_dataset = eval_split.map(
|
| lambda example: _make_conversation(example, script_args.max_rubric_items_per_section),
|
| remove_columns=eval_split.column_names,
|
| )
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| trainer = SDPOTrainer(
|
| model=model_args.model_name_or_path,
|
| args=training_args,
|
| reward_funcs=[constant_zero_reward],
|
| train_dataset=train_dataset,
|
| eval_dataset=eval_dataset,
|
| peft_config=get_peft_config(model_args),
|
| processing_class=tokenizer,
|
| )
|
|
|
| if eval_dataset is not None and script_args.eval_num_prompts:
|
| generation_config = GenerationConfig(
|
| max_new_tokens=training_args.max_completion_length,
|
| do_sample=True,
|
| temperature=training_args.temperature,
|
| )
|
| trainer.add_callback(
|
| LogCompletionsCallback(trainer, generation_config, num_prompts=script_args.eval_num_prompts)
|
| )
|
|
|
| pre_metrics = None
|
| if eval_dataset is not None:
|
| pre_metrics = _run_reward_eval(
|
| trainer,
|
| eval_dataset,
|
| max_new_tokens=script_args.reward_eval_max_new_tokens,
|
| num_examples=script_args.reward_eval_num_examples,
|
| )
|
| trainer.log_metrics("eval", {f"before_{key}": value for key, value in pre_metrics.items()})
|
| trainer.save_metrics("eval", {f"before_{key}": value for key, value in pre_metrics.items()})
|
|
|
| trainer.train()
|
|
|
| trainer.save_model(training_args.output_dir)
|
|
|
| if eval_dataset is not None and pre_metrics is not None:
|
| post_metrics = _run_reward_eval(
|
| trainer,
|
| eval_dataset,
|
| max_new_tokens=script_args.reward_eval_max_new_tokens,
|
| num_examples=script_args.reward_eval_num_examples,
|
| )
|
| after_metrics = {f"after_{key}": value for key, value in post_metrics.items()}
|
| delta_metrics = {
|
| f"delta_{key.split('/', 1)[1]}": after_metrics[f"after_{key}"] - pre_metrics[key] for key in pre_metrics
|
| }
|
| trainer.log_metrics("eval", after_metrics | delta_metrics)
|
| trainer.save_metrics("eval", after_metrics | delta_metrics)
|
|
|
| if training_args.push_to_hub:
|
| trainer.push_to_hub(dataset_name=script_args.dataset_dir)
|
|
|