Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """reward.py — MMLUFormatReward (ADR-013, framework-side, generic). | |
| A structured-answer reward for RL on MMLU-style multiple-choice tasks. It scores | |
| ONLY the final answer letter + format validity — never the rationale's style or | |
| content. This is deliberate: the north-star use case (ADR-013) drives RL on a | |
| *personality-altered* model, and rewarding "persuasive" rationale would reward | |
| the very cognitive-distortion signature we are trying to measure rather than | |
| distort the reward toward it. | |
| Scoring per completion: | |
| +1.0 final answer parses and equals the gold letter | |
| 0.0 final answer parses but is wrong | |
| -0.2 no parseable final-answer marker (unparseable) | |
| -0.1 multiple DISTINCT final-answer markers present (format hacking) | |
| -len_penalty small penalty past a rationale character cap | |
| Parsing accepts (case-insensitive, last match wins for the canonical letter): | |
| - ``Answer: X`` (X in A-D) | |
| - JSON ``{"answer": "X"}`` | |
| Exploit detection: ``MMLUFormatReward`` keeps a running count of chosen letters | |
| (``option_distribution``) so an "always C" / option-prior exploit is detectable | |
| by inspecting that distribution after a run. A companion ``randomize_options`` | |
| helper shuffles option order with an original->shuffled label remap so the | |
| training data itself can be de-biased. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from collections import Counter | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| __all__ = ["MMLUFormatReward", "randomize_options", "parse_final_answer"] | |
| _VALID_LETTERS = ("A", "B", "C", "D") | |
| # ``Answer: X`` — tolerant of whitespace, optional markdown bold/asterisks. | |
| _ANSWER_RE = re.compile(r"answer\s*[:\-]\s*\*{0,2}([A-D])\b", re.IGNORECASE) | |
| # JSON ``{"answer": "X"}`` — extract the value of an "answer" key. | |
| _JSON_ANSWER_RE = re.compile( | |
| r'["\']answer["\']\s*:\s*["\']([A-D])["\']', re.IGNORECASE | |
| ) | |
| # Hedge AFTER an answer marker: ``Answer: A or B`` / ``Answer: A/B`` / | |
| # ``Answer: A, B`` — a single marker that names a SECOND distinct option is a | |
| # format hedge and must be treated as multiple-answers, not full credit for the | |
| # first letter (final-verify 2026-05-29). Captures the lead letter + the hedged | |
| # second letter immediately following via or / slash / comma / 'and'. | |
| _HEDGE_RE = re.compile( | |
| r"answer\s*[:\-]\s*\*{0,2}([A-D])\b\s*(?:or|and|/|,|\|)\s*\*{0,2}([A-D])\b", | |
| re.IGNORECASE, | |
| ) | |
| def _find_markers(text: str) -> list[str]: | |
| """Return ALL final-answer letters found (uppercased), in order of appearance. | |
| Used both to pick the canonical answer (last match wins) and to detect the | |
| multiple-distinct-markers format-hacking case. | |
| """ | |
| markers: list[tuple[int, str]] = [] | |
| for m in _ANSWER_RE.finditer(text or ""): | |
| markers.append((m.start(), m.group(1).upper())) | |
| for m in _JSON_ANSWER_RE.finditer(text or ""): | |
| markers.append((m.start(), m.group(1).upper())) | |
| markers.sort(key=lambda p: p[0]) | |
| return [letter for _, letter in markers] | |
| def parse_final_answer(completion: str) -> tuple[str | None, int]: | |
| """Parse the final answer letter from a completion. | |
| Returns ``(letter_or_None, n_distinct_markers)``. ``letter`` is the LAST | |
| marker found (last match wins). ``n_distinct_markers`` counts DISTINCT | |
| letters across all markers (so two ``Answer: C`` are not penalized, but | |
| ``Answer: A ... Answer: B`` is). | |
| """ | |
| markers = _find_markers(completion) | |
| if not markers: | |
| return None, 0 | |
| distinct = set(markers) | |
| # Hedge detection: a single marker naming a second distinct option | |
| # ("Answer: A or B") adds the hedged letter to the distinct set, so it is | |
| # penalized as a multiple-answers format hack instead of scoring the lead. | |
| for m in _HEDGE_RE.finditer(completion or ""): | |
| distinct.add(m.group(1).upper()) | |
| distinct.add(m.group(2).upper()) | |
| return markers[-1], len(distinct) | |
| class MMLUFormatReward: | |
| """Callable reward_fn(prompts, completions, *, answers, **kwargs) -> list[float]. | |
| Args: | |
| rationale_char_cap: completions longer than this incur a small length | |
| penalty (``length_penalty_per_char`` per char past the cap). Caps | |
| verbosity without scoring rationale content. | |
| length_penalty_per_char: per-character penalty past the cap. | |
| correct_reward / wrong_reward / unparseable_reward / | |
| multiple_answers_reward: the scalar rewards for each outcome. | |
| Side effect: ``option_distribution`` (a Counter over chosen letters) and | |
| ``n_scored`` accumulate across calls so an "always C" exploit is detectable | |
| via ``exploit_report()``. | |
| """ | |
| rationale_char_cap: int = 512 | |
| length_penalty_per_char: float = 0.001 | |
| correct_reward: float = 1.0 | |
| wrong_reward: float = 0.0 | |
| unparseable_reward: float = -0.2 | |
| multiple_answers_reward: float = -0.1 | |
| option_distribution: Counter = field(default_factory=Counter) | |
| n_scored: int = 0 | |
| def __call__( | |
| self, | |
| prompts: Any = None, | |
| completions: list[str] | None = None, | |
| *, | |
| answers: list[str] | None = None, | |
| **kwargs: Any, | |
| ) -> list[float]: | |
| """Score a batch of completions against gold ``answers`` (letters A-D). | |
| ``prompts`` is accepted for the TRL reward-fn signature but unused | |
| (we score the completion text only). ``answers`` is required. | |
| """ | |
| if completions is None: | |
| completions = [] | |
| if answers is None: | |
| raise ValueError( | |
| "MMLUFormatReward requires `answers` (the gold letters, one per " | |
| "completion). Pass via reward_fn(..., answers=[...])." | |
| ) | |
| if len(answers) != len(completions): | |
| raise ValueError( | |
| f"answers/completions length mismatch: {len(answers)} vs " | |
| f"{len(completions)}." | |
| ) | |
| rewards: list[float] = [] | |
| for completion, gold in zip(completions, answers): | |
| rewards.append(self._score_one(completion, gold)) | |
| return rewards | |
| def _score_one(self, completion: str, gold: str) -> float: | |
| letter, n_distinct = parse_final_answer(completion) | |
| self.n_scored += 1 | |
| if letter is None: | |
| # Unparseable: no usable final-answer marker. Length penalty does | |
| # not apply (we never even parsed a letter to reward/penalize). | |
| return self.unparseable_reward | |
| # Log the chosen letter for exploit detection (always-C etc.). | |
| self.option_distribution[letter] += 1 | |
| if n_distinct > 1: | |
| # Multiple DISTINCT markers — format hacking. Penalize regardless | |
| # of correctness (the model is hedging / gaming the parser). | |
| base = self.multiple_answers_reward | |
| elif gold is not None and letter == str(gold).strip().upper(): | |
| base = self.correct_reward | |
| else: | |
| base = self.wrong_reward | |
| return base - self._length_penalty(completion) | |
| def _length_penalty(self, completion: str) -> float: | |
| over = max(0, len(completion or "") - self.rationale_char_cap) | |
| return self.length_penalty_per_char * over | |
| # ------------------------------------------------------------------ | |
| # Exploit detection | |
| # ------------------------------------------------------------------ | |
| def exploit_report(self) -> dict[str, Any]: | |
| """Summarize the chosen-letter distribution so an option-prior exploit | |
| (e.g. "always C") is detectable. | |
| Returns a dict with the raw counts, the most common letter, and its | |
| fraction of all parsed answers. A healthy run is ~uniform over A-D; a | |
| fraction near 1.0 for a single letter is the exploit signature. | |
| """ | |
| total = sum(self.option_distribution.values()) | |
| if total == 0: | |
| return { | |
| "counts": {}, | |
| "total_parsed": 0, | |
| "most_common": None, | |
| "max_fraction": 0.0, | |
| } | |
| letter, count = self.option_distribution.most_common(1)[0] | |
| return { | |
| "counts": dict(self.option_distribution), | |
| "total_parsed": total, | |
| "most_common": letter, | |
| "max_fraction": count / total, | |
| } | |
| def randomize_options( | |
| item: dict[str, Any], seed: int | |
| ) -> tuple[dict[str, Any], dict[str, str]]: | |
| """Shuffle the multiple-choice option order, tracking original->shuffled letters. | |
| Args: | |
| item: a dict with ``options`` (list[str], A-first ordering) and | |
| ``answer`` (the gold letter, A-D). Other keys are passed through. | |
| seed: deterministic RNG seed for the shuffle. | |
| Returns: | |
| ``(shuffled_item, label_remap)`` where ``shuffled_item`` has the options | |
| reordered and its ``answer`` updated to the gold option's NEW letter, and | |
| ``label_remap`` maps each ORIGINAL letter -> its NEW (shuffled) letter. | |
| This de-biases an option-prior exploit at the data level: if the gold answer | |
| is no longer correlated with a fixed position, "always C" stops working. | |
| """ | |
| import random | |
| options = list(item.get("options", [])) | |
| n = len(options) | |
| if n == 0: | |
| return dict(item), {} | |
| orig_letters = [chr(ord("A") + i) for i in range(n)] | |
| rng = random.Random(seed) | |
| perm = list(range(n)) | |
| rng.shuffle(perm) | |
| # perm[new_pos] = old_pos => option at new_pos is the old option perm[new_pos] | |
| shuffled_options = [options[perm[new]] for new in range(n)] | |
| # original letter -> new letter: old index `perm[new]` moved to position `new`. | |
| label_remap: dict[str, str] = {} | |
| for new_pos, old_pos in enumerate(perm): | |
| label_remap[orig_letters[old_pos]] = orig_letters[new_pos] | |
| shuffled_item = dict(item) | |
| shuffled_item["options"] = shuffled_options | |
| gold = str(item.get("answer", "")).strip().upper() | |
| if gold in label_remap: | |
| shuffled_item["answer"] = label_remap[gold] | |
| return shuffled_item, label_remap | |