Codeseys's picture
fix(phase-8): close all 5 cross-family final-verify findings + regression tests
678d10b
Raw
History Blame Contribute Delete
10 kB
"""reward.py — MMLUFormatReward (ADR-013, framework-side, generic).
A structured-answer reward for RL on MMLU-style multiple-choice tasks. It scores
ONLY the final answer letter + format validity — never the rationale's style or
content. This is deliberate: the north-star use case (ADR-013) drives RL on a
*personality-altered* model, and rewarding "persuasive" rationale would reward
the very cognitive-distortion signature we are trying to measure rather than
distort the reward toward it.
Scoring per completion:
+1.0 final answer parses and equals the gold letter
0.0 final answer parses but is wrong
-0.2 no parseable final-answer marker (unparseable)
-0.1 multiple DISTINCT final-answer markers present (format hacking)
-len_penalty small penalty past a rationale character cap
Parsing accepts (case-insensitive, last match wins for the canonical letter):
- ``Answer: X`` (X in A-D)
- JSON ``{"answer": "X"}``
Exploit detection: ``MMLUFormatReward`` keeps a running count of chosen letters
(``option_distribution``) so an "always C" / option-prior exploit is detectable
by inspecting that distribution after a run. A companion ``randomize_options``
helper shuffles option order with an original->shuffled label remap so the
training data itself can be de-biased.
"""
from __future__ import annotations
import json
import re
from collections import Counter
from dataclasses import dataclass, field
from typing import Any
__all__ = ["MMLUFormatReward", "randomize_options", "parse_final_answer"]
_VALID_LETTERS = ("A", "B", "C", "D")
# ``Answer: X`` — tolerant of whitespace, optional markdown bold/asterisks.
_ANSWER_RE = re.compile(r"answer\s*[:\-]\s*\*{0,2}([A-D])\b", re.IGNORECASE)
# JSON ``{"answer": "X"}`` — extract the value of an "answer" key.
_JSON_ANSWER_RE = re.compile(
r'["\']answer["\']\s*:\s*["\']([A-D])["\']', re.IGNORECASE
)
# Hedge AFTER an answer marker: ``Answer: A or B`` / ``Answer: A/B`` /
# ``Answer: A, B`` — a single marker that names a SECOND distinct option is a
# format hedge and must be treated as multiple-answers, not full credit for the
# first letter (final-verify 2026-05-29). Captures the lead letter + the hedged
# second letter immediately following via or / slash / comma / 'and'.
_HEDGE_RE = re.compile(
r"answer\s*[:\-]\s*\*{0,2}([A-D])\b\s*(?:or|and|/|,|\|)\s*\*{0,2}([A-D])\b",
re.IGNORECASE,
)
def _find_markers(text: str) -> list[str]:
"""Return ALL final-answer letters found (uppercased), in order of appearance.
Used both to pick the canonical answer (last match wins) and to detect the
multiple-distinct-markers format-hacking case.
"""
markers: list[tuple[int, str]] = []
for m in _ANSWER_RE.finditer(text or ""):
markers.append((m.start(), m.group(1).upper()))
for m in _JSON_ANSWER_RE.finditer(text or ""):
markers.append((m.start(), m.group(1).upper()))
markers.sort(key=lambda p: p[0])
return [letter for _, letter in markers]
def parse_final_answer(completion: str) -> tuple[str | None, int]:
"""Parse the final answer letter from a completion.
Returns ``(letter_or_None, n_distinct_markers)``. ``letter`` is the LAST
marker found (last match wins). ``n_distinct_markers`` counts DISTINCT
letters across all markers (so two ``Answer: C`` are not penalized, but
``Answer: A ... Answer: B`` is).
"""
markers = _find_markers(completion)
if not markers:
return None, 0
distinct = set(markers)
# Hedge detection: a single marker naming a second distinct option
# ("Answer: A or B") adds the hedged letter to the distinct set, so it is
# penalized as a multiple-answers format hack instead of scoring the lead.
for m in _HEDGE_RE.finditer(completion or ""):
distinct.add(m.group(1).upper())
distinct.add(m.group(2).upper())
return markers[-1], len(distinct)
@dataclass
class MMLUFormatReward:
"""Callable reward_fn(prompts, completions, *, answers, **kwargs) -> list[float].
Args:
rationale_char_cap: completions longer than this incur a small length
penalty (``length_penalty_per_char`` per char past the cap). Caps
verbosity without scoring rationale content.
length_penalty_per_char: per-character penalty past the cap.
correct_reward / wrong_reward / unparseable_reward /
multiple_answers_reward: the scalar rewards for each outcome.
Side effect: ``option_distribution`` (a Counter over chosen letters) and
``n_scored`` accumulate across calls so an "always C" exploit is detectable
via ``exploit_report()``.
"""
rationale_char_cap: int = 512
length_penalty_per_char: float = 0.001
correct_reward: float = 1.0
wrong_reward: float = 0.0
unparseable_reward: float = -0.2
multiple_answers_reward: float = -0.1
option_distribution: Counter = field(default_factory=Counter)
n_scored: int = 0
def __call__(
self,
prompts: Any = None,
completions: list[str] | None = None,
*,
answers: list[str] | None = None,
**kwargs: Any,
) -> list[float]:
"""Score a batch of completions against gold ``answers`` (letters A-D).
``prompts`` is accepted for the TRL reward-fn signature but unused
(we score the completion text only). ``answers`` is required.
"""
if completions is None:
completions = []
if answers is None:
raise ValueError(
"MMLUFormatReward requires `answers` (the gold letters, one per "
"completion). Pass via reward_fn(..., answers=[...])."
)
if len(answers) != len(completions):
raise ValueError(
f"answers/completions length mismatch: {len(answers)} vs "
f"{len(completions)}."
)
rewards: list[float] = []
for completion, gold in zip(completions, answers):
rewards.append(self._score_one(completion, gold))
return rewards
def _score_one(self, completion: str, gold: str) -> float:
letter, n_distinct = parse_final_answer(completion)
self.n_scored += 1
if letter is None:
# Unparseable: no usable final-answer marker. Length penalty does
# not apply (we never even parsed a letter to reward/penalize).
return self.unparseable_reward
# Log the chosen letter for exploit detection (always-C etc.).
self.option_distribution[letter] += 1
if n_distinct > 1:
# Multiple DISTINCT markers — format hacking. Penalize regardless
# of correctness (the model is hedging / gaming the parser).
base = self.multiple_answers_reward
elif gold is not None and letter == str(gold).strip().upper():
base = self.correct_reward
else:
base = self.wrong_reward
return base - self._length_penalty(completion)
def _length_penalty(self, completion: str) -> float:
over = max(0, len(completion or "") - self.rationale_char_cap)
return self.length_penalty_per_char * over
# ------------------------------------------------------------------
# Exploit detection
# ------------------------------------------------------------------
def exploit_report(self) -> dict[str, Any]:
"""Summarize the chosen-letter distribution so an option-prior exploit
(e.g. "always C") is detectable.
Returns a dict with the raw counts, the most common letter, and its
fraction of all parsed answers. A healthy run is ~uniform over A-D; a
fraction near 1.0 for a single letter is the exploit signature.
"""
total = sum(self.option_distribution.values())
if total == 0:
return {
"counts": {},
"total_parsed": 0,
"most_common": None,
"max_fraction": 0.0,
}
letter, count = self.option_distribution.most_common(1)[0]
return {
"counts": dict(self.option_distribution),
"total_parsed": total,
"most_common": letter,
"max_fraction": count / total,
}
def randomize_options(
item: dict[str, Any], seed: int
) -> tuple[dict[str, Any], dict[str, str]]:
"""Shuffle the multiple-choice option order, tracking original->shuffled letters.
Args:
item: a dict with ``options`` (list[str], A-first ordering) and
``answer`` (the gold letter, A-D). Other keys are passed through.
seed: deterministic RNG seed for the shuffle.
Returns:
``(shuffled_item, label_remap)`` where ``shuffled_item`` has the options
reordered and its ``answer`` updated to the gold option's NEW letter, and
``label_remap`` maps each ORIGINAL letter -> its NEW (shuffled) letter.
This de-biases an option-prior exploit at the data level: if the gold answer
is no longer correlated with a fixed position, "always C" stops working.
"""
import random
options = list(item.get("options", []))
n = len(options)
if n == 0:
return dict(item), {}
orig_letters = [chr(ord("A") + i) for i in range(n)]
rng = random.Random(seed)
perm = list(range(n))
rng.shuffle(perm)
# perm[new_pos] = old_pos => option at new_pos is the old option perm[new_pos]
shuffled_options = [options[perm[new]] for new in range(n)]
# original letter -> new letter: old index `perm[new]` moved to position `new`.
label_remap: dict[str, str] = {}
for new_pos, old_pos in enumerate(perm):
label_remap[orig_letters[old_pos]] = orig_letters[new_pos]
shuffled_item = dict(item)
shuffled_item["options"] = shuffled_options
gold = str(item.get("answer", "")).strip().upper()
if gold in label_remap:
shuffled_item["answer"] = label_remap[gold]
return shuffled_item, label_remap