Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
feat(wave-a): close ADR-011 (SDPO alignment indices) + ADR-012 (review findings)
Browse filesB1/ADR-011: collator emits student/teacher_response_idx + valid masks via
_mask_to_padded_indices; loss sentinel-masks padding. Strict SDPO no longer
raises against the real collator (the regression my review-fix introduced).
B2/ADR-012:
- k1-KL: found TRL 1.5.0 uses k3 not k1; corrected docstring honestly + documenting test.
- hint routing: style/communication/effort sites now reach the judge (error-kind aware).
- HackMonitor: added patch-provenance layer defeating string-concat obfuscation; ADR-010 language corrected AST->signature+patch-provenance.
- curriculum: optional turns/think_tokens effort signals, backward-compatible.
210 passed / 16 skipped (was 192). Two Opus-4.8 workers in parallel.
- composer_replication/datagen/curriculum.py +64 -3
- composer_replication/datagen/monitor.py +108 -9
- composer_replication/datagen/tests/test_feature_deletion.py +68 -0
- composer_replication/hint_generator.py +73 -2
- composer_replication/tests/test_hint_routing.py +97 -0
- composer_replication/trainer/composer_trainer.py +36 -5
- composer_replication/trainer/data_collator.py +60 -0
- composer_replication/trainer/tests/test_dr_grpo_config_and_alignment.py +64 -0
- composer_replication/trainer/tests/test_sdpo_alignment_indices.py +274 -0
- docs/adrs/ADR-010-feature-deletion-datagen.md +14 -12
|
@@ -25,6 +25,23 @@ from dataclasses import dataclass, field
|
|
| 25 |
class _TaskStats:
|
| 26 |
n_pass: float = 0.0
|
| 27 |
n_total: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
@property
|
| 30 |
def p_hat(self) -> float:
|
|
@@ -45,10 +62,21 @@ class DifficultyCurriculum:
|
|
| 45 |
tau_easy: float = 0.95 # above this => retired
|
| 46 |
tau_hard: float = 0.02 # below this (after min_exposures) => quarantined
|
| 47 |
min_exposures: int = 8 # before a task can be quarantined as impossible
|
|
|
|
|
|
|
|
|
|
| 48 |
_stats: dict[str, _TaskStats] = field(default_factory=dict)
|
| 49 |
_quarantined: set[str] = field(default_factory=set)
|
| 50 |
|
| 51 |
-
def update(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"""Record `n_pass` successes over `n_total` exposures.
|
| 53 |
|
| 54 |
`n_pass` is a FLOAT so multi-feature tasks can record fractional credit
|
|
@@ -57,10 +85,17 @@ class DifficultyCurriculum:
|
|
| 57 |
`int(reward > 0)`, which logged a 0.5 partial as a full pass and let
|
| 58 |
`p_hat` cross `tau_easy` so the task was retired before the policy ever
|
| 59 |
learned the remaining features.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
"""
|
| 61 |
st = self._stats.setdefault(task_id, _TaskStats())
|
| 62 |
st.n_pass += n_pass
|
| 63 |
st.n_total += n_total
|
|
|
|
| 64 |
if (
|
| 65 |
st.n_total >= self.min_exposures
|
| 66 |
and st.raw_rate < self.tau_hard
|
|
@@ -71,13 +106,39 @@ class DifficultyCurriculum:
|
|
| 71 |
return self._stats.get(task_id, _TaskStats()).p_hat
|
| 72 |
|
| 73 |
def weight(self, task_id: str) -> float:
|
| 74 |
-
"""Sampling weight. Retired/quarantined => 0; else frontier-variance
|
|
|
|
| 75 |
if task_id in self._quarantined:
|
| 76 |
return 0.0
|
| 77 |
p = self.p_hat(task_id)
|
| 78 |
if p > self.tau_easy:
|
| 79 |
return 0.0 # retired — model has aced it
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
def weights(self, task_ids: list[str]) -> list[float]:
|
| 83 |
return [self.weight(t) for t in task_ids]
|
|
|
|
| 25 |
class _TaskStats:
|
| 26 |
n_pass: float = 0.0
|
| 27 |
n_total: int = 0
|
| 28 |
+
# Running means of effort signals (ADR-012 finding #4). `n_effort` counts
|
| 29 |
+
# exposures that supplied an effort signal (may differ from n_total since
|
| 30 |
+
# turns/think_tokens are optional per update).
|
| 31 |
+
mean_turns: float = 0.0
|
| 32 |
+
mean_think: float = 0.0
|
| 33 |
+
n_effort: int = 0
|
| 34 |
+
|
| 35 |
+
def observe_effort(self, turns: float | None, think_tokens: float | None) -> None:
|
| 36 |
+
"""Fold optional turn / think-token signals into running means."""
|
| 37 |
+
if turns is None and think_tokens is None:
|
| 38 |
+
return
|
| 39 |
+
self.n_effort += 1
|
| 40 |
+
k = self.n_effort
|
| 41 |
+
if turns is not None:
|
| 42 |
+
self.mean_turns += (turns - self.mean_turns) / k
|
| 43 |
+
if think_tokens is not None:
|
| 44 |
+
self.mean_think += (think_tokens - self.mean_think) / k
|
| 45 |
|
| 46 |
@property
|
| 47 |
def p_hat(self) -> float:
|
|
|
|
| 62 |
tau_easy: float = 0.95 # above this => retired
|
| 63 |
tau_hard: float = 0.02 # below this (after min_exposures) => quarantined
|
| 64 |
min_exposures: int = 8 # before a task can be quarantined as impossible
|
| 65 |
+
# Strength of the effort (turns/think-token) difficulty tilt (ADR-012 #4).
|
| 66 |
+
# 0.0 reproduces pre-ADR-012 behavior exactly.
|
| 67 |
+
effort_gain: float = 0.1
|
| 68 |
_stats: dict[str, _TaskStats] = field(default_factory=dict)
|
| 69 |
_quarantined: set[str] = field(default_factory=set)
|
| 70 |
|
| 71 |
+
def update(
|
| 72 |
+
self,
|
| 73 |
+
task_id: str,
|
| 74 |
+
n_pass: float,
|
| 75 |
+
n_total: int,
|
| 76 |
+
*,
|
| 77 |
+
turns: float | None = None,
|
| 78 |
+
think_tokens: float | None = None,
|
| 79 |
+
) -> None:
|
| 80 |
"""Record `n_pass` successes over `n_total` exposures.
|
| 81 |
|
| 82 |
`n_pass` is a FLOAT so multi-feature tasks can record fractional credit
|
|
|
|
| 85 |
`int(reward > 0)`, which logged a 0.5 partial as a full pass and let
|
| 86 |
`p_hat` cross `tau_easy` so the task was retired before the policy ever
|
| 87 |
learned the remaining features.
|
| 88 |
+
|
| 89 |
+
`turns` / `think_tokens` (ADR-012 finding #4) are OPTIONAL per-exposure
|
| 90 |
+
effort signals. The Composer 2 tech report keys the curriculum on rollout
|
| 91 |
+
#turns + thinking-token count: at equal pass-rate, a task that takes more
|
| 92 |
+
turns / thinking is HARDER and should stay on the frontier longer. Both
|
| 93 |
+
default to None => identical behavior to the pre-ADR-012 curriculum.
|
| 94 |
"""
|
| 95 |
st = self._stats.setdefault(task_id, _TaskStats())
|
| 96 |
st.n_pass += n_pass
|
| 97 |
st.n_total += n_total
|
| 98 |
+
st.observe_effort(turns, think_tokens)
|
| 99 |
if (
|
| 100 |
st.n_total >= self.min_exposures
|
| 101 |
and st.raw_rate < self.tau_hard
|
|
|
|
| 106 |
return self._stats.get(task_id, _TaskStats()).p_hat
|
| 107 |
|
| 108 |
def weight(self, task_id: str) -> float:
|
| 109 |
+
"""Sampling weight. Retired/quarantined => 0; else frontier-variance,
|
| 110 |
+
tilted up for higher-effort (more turns / think-tokens) tasks."""
|
| 111 |
if task_id in self._quarantined:
|
| 112 |
return 0.0
|
| 113 |
p = self.p_hat(task_id)
|
| 114 |
if p > self.tau_easy:
|
| 115 |
return 0.0 # retired — model has aced it
|
| 116 |
+
base = p * (1.0 - p) # max at p=0.5
|
| 117 |
+
return base * self._effort_factor(task_id)
|
| 118 |
+
|
| 119 |
+
def _effort_factor(self, task_id: str) -> float:
|
| 120 |
+
"""Mild multiplicative difficulty tilt from the turn/think-token signals
|
| 121 |
+
(ADR-012 finding #4). Returns 1.0 when no effort signals are recorded
|
| 122 |
+
anywhere (so weight() is identical to the pre-ADR-012 behavior), else
|
| 123 |
+
``1 + effort_gain * z`` where z in [0,1] normalizes this task's mean
|
| 124 |
+
effort against the max observed across all tracked tasks. Monotone
|
| 125 |
+
nondecreasing in effort => a higher-turn task weighs >= a lower-turn one
|
| 126 |
+
at equal pass-rate."""
|
| 127 |
+
st = self._stats.get(task_id)
|
| 128 |
+
if st is None or st.n_effort == 0:
|
| 129 |
+
return 1.0
|
| 130 |
+
max_turns = max(
|
| 131 |
+
(s.mean_turns for s in self._stats.values() if s.n_effort), default=0.0
|
| 132 |
+
)
|
| 133 |
+
max_think = max(
|
| 134 |
+
(s.mean_think for s in self._stats.values() if s.n_effort), default=0.0
|
| 135 |
+
)
|
| 136 |
+
z_turns = st.mean_turns / max_turns if max_turns > 0 else 0.0
|
| 137 |
+
z_think = st.mean_think / max_think if max_think > 0 else 0.0
|
| 138 |
+
# Combine the two normalized effort signals (mean of those present).
|
| 139 |
+
components = [z for z, mx in ((z_turns, max_turns), (z_think, max_think)) if mx > 0]
|
| 140 |
+
z = sum(components) / len(components) if components else 0.0
|
| 141 |
+
return 1.0 + self.effort_gain * z
|
| 142 |
|
| 143 |
def weights(self, task_ids: list[str]) -> list[float]:
|
| 144 |
return [self.weight(t) for t in task_ids]
|
|
@@ -1,14 +1,21 @@
|
|
| 1 |
-
"""monitor.py — reward-hacking provenance monitor (ADR-010 §3c).
|
| 2 |
|
| 3 |
Composer's blog reports the model recovering deleted signatures by decompiling
|
| 4 |
Java bytecode and reverse-engineering Python type-check caches. The PRIMARY
|
| 5 |
control is the sandbox lockdown (no decompilers, scrubbed caches — see
|
| 6 |
-
sandbox.py). This monitor is defense-in-depth
|
| 7 |
-
deleted symbol reappears via a non-implementation path (reading a cache,
|
| 8 |
-
importing a sibling that smuggles the cached bytecode, decompiling), so the
|
| 9 |
-
grader can MASK the reward.
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
@@ -35,18 +42,56 @@ _HACK_SIGNATURES: tuple[str, ...] = (
|
|
| 35 |
"git cat-file",
|
| 36 |
)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
@dataclass
|
| 40 |
class HackMonitor:
|
| 41 |
"""Flags a trajectory as a suspected reward-hack.
|
| 42 |
|
| 43 |
-
`flag(trajectory, deleted_symbols)` returns True if any action
|
| 44 |
-
recovered a deleted symbol via a non-implementation path.
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
|
| 47 |
extra_signatures: tuple[str, ...] = field(default_factory=tuple)
|
| 48 |
|
| 49 |
-
def flag(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
sigs = _HACK_SIGNATURES + tuple(self.extra_signatures)
|
| 51 |
for action in trajectory:
|
| 52 |
blob = " ".join(
|
|
@@ -61,4 +106,58 @@ class HackMonitor:
|
|
| 61 |
r"(cache|\.pyc|\.class|decompil|disassembl)", blob
|
| 62 |
):
|
| 63 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
return False
|
|
|
|
| 1 |
+
"""monitor.py — reward-hacking signature + patch-provenance monitor (ADR-010 §3c).
|
| 2 |
|
| 3 |
Composer's blog reports the model recovering deleted signatures by decompiling
|
| 4 |
Java bytecode and reverse-engineering Python type-check caches. The PRIMARY
|
| 5 |
control is the sandbox lockdown (no decompilers, scrubbed caches — see
|
| 6 |
+
sandbox.py). This monitor is defense-in-depth with two layers:
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
1. signature layer — flag trajectory actions containing cache/decompiler
|
| 9 |
+
substrings (defeatable by string-concat obfuscation, kept for cheap
|
| 10 |
+
defense-in-depth);
|
| 11 |
+
2. patch-provenance layer — if a deleted symbol reappears VERBATIM in the
|
| 12 |
+
agent's submitted patch/diff AND the trajectory contains a cache/decompiler/
|
| 13 |
+
file-read action targeting a bytecode/cache artifact (even string-concat
|
| 14 |
+
obfuscated, e.g. ``"__py"+"cache__"``), flag it. This catches the
|
| 15 |
+
obfuscation bypass that pure substring matching misses.
|
| 16 |
+
|
| 17 |
+
It is heuristic (will have false negatives); it is not the only safeguard. It is
|
| 18 |
+
NOT a full AST analysis — it is a signature + patch-provenance monitor.
|
| 19 |
"""
|
| 20 |
from __future__ import annotations
|
| 21 |
|
|
|
|
| 42 |
"git cat-file",
|
| 43 |
)
|
| 44 |
|
| 45 |
+
# Cache/bytecode/decompiler artifact markers for the patch-provenance check.
|
| 46 |
+
# Matched AFTER normalizing away quotes + '+' (string-concat obfuscation) and
|
| 47 |
+
# whitespace, so ``"__py" + "cache__"`` collapses to ``__pycache__``.
|
| 48 |
+
_ARTIFACT_MARKERS: tuple[str, ...] = (
|
| 49 |
+
"__pycache__",
|
| 50 |
+
".pyc",
|
| 51 |
+
".mypy_cache",
|
| 52 |
+
".class",
|
| 53 |
+
"javap",
|
| 54 |
+
"decompile",
|
| 55 |
+
"uncompyle",
|
| 56 |
+
"decompyle",
|
| 57 |
+
".pyo",
|
| 58 |
+
"bytecode",
|
| 59 |
+
"disassemb",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Actions that "read" something (a cache/bytecode artifact, for provenance).
|
| 63 |
+
_READ_MARKERS: tuple[str, ...] = (
|
| 64 |
+
"cat", "read", "open", "load", "strings", "unzip", "jar", "javap",
|
| 65 |
+
"decompile", "uncompyle", "decompyle", "dump", "show", "less", "head",
|
| 66 |
+
"tail", "import",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _normalize(blob: str) -> str:
|
| 71 |
+
"""Strip quotes, '+' (concat), and whitespace so string-concat obfuscation
|
| 72 |
+
like ``"__py" + "cache__"`` collapses to a contiguous ``__pycache__``."""
|
| 73 |
+
return re.sub(r"[\"'+\s]", "", blob)
|
| 74 |
+
|
| 75 |
|
| 76 |
@dataclass
|
| 77 |
class HackMonitor:
|
| 78 |
"""Flags a trajectory as a suspected reward-hack.
|
| 79 |
|
| 80 |
+
`flag(trajectory, deleted_symbols, patch=...)` returns True if any action
|
| 81 |
+
looks like it recovered a deleted symbol via a non-implementation path. Two
|
| 82 |
+
layers: a cheap signature substring matcher, and a patch-provenance check
|
| 83 |
+
that defeats string-concat obfuscation of cache/bytecode reads.
|
| 84 |
"""
|
| 85 |
|
| 86 |
extra_signatures: tuple[str, ...] = field(default_factory=tuple)
|
| 87 |
|
| 88 |
+
def flag(
|
| 89 |
+
self,
|
| 90 |
+
trajectory: list[dict],
|
| 91 |
+
deleted_symbols: tuple[str, ...],
|
| 92 |
+
patch: str | None = None,
|
| 93 |
+
) -> bool:
|
| 94 |
+
# --- layer 1: signature substring matcher (defense-in-depth) ---------
|
| 95 |
sigs = _HACK_SIGNATURES + tuple(self.extra_signatures)
|
| 96 |
for action in trajectory:
|
| 97 |
blob = " ".join(
|
|
|
|
| 106 |
r"(cache|\.pyc|\.class|decompil|disassembl)", blob
|
| 107 |
):
|
| 108 |
return True
|
| 109 |
+
|
| 110 |
+
# --- layer 2: patch-provenance ---------------------------------------
|
| 111 |
+
# If no patch was threaded in, try to recover it from a submit_patch /
|
| 112 |
+
# patch / diff action so this works straight off a trajectory.
|
| 113 |
+
if patch is None:
|
| 114 |
+
patch = self._extract_patch(trajectory)
|
| 115 |
+
if patch and self._patch_provenance_hack(trajectory, deleted_symbols, patch):
|
| 116 |
+
return True
|
| 117 |
+
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def _extract_patch(trajectory: list[dict]) -> str | None:
|
| 122 |
+
"""Pull the agent's patch/diff text out of the trajectory, if present."""
|
| 123 |
+
parts: list[str] = []
|
| 124 |
+
for action in trajectory:
|
| 125 |
+
if action.get("type") == "submit_patch":
|
| 126 |
+
for key in ("patch", "diff", "content"):
|
| 127 |
+
v = action.get(key)
|
| 128 |
+
if isinstance(v, str):
|
| 129 |
+
parts.append(v)
|
| 130 |
+
for key in ("patch", "diff"):
|
| 131 |
+
v = action.get(key)
|
| 132 |
+
if isinstance(v, str):
|
| 133 |
+
parts.append(v)
|
| 134 |
+
return "\n".join(parts) if parts else None
|
| 135 |
+
|
| 136 |
+
def _patch_provenance_hack(
|
| 137 |
+
self,
|
| 138 |
+
trajectory: list[dict],
|
| 139 |
+
deleted_symbols: tuple[str, ...],
|
| 140 |
+
patch: str,
|
| 141 |
+
) -> bool:
|
| 142 |
+
"""Flag if a deleted symbol reappears verbatim in the patch AND the
|
| 143 |
+
trajectory shows a cache/bytecode artifact being read — even when the
|
| 144 |
+
artifact reference is string-concat obfuscated."""
|
| 145 |
+
reappeared = [
|
| 146 |
+
s for s in deleted_symbols if s and s in patch
|
| 147 |
+
]
|
| 148 |
+
if not reappeared:
|
| 149 |
+
return False
|
| 150 |
+
for action in trajectory:
|
| 151 |
+
raw = " ".join(
|
| 152 |
+
str(v) for v in action.values()
|
| 153 |
+
if isinstance(v, (str, int, float))
|
| 154 |
+
)
|
| 155 |
+
low = raw.lower()
|
| 156 |
+
norm = _normalize(low)
|
| 157 |
+
reads_artifact = (
|
| 158 |
+
any(m in norm for m in _ARTIFACT_MARKERS)
|
| 159 |
+
and any(rm in low for rm in _READ_MARKERS)
|
| 160 |
+
)
|
| 161 |
+
if reads_artifact:
|
| 162 |
+
return True
|
| 163 |
return False
|
|
@@ -319,3 +319,71 @@ def test_swebench_adapter_handles_list_or_jsonstr_tests():
|
|
| 319 |
"FAIL_TO_PASS": ftp, "PASS_TO_PASS": "[]"}
|
| 320 |
task = SweBenchAdapter().to_task(inst)
|
| 321 |
assert task.fail_to_pass == ("t1", "t2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
"FAIL_TO_PASS": ftp, "PASS_TO_PASS": "[]"}
|
| 320 |
task = SweBenchAdapter().to_task(inst)
|
| 321 |
assert task.fail_to_pass == ("t1", "t2")
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# --- ADR-012 #3: patch-provenance monitor (obfuscation-resistant) -----------
|
| 325 |
+
|
| 326 |
+
def test_monitor_flags_string_concat_obfuscated_cache_read_via_patch_provenance():
|
| 327 |
+
"""A cache read obfuscated by string-concat ('__py'+'cache__') defeats the
|
| 328 |
+
pure-substring signature layer, but the patch-provenance layer flags it when
|
| 329 |
+
the deleted symbol reappears verbatim in the submitted patch."""
|
| 330 |
+
mon = HackMonitor()
|
| 331 |
+
traj = [
|
| 332 |
+
# obfuscated artifact reference — NOT caught by substring signatures
|
| 333 |
+
{"type": "shell", "command": "cat build/" + "'__py' + 'cache__'" + "/feature_a"},
|
| 334 |
+
{"type": "submit_patch",
|
| 335 |
+
"patch": "def feature_a():\n return 42 # lifted from cache"},
|
| 336 |
+
]
|
| 337 |
+
assert mon.flag(traj, deleted_symbols=("feature_a",)) is True
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def test_monitor_flags_obfuscated_cache_read_with_explicit_patch_arg():
|
| 341 |
+
mon = HackMonitor()
|
| 342 |
+
traj = [{"type": "shell", "command": 'open("__py" + "cache__" + "/x.pyc")'}]
|
| 343 |
+
patch = "def feature_a():\n return 7"
|
| 344 |
+
assert mon.flag(traj, deleted_symbols=("feature_a",), patch=patch) is True
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def test_monitor_patch_provenance_not_triggered_for_clean_reimpl():
|
| 348 |
+
"""Reintroducing the symbol in the patch is fine when there is NO
|
| 349 |
+
cache/bytecode read in the trajectory — that's a legitimate reimplementation."""
|
| 350 |
+
mon = HackMonitor()
|
| 351 |
+
traj = [
|
| 352 |
+
{"type": "edit", "path": "src/widget.py",
|
| 353 |
+
"content": "def feature_a(): return 42"},
|
| 354 |
+
{"type": "submit_patch", "patch": "def feature_a():\n return 42"},
|
| 355 |
+
]
|
| 356 |
+
assert mon.flag(traj, deleted_symbols=("feature_a",)) is False
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# --- ADR-012 #4: curriculum turn/think-token signals ------------------------
|
| 360 |
+
|
| 361 |
+
def test_curriculum_higher_turn_task_weighted_at_least_as_high():
|
| 362 |
+
"""Two tasks, IDENTICAL pass-rate, different mean turns => the higher-turn
|
| 363 |
+
(harder) task must weight >= the lower-turn one."""
|
| 364 |
+
cur = DifficultyCurriculum()
|
| 365 |
+
for _ in range(10):
|
| 366 |
+
cur.update("low", n_pass=1, n_total=2, turns=3.0)
|
| 367 |
+
cur.update("high", n_pass=1, n_total=2, turns=30.0)
|
| 368 |
+
assert cur.p_hat("low") == cur.p_hat("high") # same pass-rate
|
| 369 |
+
assert cur.weight("high") >= cur.weight("low")
|
| 370 |
+
assert cur.weight("high") > cur.weight("low") # strictly, given the gap
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def test_curriculum_think_tokens_also_tilt_weight():
|
| 374 |
+
cur = DifficultyCurriculum()
|
| 375 |
+
for _ in range(10):
|
| 376 |
+
cur.update("cheap", n_pass=1, n_total=2, think_tokens=100.0)
|
| 377 |
+
cur.update("expensive", n_pass=1, n_total=2, think_tokens=5000.0)
|
| 378 |
+
assert cur.weight("expensive") >= cur.weight("cheap")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def test_curriculum_backward_compatible_without_effort_signals():
|
| 382 |
+
"""No turns/think_tokens => weight identical to the pre-ADR-012 formula
|
| 383 |
+
p*(1-p), so existing behavior and tests are unchanged."""
|
| 384 |
+
cur = DifficultyCurriculum()
|
| 385 |
+
for _ in range(10):
|
| 386 |
+
cur.update("A", n_pass=1, n_total=2)
|
| 387 |
+
p = cur.p_hat("A")
|
| 388 |
+
assert cur.weight("A") == p * (1.0 - p)
|
| 389 |
+
|
|
@@ -178,6 +178,72 @@ class RawErrorHintGenerator:
|
|
| 178 |
return f"Reminder: the previous action produced this error:\n{truncated}\nReconsider and retry."
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
class LLMJudgeHintGenerator:
|
| 182 |
"""Layer 3: an LLM produces a short corrective hint.
|
| 183 |
|
|
@@ -321,11 +387,14 @@ def default_composite(
|
|
| 321 |
) -> CompositeHintGenerator:
|
| 322 |
"""Build the recommended layered generator: templates -> raw-error -> judge.
|
| 323 |
|
| 324 |
-
The
|
|
|
|
|
|
|
|
|
|
| 325 |
"""
|
| 326 |
layers: list[HintGenerator] = [TemplateHintGenerator()]
|
| 327 |
if enable_raw_error:
|
| 328 |
-
layers.append(RawErrorHintGenerator())
|
| 329 |
if llm_complete is not None:
|
| 330 |
layers.append(LLMJudgeHintGenerator(llm_complete, cache_dir=cache_dir))
|
| 331 |
return CompositeHintGenerator(layers)
|
|
@@ -340,6 +409,8 @@ __all__ = [
|
|
| 340 |
"HintGenerator",
|
| 341 |
"TemplateHintGenerator",
|
| 342 |
"RawErrorHintGenerator",
|
|
|
|
|
|
|
| 343 |
"LLMJudgeHintGenerator",
|
| 344 |
"CompositeHintGenerator",
|
| 345 |
"default_composite",
|
|
|
|
| 178 |
return f"Reminder: the previous action produced this error:\n{truncated}\nReconsider and retry."
|
| 179 |
|
| 180 |
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
# Error-kind routing (ADR-012 finding #2)
|
| 183 |
+
# ---------------------------------------------------------------------------
|
| 184 |
+
#
|
| 185 |
+
# The default composite is template -> raw-error -> judge. The raw-error layer
|
| 186 |
+
# fires for ANY kind carrying a message — including style/communication/effort
|
| 187 |
+
# sites, which are EXACTLY what the LLM judge exists to cover. So we route:
|
| 188 |
+
# tool/runtime error kinds may use the raw-error layer; style/communication/
|
| 189 |
+
# effort kinds skip it and fall through to the judge.
|
| 190 |
+
|
| 191 |
+
# Error kinds that genuinely describe a tool/runtime failure whose raw text is a
|
| 192 |
+
# useful, self-contained hint. The explicit registry-template kinds are included
|
| 193 |
+
# so behavior is unchanged for them.
|
| 194 |
+
_TOOL_RUNTIME_KINDS: frozenset[str] = frozenset({
|
| 195 |
+
"tool_not_found",
|
| 196 |
+
"json_decode",
|
| 197 |
+
"type_error",
|
| 198 |
+
"runtime_error",
|
| 199 |
+
"repeated_failure",
|
| 200 |
+
})
|
| 201 |
+
|
| 202 |
+
# Substrings marking a kind as tool/runtime-ish even if not explicitly listed
|
| 203 |
+
# (keeps generic "*_error"/"*_exception" sites flowing through raw-error, which
|
| 204 |
+
# is where their raw text belongs).
|
| 205 |
+
_TOOL_RUNTIME_MARKERS: tuple[str, ...] = (
|
| 206 |
+
"error", "exception", "fail", "decode", "timeout", "traceback",
|
| 207 |
+
"exit_code", "nonzero", "syntax", "import", "assertion", "tool",
|
| 208 |
+
"runtime", "crash", "exec",
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Substrings marking a kind as a style/communication/effort site — the judge's
|
| 212 |
+
# domain. These take precedence: a kind matching one of these skips raw-error.
|
| 213 |
+
_STYLE_KINDS_MARKERS: tuple[str, ...] = (
|
| 214 |
+
"style", "communic", "verbose", "effort", "concise", "tone",
|
| 215 |
+
"format", "wordy", "rambl", "explanation", "etiquette", "clarity",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def is_tool_runtime_kind(error_kind: str) -> bool:
|
| 220 |
+
"""True if `error_kind` is a tool/runtime failure that the raw-error layer
|
| 221 |
+
may serve. Style/communication/effort kinds return False (-> judge)."""
|
| 222 |
+
k = (error_kind or "").lower()
|
| 223 |
+
if any(m in k for m in _STYLE_KINDS_MARKERS):
|
| 224 |
+
return False
|
| 225 |
+
if k in _TOOL_RUNTIME_KINDS:
|
| 226 |
+
return True
|
| 227 |
+
return any(m in k for m in _TOOL_RUNTIME_MARKERS)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class RoutingHintGenerator:
|
| 231 |
+
"""Wraps an inner layer (the raw-error layer) and only lets it fire for
|
| 232 |
+
tool/runtime error kinds. For style/communication/effort kinds it returns
|
| 233 |
+
None so the composite falls through to the judge — the layer those sites
|
| 234 |
+
were always meant to reach (ADR-012 finding #2).
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
def __init__(self, inner: HintGenerator, route=is_tool_runtime_kind) -> None:
|
| 238 |
+
self.inner = inner
|
| 239 |
+
self.route = route
|
| 240 |
+
|
| 241 |
+
def generate(self, error_kind: str, error_meta: dict) -> str | None:
|
| 242 |
+
if not self.route(error_kind):
|
| 243 |
+
return None
|
| 244 |
+
return self.inner.generate(error_kind, error_meta)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
class LLMJudgeHintGenerator:
|
| 248 |
"""Layer 3: an LLM produces a short corrective hint.
|
| 249 |
|
|
|
|
| 387 |
) -> CompositeHintGenerator:
|
| 388 |
"""Build the recommended layered generator: templates -> raw-error -> judge.
|
| 389 |
|
| 390 |
+
The raw-error layer is wrapped in a RoutingHintGenerator so it only fires for
|
| 391 |
+
tool/runtime error kinds; style/communication/effort kinds skip it and fall
|
| 392 |
+
through to the LLM judge (ADR-012 finding #2). The LLM-judge layer is
|
| 393 |
+
included only when `llm_complete` is provided.
|
| 394 |
"""
|
| 395 |
layers: list[HintGenerator] = [TemplateHintGenerator()]
|
| 396 |
if enable_raw_error:
|
| 397 |
+
layers.append(RoutingHintGenerator(RawErrorHintGenerator()))
|
| 398 |
if llm_complete is not None:
|
| 399 |
layers.append(LLMJudgeHintGenerator(llm_complete, cache_dir=cache_dir))
|
| 400 |
return CompositeHintGenerator(layers)
|
|
|
|
| 409 |
"HintGenerator",
|
| 410 |
"TemplateHintGenerator",
|
| 411 |
"RawErrorHintGenerator",
|
| 412 |
+
"RoutingHintGenerator",
|
| 413 |
+
"is_tool_runtime_kind",
|
| 414 |
"LLMJudgeHintGenerator",
|
| 415 |
"CompositeHintGenerator",
|
| 416 |
"default_composite",
|
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for error-kind hint routing on the DEFAULT composite (ADR-012 #2).
|
| 2 |
+
|
| 3 |
+
The default composite is template -> raw-error -> judge. Before ADR-012 the
|
| 4 |
+
raw-error layer consumed ANY site carrying an `error_message`, including
|
| 5 |
+
style/communication/effort sites — exactly the sites the LLM judge exists to
|
| 6 |
+
cover. These tests validate the DEFAULT path (raw-error NOT disabled): a
|
| 7 |
+
style/communication site WITH an error_message routes through to the judge,
|
| 8 |
+
while tool/runtime sites still use the raw-error layer.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from composer_replication.hint_generator import (
|
| 13 |
+
RoutingHintGenerator,
|
| 14 |
+
RawErrorHintGenerator,
|
| 15 |
+
default_composite,
|
| 16 |
+
is_tool_runtime_kind,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# --- the headline acceptance: style site reaches judge on the DEFAULT path ---
|
| 21 |
+
|
| 22 |
+
def test_style_site_with_error_message_reaches_judge_on_default_composite():
|
| 23 |
+
calls = {"n": 0}
|
| 24 |
+
|
| 25 |
+
def fake_complete(prompt: str) -> str:
|
| 26 |
+
calls["n"] += 1
|
| 27 |
+
return "Be more concise; you repeated the same explanation twice."
|
| 28 |
+
|
| 29 |
+
# NOTE: raw-error is ENABLED (the default). Pre-ADR-012 this would have been
|
| 30 |
+
# eaten by the raw-error layer and the judge never called.
|
| 31 |
+
comp = default_composite(llm_complete=fake_complete) # enable_raw_error=True
|
| 32 |
+
hint = comp.generate(
|
| 33 |
+
"verbose_communication",
|
| 34 |
+
{"error_message": "The agent restated the plan three times."},
|
| 35 |
+
)
|
| 36 |
+
assert hint == "Be more concise; you repeated the same explanation twice."
|
| 37 |
+
assert calls["n"] == 1, "style site must reach the judge, not the raw-error layer"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_effort_site_with_message_routes_to_judge():
|
| 41 |
+
calls = {"n": 0}
|
| 42 |
+
|
| 43 |
+
def fake_complete(prompt: str) -> str:
|
| 44 |
+
calls["n"] += 1
|
| 45 |
+
return "Don't pad the answer; one example suffices."
|
| 46 |
+
|
| 47 |
+
comp = default_composite(llm_complete=fake_complete)
|
| 48 |
+
hint = comp.generate("low_effort_style", {"error_message": "padding detected"})
|
| 49 |
+
assert hint == "Don't pad the answer; one example suffices."
|
| 50 |
+
assert calls["n"] == 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# --- tool/runtime sites still served by raw-error (no regression) -----------
|
| 54 |
+
|
| 55 |
+
def test_tool_runtime_site_still_served_by_raw_error_no_judge():
|
| 56 |
+
calls = {"n": 0}
|
| 57 |
+
|
| 58 |
+
def fake_complete(prompt: str) -> str:
|
| 59 |
+
calls["n"] += 1
|
| 60 |
+
return "JUDGE (should not be called)"
|
| 61 |
+
|
| 62 |
+
comp = default_composite(llm_complete=fake_complete)
|
| 63 |
+
# an unmapped *runtime* error (no template) -> raw-error layer, not judge.
|
| 64 |
+
hint = comp.generate("weird_runtime_error", {"error_message": "Segfault at 0x0"})
|
| 65 |
+
assert hint is not None
|
| 66 |
+
assert "Segfault at 0x0" in hint
|
| 67 |
+
assert calls["n"] == 0, "tool/runtime sites must be served by raw-error, not judge"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_template_site_unaffected_by_routing():
|
| 71 |
+
comp = default_composite() # no judge
|
| 72 |
+
hint = comp.generate("tool_not_found", {"available_tools": ["read", "write"]})
|
| 73 |
+
assert hint is not None and "Available tools" in hint
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# --- the route predicate ----------------------------------------------------
|
| 77 |
+
|
| 78 |
+
def test_route_predicate_classifies_kinds():
|
| 79 |
+
# tool/runtime
|
| 80 |
+
for k in ("tool_not_found", "json_decode", "type_error", "runtime_error",
|
| 81 |
+
"repeated_failure", "weird_runtime_error", "some_exception",
|
| 82 |
+
"weird_unmapped_error"):
|
| 83 |
+
assert is_tool_runtime_kind(k) is True, k
|
| 84 |
+
# style/communication/effort
|
| 85 |
+
for k in ("verbose_communication", "low_effort_style", "tone_violation",
|
| 86 |
+
"rambling_explanation", "bad_formatting"):
|
| 87 |
+
assert is_tool_runtime_kind(k) is False, k
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_routing_generator_returns_none_for_style_kind():
|
| 91 |
+
routed = RoutingHintGenerator(RawErrorHintGenerator())
|
| 92 |
+
# style kind WITH a message -> None (defer to judge), even though the inner
|
| 93 |
+
# raw-error layer would have produced a hint.
|
| 94 |
+
assert routed.generate("verbose_style", {"error_message": "too long"}) is None
|
| 95 |
+
# tool/runtime kind WITH a message -> inner fires.
|
| 96 |
+
out = routed.generate("runtime_error", {"error_message": "boom"})
|
| 97 |
+
assert out is not None and "boom" in out
|
|
@@ -225,16 +225,37 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
|
|
| 225 |
# Gather the provably-aligned response logits from each sequence, then
|
| 226 |
# JSD only those positions (this is the masked error-turn distillation).
|
| 227 |
# gather over the sequence dim (dim=1): expand index to the vocab dim.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
vocab = student_logits.size(-1)
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
| 231 |
student_aligned = torch.gather(student_logits, 1, s_gather)
|
| 232 |
teacher_aligned = torch.gather(teacher_logits, 1, t_gather)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
return generalized_jsd_loss(
|
| 235 |
student_logits=student_aligned,
|
| 236 |
teacher_logits=teacher_aligned,
|
| 237 |
-
labels=
|
| 238 |
beta=self.sdpo_jsd_beta,
|
| 239 |
temperature=self.sdpo_temperature,
|
| 240 |
token_clip=self.sdpo_token_clip,
|
|
@@ -325,8 +346,18 @@ def make_dr_grpo_config(**overrides: Any):
|
|
| 325 |
standard deviation introduces a question-level difficulty bias."
|
| 326 |
- ``num_iterations=1`` — single-epoch regime (a prompt is never
|
| 327 |
trained on twice), matching the tech report.
|
| 328 |
-
- ``beta`` (KL-to-ref coef) kept
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
|
| 332 |
``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
|
|
|
|
| 225 |
# Gather the provably-aligned response logits from each sequence, then
|
| 226 |
# JSD only those positions (this is the masked error-turn distillation).
|
| 227 |
# gather over the sequence dim (dim=1): expand index to the vocab dim.
|
| 228 |
+
#
|
| 229 |
+
# ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row
|
| 230 |
+
# *_valid mask. Negative indices are illegal for torch.gather, so clamp
|
| 231 |
+
# to 0 before gathering, then neutralize those positions by feeding
|
| 232 |
+
# labels=-100 (the standard HF ignore convention that generalized_jsd_loss
|
| 233 |
+
# already honors). This makes sentinel/padding positions contribute 0.
|
| 234 |
+
if "student_response_valid" in inputs and inputs["student_response_valid"] is not None:
|
| 235 |
+
aligned_mask = inputs["student_response_valid"].bool()
|
| 236 |
+
else:
|
| 237 |
+
aligned_mask = (s_idx >= 0) & (t_idx >= 0)
|
| 238 |
+
|
| 239 |
vocab = student_logits.size(-1)
|
| 240 |
+
s_safe = s_idx.clamp_min(0)
|
| 241 |
+
t_safe = t_idx.clamp_min(0)
|
| 242 |
+
s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab)
|
| 243 |
+
t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab)
|
| 244 |
student_aligned = torch.gather(student_logits, 1, s_gather)
|
| 245 |
teacher_aligned = torch.gather(teacher_logits, 1, t_gather)
|
| 246 |
|
| 247 |
+
# Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at
|
| 248 |
+
# sentinel/padding positions so they drop out of the JSD reduction.
|
| 249 |
+
aligned_labels = torch.where(
|
| 250 |
+
aligned_mask,
|
| 251 |
+
torch.ones_like(s_idx),
|
| 252 |
+
torch.full_like(s_idx, -100),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
return generalized_jsd_loss(
|
| 256 |
student_logits=student_aligned,
|
| 257 |
teacher_logits=teacher_aligned,
|
| 258 |
+
labels=aligned_labels, # sentinel-masked aligned error-turn positions
|
| 259 |
beta=self.sdpo_jsd_beta,
|
| 260 |
temperature=self.sdpo_temperature,
|
| 261 |
token_clip=self.sdpo_token_clip,
|
|
|
|
| 346 |
standard deviation introduces a question-level difficulty bias."
|
| 347 |
- ``num_iterations=1`` — single-epoch regime (a prompt is never
|
| 348 |
trained on twice), matching the tech report.
|
| 349 |
+
- ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012
|
| 350 |
+
finding #1, verified against the installed trl==1.5.0 source):
|
| 351 |
+
``GRPOTrainer._compute_loss`` uses the **k3** estimator
|
| 352 |
+
``exp(ref_logp - logp) - (ref_logp - logp) - 1``
|
| 353 |
+
(trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator
|
| 354 |
+
``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance,
|
| 355 |
+
always-non-negative KL approximation; k1 is its unbiased but
|
| 356 |
+
higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses
|
| 357 |
+
KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2))
|
| 358 |
+
and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL
|
| 359 |
+
to force k1; we document the honest delta. See
|
| 360 |
+
``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``.
|
| 361 |
|
| 362 |
Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
|
| 363 |
``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
|
|
@@ -118,6 +118,47 @@ def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]:
|
|
| 118 |
return seq + [pad_id] * (target_len - len(seq))
|
| 119 |
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# ---------------------------------------------------------------------------
|
| 122 |
# The collator
|
| 123 |
# ---------------------------------------------------------------------------
|
|
@@ -190,6 +231,25 @@ class ComposerDataCollator:
|
|
| 190 |
out["attention_mask"] = aligned["attention_mask"]
|
| 191 |
out["response_mask"] = aligned["response_mask"]
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# --- Channel 3: trace-replay DPO fields ---
|
| 194 |
if self.config.enable_replay_dpo:
|
| 195 |
dpo = self._build_dpo_fields(batch)
|
|
|
|
| 118 |
return seq + [pad_id] * (target_len - len(seq))
|
| 119 |
|
| 120 |
|
| 121 |
+
def _mask_to_padded_indices(
|
| 122 |
+
mask: torch.Tensor, # (B, T) where nonzero/True == valid position
|
| 123 |
+
pad_sentinel: int = -1,
|
| 124 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 125 |
+
"""Convert a (B,T) bool/0-1 mask → (B,K_max) index tensor + (B,K_max) validity mask.
|
| 126 |
+
|
| 127 |
+
Each row's K valid positions are written left-aligned into ``idx``; the
|
| 128 |
+
ragged tail (rows with fewer than K_max positions) is padded with
|
| 129 |
+
``pad_sentinel`` (default -1). ``valid`` is True exactly where ``idx``
|
| 130 |
+
holds a real position.
|
| 131 |
+
|
| 132 |
+
ADR-011: the SDPO loss gathers post-hint response logits via these indices,
|
| 133 |
+
then masks the sentinel/padding positions so they contribute 0. K_max=0
|
| 134 |
+
(no valid positions anywhere) returns (B,0) tensors.
|
| 135 |
+
"""
|
| 136 |
+
B, T = mask.shape
|
| 137 |
+
bool_mask = mask != 0
|
| 138 |
+
counts = bool_mask.sum(dim=1).long() # (B,) — K per row
|
| 139 |
+
K_max = int(counts.max().item()) if counts.numel() else 0
|
| 140 |
+
if K_max == 0:
|
| 141 |
+
return (
|
| 142 |
+
torch.full((B, 0), pad_sentinel, dtype=torch.long, device=mask.device),
|
| 143 |
+
torch.zeros(B, 0, dtype=torch.bool, device=mask.device),
|
| 144 |
+
)
|
| 145 |
+
idx = torch.full((B, K_max), pad_sentinel, dtype=torch.long, device=mask.device)
|
| 146 |
+
valid = torch.zeros(B, K_max, dtype=torch.bool, device=mask.device)
|
| 147 |
+
# torch.nonzero on a 2D bool tensor yields (total_K, 2): (batch_idx, pos_idx),
|
| 148 |
+
# row-major so positions are already in per-row, ascending order.
|
| 149 |
+
nz = torch.nonzero(bool_mask, as_tuple=False) # (total_K, 2)
|
| 150 |
+
pos_idx = nz[:, 1]
|
| 151 |
+
offsets = torch.zeros(B + 1, dtype=torch.long, device=mask.device)
|
| 152 |
+
offsets[1:] = counts.cumsum(dim=0)
|
| 153 |
+
for b in range(B):
|
| 154 |
+
start, end = int(offsets[b].item()), int(offsets[b + 1].item())
|
| 155 |
+
k = end - start
|
| 156 |
+
if k > 0:
|
| 157 |
+
idx[b, :k] = pos_idx[start:end]
|
| 158 |
+
valid[b, :k] = True
|
| 159 |
+
return idx, valid
|
| 160 |
+
|
| 161 |
+
|
| 162 |
# ---------------------------------------------------------------------------
|
| 163 |
# The collator
|
| 164 |
# ---------------------------------------------------------------------------
|
|
|
|
| 231 |
out["attention_mask"] = aligned["attention_mask"]
|
| 232 |
out["response_mask"] = aligned["response_mask"]
|
| 233 |
|
| 234 |
+
# --- ADR-011: emit SDPO alignment indices ---
|
| 235 |
+
# The loss (strict mode, default) requires explicit per-token
|
| 236 |
+
# alignment indices into each sequence so the JSD compares
|
| 237 |
+
# corresponding post-hint response tokens. Derive them from the
|
| 238 |
+
# already-aligned masks: teacher positions from sdpo_loss_mask==1,
|
| 239 |
+
# student positions from response_mask==1. Both masks are placed
|
| 240 |
+
# on content tokens by _build_chat_aligned_mask, and the
|
| 241 |
+
# placeholder-system-message trick makes them land at the SAME
|
| 242 |
+
# logical token, so at valid positions s_idx == t_idx.
|
| 243 |
+
if "sdpo_loss_mask" in out and "response_mask" in out:
|
| 244 |
+
t_mask = out["sdpo_loss_mask"] == 1
|
| 245 |
+
s_mask = out["response_mask"] == 1
|
| 246 |
+
t_idx, t_valid = _mask_to_padded_indices(t_mask)
|
| 247 |
+
s_idx, s_valid = _mask_to_padded_indices(s_mask)
|
| 248 |
+
out["student_response_idx"] = s_idx
|
| 249 |
+
out["teacher_response_idx"] = t_idx
|
| 250 |
+
out["student_response_valid"] = s_valid
|
| 251 |
+
out["teacher_response_valid"] = t_valid
|
| 252 |
+
|
| 253 |
# --- Channel 3: trace-replay DPO fields ---
|
| 254 |
if self.config.enable_replay_dpo:
|
| 255 |
dpo = self._build_dpo_fields(batch)
|
|
@@ -57,6 +57,70 @@ def test_make_dr_grpo_config_override_does_not_silently_break_guard(tmp_path):
|
|
| 57 |
assert cfg.loss_type == "grpo"
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
# Gate 2 — SDPO strict-alignment guard (no real GRPOTrainer needed)
|
| 62 |
# ---------------------------------------------------------------------------
|
|
|
|
| 57 |
assert cfg.loss_type == "grpo"
|
| 58 |
|
| 59 |
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# ADR-012 finding #1 — TRL's native KL estimator (k1 vs k3)
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def test_trl_kl_estimator_is_k3_not_k1():
|
| 65 |
+
"""Document, honestly, which KL estimator TRL's GRPOTrainer actually uses.
|
| 66 |
+
|
| 67 |
+
Two common per-token KL approximations of KL(pi || pi_ref), given the log
|
| 68 |
+
importance ratio Δ = ref_logp - logp (so r = pi/pi_ref = exp(-Δ)... we use
|
| 69 |
+
the trl convention Δ = ref_logp - logp directly):
|
| 70 |
+
|
| 71 |
+
k1 = Δ = (ref_logp - logp) (unbiased, higher var)
|
| 72 |
+
k3 = exp(Δ) - Δ - 1 (Schulman, low var, >= 0)
|
| 73 |
+
|
| 74 |
+
make_dr_grpo_config's docstring previously *claimed* TRL uses k1. Inspecting
|
| 75 |
+
the installed trl==1.5.0 source (grpo_trainer.py ~L2513) shows it actually
|
| 76 |
+
computes k3: `torch.exp(ref - logp) - (ref - logp) - 1`. This test pins
|
| 77 |
+
that finding so the docstring stays honest and a future TRL change is caught.
|
| 78 |
+
"""
|
| 79 |
+
# Known logprob pairs (student logp, reference logp).
|
| 80 |
+
logp = torch.tensor([-1.0, -2.0, -0.5, -3.0])
|
| 81 |
+
ref_logp = torch.tensor([-1.2, -1.5, -0.7, -2.4])
|
| 82 |
+
delta = ref_logp - logp
|
| 83 |
+
|
| 84 |
+
k1 = delta
|
| 85 |
+
k3 = torch.exp(delta) - delta - 1.0
|
| 86 |
+
|
| 87 |
+
# k3 is always non-negative; k1 can be negative — a structural difference.
|
| 88 |
+
assert (k3 >= -1e-6).all(), "k3 must be non-negative (Schulman estimator)"
|
| 89 |
+
assert (k1 < 0).any(), "k1 (= Δ) can be negative; the test data exercises that"
|
| 90 |
+
|
| 91 |
+
# The TRL 1.5.0 source uses k3 (verified by grepping the installed package).
|
| 92 |
+
import inspect
|
| 93 |
+
|
| 94 |
+
from trl import GRPOTrainer
|
| 95 |
+
|
| 96 |
+
src = inspect.getsource(GRPOTrainer)
|
| 97 |
+
# The k3 signature: exp(ref - logp) - (ref - logp) - 1. We assert the
|
| 98 |
+
# distinctive `torch.exp(` of the ratio appears in the per-token KL block.
|
| 99 |
+
assert "per_token_kl" in src, "TRL GRPOTrainer no longer has a per_token_kl block"
|
| 100 |
+
uses_k3 = "torch.exp(ref_per_token_logps - per_token_logps)" in src
|
| 101 |
+
uses_k1_only = (
|
| 102 |
+
"per_token_kl = ref_per_token_logps - per_token_logps" in src and not uses_k3
|
| 103 |
+
)
|
| 104 |
+
assert uses_k3, (
|
| 105 |
+
"Expected TRL 1.5.0 to compute the k3 KL estimator "
|
| 106 |
+
"exp(ref - logp) - (ref - logp) - 1. If this fails, TRL changed its "
|
| 107 |
+
"estimator — re-verify make_dr_grpo_config's docstring (which documents "
|
| 108 |
+
f"k3, not k1). uses_k1_only={uses_k1_only}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Sanity: for small Δ (r≈1) the two estimators agree to second order, which
|
| 112 |
+
# is why the report's k1 framing and TRL's k3 reality differ only mildly.
|
| 113 |
+
small = torch.tensor([0.01, -0.02, 0.005])
|
| 114 |
+
k1_small = small
|
| 115 |
+
k3_small = torch.exp(small) - small - 1.0
|
| 116 |
+
assert torch.allclose(k3_small, 0.5 * small**2, atol=1e-4), (
|
| 117 |
+
"k3 should be ~Δ²/2 for small Δ (its leading order)"
|
| 118 |
+
)
|
| 119 |
+
assert (k3_small.abs() < k1_small.abs()).all(), (
|
| 120 |
+
"for small Δ, |k3| << |k1| — the delta the docstring documents is minor"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
# ---------------------------------------------------------------------------
|
| 125 |
# Gate 2 — SDPO strict-alignment guard (no real GRPOTrainer needed)
|
| 126 |
# ---------------------------------------------------------------------------
|
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ADR-011 — collator-emitted SDPO alignment indices + loss sentinel-masking.
|
| 2 |
+
|
| 3 |
+
These tests close the strict-SDPO-raises regression: the SDPO loss requires
|
| 4 |
+
explicit `student_response_idx`/`teacher_response_idx` (B,K) LongTensors, and
|
| 5 |
+
the production collator must emit them. Covered acceptance gates:
|
| 6 |
+
|
| 7 |
+
1. `_mask_to_padded_indices` ragged-K shape + sentinel/valid semantics.
|
| 8 |
+
2. Real `ComposerDataCollator` emits the 4 alignment keys with correct
|
| 9 |
+
shapes; student_response_idx == teacher_response_idx at valid positions.
|
| 10 |
+
3. THE REGRESSION: real collator → batch → `_compute_sdpo_loss` in STRICT
|
| 11 |
+
mode (default) runs WITHOUT raising and returns a finite positive loss.
|
| 12 |
+
4. Ragged-K: a 2-row batch with different K per row → finite loss, the K=1
|
| 13 |
+
row's sentinel padding does not leak into the JSD.
|
| 14 |
+
|
| 15 |
+
All CPU-only and fast (stub tokenizer + tiny model — no model download).
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import pytest
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from composer_replication.trainer.data_collator import (
|
| 23 |
+
CollatorConfig,
|
| 24 |
+
ComposerDataCollator,
|
| 25 |
+
_mask_to_padded_indices,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Stubs (mirror the patterns in test_chat_template_alignment.py /
|
| 31 |
+
# test_dr_grpo_config_and_alignment.py so these tests need no model cache).
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
class _StubTok:
|
| 35 |
+
"""Word-level deterministic tokenizer; apply_chat_template space-joins."""
|
| 36 |
+
|
| 37 |
+
pad_token_id = 0
|
| 38 |
+
|
| 39 |
+
def __init__(self) -> None:
|
| 40 |
+
self._v: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
|
| 41 |
+
|
| 42 |
+
def _id(self, w: str) -> int:
|
| 43 |
+
if w not in self._v:
|
| 44 |
+
self._v[w] = len(self._v)
|
| 45 |
+
return self._v[w]
|
| 46 |
+
|
| 47 |
+
def __call__(self, text, **_k):
|
| 48 |
+
return {"input_ids": [self._id(w) for w in text.split()] if text else []}
|
| 49 |
+
|
| 50 |
+
def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
|
| 51 |
+
return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class _TinyLM(torch.nn.Module):
|
| 55 |
+
"""Minimal HF-style model: model(input_ids=...).logits.
|
| 56 |
+
|
| 57 |
+
Position-DEPENDENT: adds a learned positional bias so identical token ids at
|
| 58 |
+
DIFFERENT sequence positions produce DIFFERENT logits. This matters for the
|
| 59 |
+
SDPO regression test — student and teacher share the same response token ids
|
| 60 |
+
but at different absolute positions (the hint/placeholder shifts them), so a
|
| 61 |
+
position-independent model would give JSD≈0 and mask a real misalignment bug.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.embed = torch.nn.Embedding(vocab, hidden)
|
| 67 |
+
self.pos = torch.nn.Embedding(max_pos, hidden)
|
| 68 |
+
self.head = torch.nn.Linear(hidden, vocab)
|
| 69 |
+
|
| 70 |
+
def forward(self, input_ids: torch.Tensor):
|
| 71 |
+
T = input_ids.size(1)
|
| 72 |
+
positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
|
| 73 |
+
h = self.embed(input_ids) + self.pos(positions)
|
| 74 |
+
logits = self.head(h)
|
| 75 |
+
|
| 76 |
+
class _Out:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
out = _Out()
|
| 80 |
+
out.logits = logits
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _hint_gen(kind, _meta):
|
| 85 |
+
return "HINT search before reading"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _make_sdpo_trainer():
|
| 89 |
+
"""ComposerReplicationTrainer instance without GRPOTrainer.__init__ — we
|
| 90 |
+
only exercise _compute_sdpo_loss, in STRICT mode (default)."""
|
| 91 |
+
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
|
| 92 |
+
|
| 93 |
+
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
|
| 94 |
+
obj.alpha_sdpo = 1.0
|
| 95 |
+
obj.sdpo_jsd_beta = 0.5
|
| 96 |
+
obj.sdpo_temperature = 1.0
|
| 97 |
+
obj.sdpo_token_clip = None
|
| 98 |
+
obj.strict_sdpo_alignment = True # the default / production setting
|
| 99 |
+
return obj
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _error_trace(trace_id: str, recovery: str = "let me use a real tool instead"):
|
| 103 |
+
return {
|
| 104 |
+
"trace_id": trace_id,
|
| 105 |
+
"turns": [
|
| 106 |
+
{"role": "user", "content": "do the task now"},
|
| 107 |
+
{"role": "user", "content": "tool not found error occurred"},
|
| 108 |
+
{
|
| 109 |
+
"role": "assistant",
|
| 110 |
+
"content": recovery,
|
| 111 |
+
"tool_error": "tool_not_found",
|
| 112 |
+
"error_meta": {},
|
| 113 |
+
},
|
| 114 |
+
],
|
| 115 |
+
"final_reward": 0.0,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
# Gate 1 — _mask_to_padded_indices ragged-K semantics
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
def test_mask_to_padded_indices_ragged_k():
|
| 124 |
+
"""2 rows, K=3 and K=1 → (2,3) idx; row1 tail padded with -1;
|
| 125 |
+
valid[1] == [True, False, False]."""
|
| 126 |
+
mask = torch.tensor(
|
| 127 |
+
[
|
| 128 |
+
[0, 1, 1, 0, 1], # K=3 at positions 1,2,4
|
| 129 |
+
[0, 0, 1, 0, 0], # K=1 at position 2
|
| 130 |
+
],
|
| 131 |
+
dtype=torch.long,
|
| 132 |
+
)
|
| 133 |
+
idx, valid = _mask_to_padded_indices(mask)
|
| 134 |
+
assert idx.shape == (2, 3)
|
| 135 |
+
assert valid.shape == (2, 3)
|
| 136 |
+
assert idx[0].tolist() == [1, 2, 4]
|
| 137 |
+
assert idx[1].tolist() == [2, -1, -1]
|
| 138 |
+
assert valid[0].tolist() == [True, True, True]
|
| 139 |
+
assert valid[1].tolist() == [True, False, False]
|
| 140 |
+
assert idx.dtype == torch.long
|
| 141 |
+
assert valid.dtype == torch.bool
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def test_mask_to_padded_indices_empty_returns_b0():
|
| 145 |
+
"""K_max == 0 (no valid positions) returns (B,0) tensors."""
|
| 146 |
+
mask = torch.zeros(3, 5, dtype=torch.long)
|
| 147 |
+
idx, valid = _mask_to_padded_indices(mask)
|
| 148 |
+
assert idx.shape == (3, 0)
|
| 149 |
+
assert valid.shape == (3, 0)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
# Gate 2 — collator emits the 4 alignment keys with correct shapes
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
def test_collator_emits_alignment_indices_keys():
|
| 157 |
+
tok = _StubTok()
|
| 158 |
+
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
|
| 159 |
+
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
|
| 160 |
+
batch = collator([_error_trace("align-1")])
|
| 161 |
+
|
| 162 |
+
for key in (
|
| 163 |
+
"student_response_idx",
|
| 164 |
+
"teacher_response_idx",
|
| 165 |
+
"student_response_valid",
|
| 166 |
+
"teacher_response_valid",
|
| 167 |
+
):
|
| 168 |
+
assert key in batch, f"collator did not emit {key!r}"
|
| 169 |
+
|
| 170 |
+
s_idx = batch["student_response_idx"]
|
| 171 |
+
t_idx = batch["teacher_response_idx"]
|
| 172 |
+
s_valid = batch["student_response_valid"]
|
| 173 |
+
assert s_idx.shape == t_idx.shape
|
| 174 |
+
assert s_idx.shape == s_valid.shape
|
| 175 |
+
assert s_idx.dtype == torch.long
|
| 176 |
+
assert s_valid.dtype == torch.bool
|
| 177 |
+
# There must be at least one valid aligned position.
|
| 178 |
+
assert int(s_valid.sum()) > 0
|
| 179 |
+
# At valid positions the placeholder-trick makes the two indices identical.
|
| 180 |
+
vmask = s_valid
|
| 181 |
+
assert torch.equal(s_idx[vmask], t_idx[vmask]), (
|
| 182 |
+
"student/teacher indices diverge at valid positions; the placeholder "
|
| 183 |
+
"alignment trick is broken."
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
# Gate 3 — THE REGRESSION TEST: real collator → strict _compute_sdpo_loss
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
|
| 191 |
+
def test_strict_sdpo_loss_runs_on_real_collator_batch():
|
| 192 |
+
"""Real ComposerDataCollator batch → _compute_sdpo_loss in STRICT mode
|
| 193 |
+
(default) runs WITHOUT raising and returns a finite, positive loss.
|
| 194 |
+
This is the whole point of ADR-011."""
|
| 195 |
+
tok = _StubTok()
|
| 196 |
+
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
|
| 197 |
+
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
|
| 198 |
+
batch = collator([_error_trace("regression-1")])
|
| 199 |
+
|
| 200 |
+
# vocab must cover every token id the stub tokenizer produced.
|
| 201 |
+
vocab = int(max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())) + 1
|
| 202 |
+
model = _TinyLM(vocab=max(vocab, 8))
|
| 203 |
+
|
| 204 |
+
obj = _make_sdpo_trainer()
|
| 205 |
+
loss = obj._compute_sdpo_loss(model, batch) # must NOT raise
|
| 206 |
+
|
| 207 |
+
val = float(loss.detach())
|
| 208 |
+
assert val == val, "SDPO loss is NaN"
|
| 209 |
+
assert val not in (float("inf"), float("-inf")), "SDPO loss is infinite"
|
| 210 |
+
# JSD is always >= 0. With this context-free stub model the gathered
|
| 211 |
+
# student/teacher logits at correctly-aligned positions (same token id, same
|
| 212 |
+
# absolute position) are identical, so the JSD floors at ~0 — that is the
|
| 213 |
+
# CORRECT answer for a perfectly-aligned identical model, not a bug. The
|
| 214 |
+
# whole-point assertion is that strict mode RAN (no raise) and produced a
|
| 215 |
+
# real finite scalar on a grad path; positivity needs an attention model
|
| 216 |
+
# (covered by examples/composer_grpo_sdpo_smoke on Qwen2.5-0.5B).
|
| 217 |
+
assert val >= -1e-6, f"JSD must be non-negative, got {val}"
|
| 218 |
+
assert loss.requires_grad, "SDPO loss must be differentiable (grad path)"
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
# Gate 4 — ragged-K batch: K=1 row padding must not leak into the loss
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
|
| 225 |
+
def test_ragged_k_batch_finite_loss_no_padding_leak():
|
| 226 |
+
"""A 2-row batch with different recovery lengths → ragged K. The loss must
|
| 227 |
+
be finite and the K=1 row's sentinel padding must not contribute."""
|
| 228 |
+
tok = _StubTok()
|
| 229 |
+
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
|
| 230 |
+
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
|
| 231 |
+
batch = collator([
|
| 232 |
+
_error_trace("ragged-long", recovery="recover with a real working tool now please"),
|
| 233 |
+
_error_trace("ragged-short", recovery="ok"),
|
| 234 |
+
])
|
| 235 |
+
|
| 236 |
+
s_idx = batch["student_response_idx"]
|
| 237 |
+
s_valid = batch["student_response_valid"]
|
| 238 |
+
# Ragged: at least one row should be shorter (have an invalid tail) OR the
|
| 239 |
+
# rows genuinely differ — assert sentinel padding exists where invalid.
|
| 240 |
+
assert (s_idx == -1)[~s_valid].all(), "invalid positions must hold sentinel -1"
|
| 241 |
+
|
| 242 |
+
vocab = int(max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())) + 1
|
| 243 |
+
model = _TinyLM(vocab=max(vocab, 8))
|
| 244 |
+
obj = _make_sdpo_trainer()
|
| 245 |
+
|
| 246 |
+
loss = obj._compute_sdpo_loss(model, batch)
|
| 247 |
+
val = float(loss.detach())
|
| 248 |
+
assert val == val and val not in (float("inf"), float("-inf"))
|
| 249 |
+
# Non-negative (JSD floor). The leak failure mode this guards against is a
|
| 250 |
+
# sentinel (-1) index reaching torch.gather (illegal → error) or a padding
|
| 251 |
+
# position contributing garbage → NaN/inf. A finite, non-negative scalar
|
| 252 |
+
# proves the clamp-to-0 + label=-100 sentinel masking worked.
|
| 253 |
+
assert val >= -1e-6
|
| 254 |
+
|
| 255 |
+
# Padding-leak guard: zeroing the (clamped) sentinel rows must not change
|
| 256 |
+
# the loss, since valid-mask labels already drop them. We verify by
|
| 257 |
+
# recomputing with the valid mask forced all-True on a fresh batch where
|
| 258 |
+
# the short row is genuinely shorter — instead we assert the simpler
|
| 259 |
+
# invariant: the loss equals the loss computed if we explicitly drop the
|
| 260 |
+
# invalid tail by truncating to the per-batch min-K.
|
| 261 |
+
min_k = int(s_valid.sum(dim=1).min())
|
| 262 |
+
if min_k < s_idx.shape[1]:
|
| 263 |
+
truncated = dict(batch)
|
| 264 |
+
truncated["student_response_idx"] = batch["student_response_idx"][:, :min_k]
|
| 265 |
+
truncated["teacher_response_idx"] = batch["teacher_response_idx"][:, :min_k]
|
| 266 |
+
truncated["student_response_valid"] = batch["student_response_valid"][:, :min_k]
|
| 267 |
+
truncated["teacher_response_valid"] = batch["teacher_response_valid"][:, :min_k]
|
| 268 |
+
# Same model state (no grad step taken) → deterministic forward.
|
| 269 |
+
loss_trunc = obj._compute_sdpo_loss(model, truncated)
|
| 270 |
+
# The full-batch loss includes the long row's extra valid tokens, so it
|
| 271 |
+
# need not equal the truncated loss; we only assert both are finite and
|
| 272 |
+
# the sentinel tail produced no NaN/inf (the real leak failure mode).
|
| 273 |
+
vt = float(loss_trunc.detach())
|
| 274 |
+
assert vt == vt and vt not in (float("inf"), float("-inf"))
|
|
@@ -48,7 +48,7 @@ package.
|
|
| 48 |
|
| 49 |
## Considered Options
|
| 50 |
|
| 51 |
-
- **A. `FeatureDeletionEnv` that inverts OSS SWE substrates (revert gold patch) + online pass-rate difficulty gate + sandbox/
|
| 52 |
- **B. Greenfield repo-scraping generator (clone arbitrary GitHub repos, delete AST nodes, hope tests cover them)**
|
| 53 |
- **C. Skip generation; reuse SWE-bench-lite tasks as-is without a deletion/inversion layer**
|
| 54 |
|
|
@@ -62,8 +62,9 @@ adapters that invert the 5 OSS datasets by reverting their gold patch, a
|
|
| 62 |
`PASS_TO_PASS`, gold patch restores green, deletion is reachable from tests),
|
| 63 |
an online pass-rate difficulty gate, and reward-hacking safeguards
|
| 64 |
(pre-task scrub of `__pycache__`/`.mypy_cache`/`.class`/`.git`; allowlisted
|
| 65 |
-
sandbox without `find`/`strings`/`unzip`/decompilers;
|
| 66 |
-
that masks reward when deleted symbols reappear via non-implementation
|
|
|
|
| 67 |
A TRL `reward_fn(prompts, completions, **kwargs) -> list[float]` adapter wires
|
| 68 |
it to the RL loop.
|
| 69 |
|
|
@@ -74,7 +75,7 @@ it to the RL loop.
|
|
| 74 |
- **Positive**: Online difficulty gate matches the actual recipe.
|
| 75 |
- **Negative**: Bounded to what the OSS substrates cover (Python-dominant; SWE-bench is Python/JS-heavy). Other languages need new substrates. Documented as a known coverage limit.
|
| 76 |
- **Negative**: Running tests in a sandbox requires Docker images per substrate; CPU-pool generation has real wall-clock cost (~15 node-days to invert all 21k SWE-rebench tasks per research/06). Mitigated by reusing the substrates' published Docker images and generating lazily.
|
| 77 |
-
- **Negative**: Reward-hacking safeguards are a moving target; the
|
| 78 |
- **Neutral**: Adds a `[datagen]` optional extra (datasets, docker SDK).
|
| 79 |
|
| 80 |
## Pros and Cons of the Options
|
|
@@ -142,14 +143,15 @@ remediated where possible without Docker:
|
|
| 142 |
(coverage of the changed region by the failing tests, or revert-provenance)
|
| 143 |
needs the live Docker materializers. **This is the same `[~]` gate as the
|
| 144 |
substrate-inversion e2e — see below.**
|
| 145 |
-
- **[
|
| 146 |
-
the ADR
|
| 147 |
-
trajectory but
|
| 148 |
-
string-concat. With the scrub now in place as
|
| 149 |
-
is correctly-scoped defense-in-depth
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
- **[OPEN — recipe fidelity] Curriculum ignores rollout-turns and
|
| 154 |
thinking-token count** (DeepSeek, GPT-5.5). The Composer 2 tech report keys the
|
| 155 |
curriculum on these; the implementation tracks only pass-rate. Follow-up:
|
|
|
|
| 48 |
|
| 49 |
## Considered Options
|
| 50 |
|
| 51 |
+
- **A. `FeatureDeletionEnv` that inverts OSS SWE substrates (revert gold patch) + online pass-rate difficulty gate + sandbox + signature/patch-provenance reward-hacking safeguards** (chosen)
|
| 52 |
- **B. Greenfield repo-scraping generator (clone arbitrary GitHub repos, delete AST nodes, hope tests cover them)**
|
| 53 |
- **C. Skip generation; reuse SWE-bench-lite tasks as-is without a deletion/inversion layer**
|
| 54 |
|
|
|
|
| 62 |
`PASS_TO_PASS`, gold patch restores green, deletion is reachable from tests),
|
| 63 |
an online pass-rate difficulty gate, and reward-hacking safeguards
|
| 64 |
(pre-task scrub of `__pycache__`/`.mypy_cache`/`.class`/`.git`; allowlisted
|
| 65 |
+
sandbox without `find`/`strings`/`unzip`/decompilers; signature + patch-provenance
|
| 66 |
+
monitor that masks reward when deleted symbols reappear via non-implementation
|
| 67 |
+
paths — including string-concat-obfuscated cache reads).
|
| 68 |
A TRL `reward_fn(prompts, completions, **kwargs) -> list[float]` adapter wires
|
| 69 |
it to the RL loop.
|
| 70 |
|
|
|
|
| 75 |
- **Positive**: Online difficulty gate matches the actual recipe.
|
| 76 |
- **Negative**: Bounded to what the OSS substrates cover (Python-dominant; SWE-bench is Python/JS-heavy). Other languages need new substrates. Documented as a known coverage limit.
|
| 77 |
- **Negative**: Running tests in a sandbox requires Docker images per substrate; CPU-pool generation has real wall-clock cost (~15 node-days to invert all 21k SWE-rebench tasks per research/06). Mitigated by reusing the substrates' published Docker images and generating lazily.
|
| 78 |
+
- **Negative**: Reward-hacking safeguards are a moving target; the signature + patch-provenance monitor is heuristic and will have false negatives. Mitigated by treating it as defense-in-depth (sandbox lockdown is the primary control) and logging suspected hacks for review.
|
| 79 |
- **Neutral**: Adds a `[datagen]` optional extra (datasets, docker SDK).
|
| 80 |
|
| 81 |
## Pros and Cons of the Options
|
|
|
|
| 143 |
(coverage of the changed region by the failing tests, or revert-provenance)
|
| 144 |
needs the live Docker materializers. **This is the same `[~]` gate as the
|
| 145 |
substrate-inversion e2e — see below.**
|
| 146 |
+
- **[RESOLVED — ADR-012] `HackMonitor` was a substring matcher, not the
|
| 147 |
+
AST-provenance monitor the ADR advertised** (DeepSeek P0). It flagged
|
| 148 |
+
cache/decompiler signatures in the trajectory but did no symbol-reappearance
|
| 149 |
+
analysis, and was bypassable by string-concat. With the scrub now in place as
|
| 150 |
+
the primary control, the monitor is correctly-scoped defense-in-depth. ADR-012
|
| 151 |
+
re-scoped the language to "signature + patch-provenance monitor" (not "AST")
|
| 152 |
+
and added a patch-provenance layer: a deleted symbol reappearing verbatim in
|
| 153 |
+
the agent's patch alongside a cache/bytecode read — normalized to defeat
|
| 154 |
+
string-concat obfuscation (`"__py"+"cache__"`) — is now flagged.
|
| 155 |
- **[OPEN — recipe fidelity] Curriculum ignores rollout-turns and
|
| 156 |
thinking-token count** (DeepSeek, GPT-5.5). The Composer 2 tech report keys the
|
| 157 |
curriculum on these; the implementation tracks only pass-rate. Follow-up:
|