"""reward.py — MMLUFormatReward (ADR-013, framework-side, generic). A structured-answer reward for RL on MMLU-style multiple-choice tasks. It scores ONLY the final answer letter + format validity — never the rationale's style or content. This is deliberate: the north-star use case (ADR-013) drives RL on a *personality-altered* model, and rewarding "persuasive" rationale would reward the very cognitive-distortion signature we are trying to measure rather than distort the reward toward it. Scoring per completion: +1.0 final answer parses and equals the gold letter 0.0 final answer parses but is wrong -0.2 no parseable final-answer marker (unparseable) -0.1 multiple DISTINCT final-answer markers present (format hacking) -len_penalty small penalty past a rationale character cap Parsing accepts (case-insensitive, last match wins for the canonical letter): - ``Answer: X`` (X in A-D) - JSON ``{"answer": "X"}`` Exploit detection: ``MMLUFormatReward`` keeps a running count of chosen letters (``option_distribution``) so an "always C" / option-prior exploit is detectable by inspecting that distribution after a run. A companion ``randomize_options`` helper shuffles option order with an original->shuffled label remap so the training data itself can be de-biased. """ from __future__ import annotations import json import re from collections import Counter from dataclasses import dataclass, field from typing import Any __all__ = ["MMLUFormatReward", "randomize_options", "parse_final_answer"] _VALID_LETTERS = ("A", "B", "C", "D") # ``Answer: X`` — tolerant of whitespace, optional markdown bold/asterisks. _ANSWER_RE = re.compile(r"answer\s*[:\-]\s*\*{0,2}([A-D])\b", re.IGNORECASE) # JSON ``{"answer": "X"}`` — extract the value of an "answer" key. _JSON_ANSWER_RE = re.compile( r'["\']answer["\']\s*:\s*["\']([A-D])["\']', re.IGNORECASE ) # Hedge AFTER an answer marker: ``Answer: A or B`` / ``Answer: A/B`` / # ``Answer: A, B`` — a single marker that names a SECOND distinct option is a # format hedge and must be treated as multiple-answers, not full credit for the # first letter (final-verify 2026-05-29). Captures the lead letter + the hedged # second letter immediately following via or / slash / comma / 'and'. _HEDGE_RE = re.compile( r"answer\s*[:\-]\s*\*{0,2}([A-D])\b\s*(?:or|and|/|,|\|)\s*\*{0,2}([A-D])\b", re.IGNORECASE, ) def _find_markers(text: str) -> list[str]: """Return ALL final-answer letters found (uppercased), in order of appearance. Used both to pick the canonical answer (last match wins) and to detect the multiple-distinct-markers format-hacking case. """ markers: list[tuple[int, str]] = [] for m in _ANSWER_RE.finditer(text or ""): markers.append((m.start(), m.group(1).upper())) for m in _JSON_ANSWER_RE.finditer(text or ""): markers.append((m.start(), m.group(1).upper())) markers.sort(key=lambda p: p[0]) return [letter for _, letter in markers] def parse_final_answer(completion: str) -> tuple[str | None, int]: """Parse the final answer letter from a completion. Returns ``(letter_or_None, n_distinct_markers)``. ``letter`` is the LAST marker found (last match wins). ``n_distinct_markers`` counts DISTINCT letters across all markers (so two ``Answer: C`` are not penalized, but ``Answer: A ... Answer: B`` is). """ markers = _find_markers(completion) if not markers: return None, 0 distinct = set(markers) # Hedge detection: a single marker naming a second distinct option # ("Answer: A or B") adds the hedged letter to the distinct set, so it is # penalized as a multiple-answers format hack instead of scoring the lead. for m in _HEDGE_RE.finditer(completion or ""): distinct.add(m.group(1).upper()) distinct.add(m.group(2).upper()) return markers[-1], len(distinct) @dataclass class MMLUFormatReward: """Callable reward_fn(prompts, completions, *, answers, **kwargs) -> list[float]. Args: rationale_char_cap: completions longer than this incur a small length penalty (``length_penalty_per_char`` per char past the cap). Caps verbosity without scoring rationale content. length_penalty_per_char: per-character penalty past the cap. correct_reward / wrong_reward / unparseable_reward / multiple_answers_reward: the scalar rewards for each outcome. Side effect: ``option_distribution`` (a Counter over chosen letters) and ``n_scored`` accumulate across calls so an "always C" exploit is detectable via ``exploit_report()``. """ rationale_char_cap: int = 512 length_penalty_per_char: float = 0.001 correct_reward: float = 1.0 wrong_reward: float = 0.0 unparseable_reward: float = -0.2 multiple_answers_reward: float = -0.1 option_distribution: Counter = field(default_factory=Counter) n_scored: int = 0 def __call__( self, prompts: Any = None, completions: list[str] | None = None, *, answers: list[str] | None = None, **kwargs: Any, ) -> list[float]: """Score a batch of completions against gold ``answers`` (letters A-D). ``prompts`` is accepted for the TRL reward-fn signature but unused (we score the completion text only). ``answers`` is required. """ if completions is None: completions = [] if answers is None: raise ValueError( "MMLUFormatReward requires `answers` (the gold letters, one per " "completion). Pass via reward_fn(..., answers=[...])." ) if len(answers) != len(completions): raise ValueError( f"answers/completions length mismatch: {len(answers)} vs " f"{len(completions)}." ) rewards: list[float] = [] for completion, gold in zip(completions, answers): rewards.append(self._score_one(completion, gold)) return rewards def _score_one(self, completion: str, gold: str) -> float: letter, n_distinct = parse_final_answer(completion) self.n_scored += 1 if letter is None: # Unparseable: no usable final-answer marker. Length penalty does # not apply (we never even parsed a letter to reward/penalize). return self.unparseable_reward # Log the chosen letter for exploit detection (always-C etc.). self.option_distribution[letter] += 1 if n_distinct > 1: # Multiple DISTINCT markers — format hacking. Penalize regardless # of correctness (the model is hedging / gaming the parser). base = self.multiple_answers_reward elif gold is not None and letter == str(gold).strip().upper(): base = self.correct_reward else: base = self.wrong_reward return base - self._length_penalty(completion) def _length_penalty(self, completion: str) -> float: over = max(0, len(completion or "") - self.rationale_char_cap) return self.length_penalty_per_char * over # ------------------------------------------------------------------ # Exploit detection # ------------------------------------------------------------------ def exploit_report(self) -> dict[str, Any]: """Summarize the chosen-letter distribution so an option-prior exploit (e.g. "always C") is detectable. Returns a dict with the raw counts, the most common letter, and its fraction of all parsed answers. A healthy run is ~uniform over A-D; a fraction near 1.0 for a single letter is the exploit signature. """ total = sum(self.option_distribution.values()) if total == 0: return { "counts": {}, "total_parsed": 0, "most_common": None, "max_fraction": 0.0, } letter, count = self.option_distribution.most_common(1)[0] return { "counts": dict(self.option_distribution), "total_parsed": total, "most_common": letter, "max_fraction": count / total, } def randomize_options( item: dict[str, Any], seed: int ) -> tuple[dict[str, Any], dict[str, str]]: """Shuffle the multiple-choice option order, tracking original->shuffled letters. Args: item: a dict with ``options`` (list[str], A-first ordering) and ``answer`` (the gold letter, A-D). Other keys are passed through. seed: deterministic RNG seed for the shuffle. Returns: ``(shuffled_item, label_remap)`` where ``shuffled_item`` has the options reordered and its ``answer`` updated to the gold option's NEW letter, and ``label_remap`` maps each ORIGINAL letter -> its NEW (shuffled) letter. This de-biases an option-prior exploit at the data level: if the gold answer is no longer correlated with a fixed position, "always C" stops working. """ import random options = list(item.get("options", [])) n = len(options) if n == 0: return dict(item), {} orig_letters = [chr(ord("A") + i) for i in range(n)] rng = random.Random(seed) perm = list(range(n)) rng.shuffle(perm) # perm[new_pos] = old_pos => option at new_pos is the old option perm[new_pos] shuffled_options = [options[perm[new]] for new in range(n)] # original letter -> new letter: old index `perm[new]` moved to position `new`. label_remap: dict[str, str] = {} for new_pos, old_pos in enumerate(perm): label_remap[orig_letters[old_pos]] = orig_letters[new_pos] shuffled_item = dict(item) shuffled_item["options"] = shuffled_options gold = str(item.get("answer", "")).strip().upper() if gold in label_remap: shuffled_item["answer"] = label_remap[gold] return shuffled_item, label_remap