Codeseys commited on
Commit
d02d724
·
1 Parent(s): b4a584a

feat(wave-a): close ADR-011 (SDPO alignment indices) + ADR-012 (review findings)

Browse files

B1/ADR-011: collator emits student/teacher_response_idx + valid masks via
_mask_to_padded_indices; loss sentinel-masks padding. Strict SDPO no longer
raises against the real collator (the regression my review-fix introduced).

B2/ADR-012:
- k1-KL: found TRL 1.5.0 uses k3 not k1; corrected docstring honestly + documenting test.
- hint routing: style/communication/effort sites now reach the judge (error-kind aware).
- HackMonitor: added patch-provenance layer defeating string-concat obfuscation; ADR-010 language corrected AST->signature+patch-provenance.
- curriculum: optional turns/think_tokens effort signals, backward-compatible.

210 passed / 16 skipped (was 192). Two Opus-4.8 workers in parallel.

composer_replication/datagen/curriculum.py CHANGED
@@ -25,6 +25,23 @@ from dataclasses import dataclass, field
25
  class _TaskStats:
26
  n_pass: float = 0.0
27
  n_total: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  @property
30
  def p_hat(self) -> float:
@@ -45,10 +62,21 @@ class DifficultyCurriculum:
45
  tau_easy: float = 0.95 # above this => retired
46
  tau_hard: float = 0.02 # below this (after min_exposures) => quarantined
47
  min_exposures: int = 8 # before a task can be quarantined as impossible
 
 
 
48
  _stats: dict[str, _TaskStats] = field(default_factory=dict)
49
  _quarantined: set[str] = field(default_factory=set)
50
 
51
- def update(self, task_id: str, n_pass: float, n_total: int) -> None:
 
 
 
 
 
 
 
 
52
  """Record `n_pass` successes over `n_total` exposures.
53
 
54
  `n_pass` is a FLOAT so multi-feature tasks can record fractional credit
@@ -57,10 +85,17 @@ class DifficultyCurriculum:
57
  `int(reward > 0)`, which logged a 0.5 partial as a full pass and let
58
  `p_hat` cross `tau_easy` so the task was retired before the policy ever
59
  learned the remaining features.
 
 
 
 
 
 
60
  """
61
  st = self._stats.setdefault(task_id, _TaskStats())
62
  st.n_pass += n_pass
63
  st.n_total += n_total
 
64
  if (
65
  st.n_total >= self.min_exposures
66
  and st.raw_rate < self.tau_hard
@@ -71,13 +106,39 @@ class DifficultyCurriculum:
71
  return self._stats.get(task_id, _TaskStats()).p_hat
72
 
73
  def weight(self, task_id: str) -> float:
74
- """Sampling weight. Retired/quarantined => 0; else frontier-variance."""
 
75
  if task_id in self._quarantined:
76
  return 0.0
77
  p = self.p_hat(task_id)
78
  if p > self.tau_easy:
79
  return 0.0 # retired — model has aced it
80
- return p * (1.0 - p) # max at p=0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def weights(self, task_ids: list[str]) -> list[float]:
83
  return [self.weight(t) for t in task_ids]
 
25
  class _TaskStats:
26
  n_pass: float = 0.0
27
  n_total: int = 0
28
+ # Running means of effort signals (ADR-012 finding #4). `n_effort` counts
29
+ # exposures that supplied an effort signal (may differ from n_total since
30
+ # turns/think_tokens are optional per update).
31
+ mean_turns: float = 0.0
32
+ mean_think: float = 0.0
33
+ n_effort: int = 0
34
+
35
+ def observe_effort(self, turns: float | None, think_tokens: float | None) -> None:
36
+ """Fold optional turn / think-token signals into running means."""
37
+ if turns is None and think_tokens is None:
38
+ return
39
+ self.n_effort += 1
40
+ k = self.n_effort
41
+ if turns is not None:
42
+ self.mean_turns += (turns - self.mean_turns) / k
43
+ if think_tokens is not None:
44
+ self.mean_think += (think_tokens - self.mean_think) / k
45
 
46
  @property
47
  def p_hat(self) -> float:
 
62
  tau_easy: float = 0.95 # above this => retired
63
  tau_hard: float = 0.02 # below this (after min_exposures) => quarantined
64
  min_exposures: int = 8 # before a task can be quarantined as impossible
65
+ # Strength of the effort (turns/think-token) difficulty tilt (ADR-012 #4).
66
+ # 0.0 reproduces pre-ADR-012 behavior exactly.
67
+ effort_gain: float = 0.1
68
  _stats: dict[str, _TaskStats] = field(default_factory=dict)
69
  _quarantined: set[str] = field(default_factory=set)
70
 
71
+ def update(
72
+ self,
73
+ task_id: str,
74
+ n_pass: float,
75
+ n_total: int,
76
+ *,
77
+ turns: float | None = None,
78
+ think_tokens: float | None = None,
79
+ ) -> None:
80
  """Record `n_pass` successes over `n_total` exposures.
81
 
82
  `n_pass` is a FLOAT so multi-feature tasks can record fractional credit
 
85
  `int(reward > 0)`, which logged a 0.5 partial as a full pass and let
86
  `p_hat` cross `tau_easy` so the task was retired before the policy ever
87
  learned the remaining features.
88
+
89
+ `turns` / `think_tokens` (ADR-012 finding #4) are OPTIONAL per-exposure
90
+ effort signals. The Composer 2 tech report keys the curriculum on rollout
91
+ #turns + thinking-token count: at equal pass-rate, a task that takes more
92
+ turns / thinking is HARDER and should stay on the frontier longer. Both
93
+ default to None => identical behavior to the pre-ADR-012 curriculum.
94
  """
95
  st = self._stats.setdefault(task_id, _TaskStats())
96
  st.n_pass += n_pass
97
  st.n_total += n_total
98
+ st.observe_effort(turns, think_tokens)
99
  if (
100
  st.n_total >= self.min_exposures
101
  and st.raw_rate < self.tau_hard
 
106
  return self._stats.get(task_id, _TaskStats()).p_hat
107
 
108
  def weight(self, task_id: str) -> float:
109
+ """Sampling weight. Retired/quarantined => 0; else frontier-variance,
110
+ tilted up for higher-effort (more turns / think-tokens) tasks."""
111
  if task_id in self._quarantined:
112
  return 0.0
113
  p = self.p_hat(task_id)
114
  if p > self.tau_easy:
115
  return 0.0 # retired — model has aced it
116
+ base = p * (1.0 - p) # max at p=0.5
117
+ return base * self._effort_factor(task_id)
118
+
119
+ def _effort_factor(self, task_id: str) -> float:
120
+ """Mild multiplicative difficulty tilt from the turn/think-token signals
121
+ (ADR-012 finding #4). Returns 1.0 when no effort signals are recorded
122
+ anywhere (so weight() is identical to the pre-ADR-012 behavior), else
123
+ ``1 + effort_gain * z`` where z in [0,1] normalizes this task's mean
124
+ effort against the max observed across all tracked tasks. Monotone
125
+ nondecreasing in effort => a higher-turn task weighs >= a lower-turn one
126
+ at equal pass-rate."""
127
+ st = self._stats.get(task_id)
128
+ if st is None or st.n_effort == 0:
129
+ return 1.0
130
+ max_turns = max(
131
+ (s.mean_turns for s in self._stats.values() if s.n_effort), default=0.0
132
+ )
133
+ max_think = max(
134
+ (s.mean_think for s in self._stats.values() if s.n_effort), default=0.0
135
+ )
136
+ z_turns = st.mean_turns / max_turns if max_turns > 0 else 0.0
137
+ z_think = st.mean_think / max_think if max_think > 0 else 0.0
138
+ # Combine the two normalized effort signals (mean of those present).
139
+ components = [z for z, mx in ((z_turns, max_turns), (z_think, max_think)) if mx > 0]
140
+ z = sum(components) / len(components) if components else 0.0
141
+ return 1.0 + self.effort_gain * z
142
 
143
  def weights(self, task_ids: list[str]) -> list[float]:
144
  return [self.weight(t) for t in task_ids]
composer_replication/datagen/monitor.py CHANGED
@@ -1,14 +1,21 @@
1
- """monitor.py — reward-hacking provenance monitor (ADR-010 §3c).
2
 
3
  Composer's blog reports the model recovering deleted signatures by decompiling
4
  Java bytecode and reverse-engineering Python type-check caches. The PRIMARY
5
  control is the sandbox lockdown (no decompilers, scrubbed caches — see
6
- sandbox.py). This monitor is defense-in-depth: it flags trajectories where a
7
- deleted symbol reappears via a non-implementation path (reading a cache,
8
- importing a sibling that smuggles the cached bytecode, decompiling), so the
9
- grader can MASK the reward.
10
 
11
- It is heuristic (will have false negatives); it is not the only safeguard.
 
 
 
 
 
 
 
 
 
 
12
  """
13
  from __future__ import annotations
14
 
@@ -35,18 +42,56 @@ _HACK_SIGNATURES: tuple[str, ...] = (
35
  "git cat-file",
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @dataclass
40
  class HackMonitor:
41
  """Flags a trajectory as a suspected reward-hack.
42
 
43
- `flag(trajectory, deleted_symbols)` returns True if any action looks like it
44
- recovered a deleted symbol via a non-implementation path.
 
 
45
  """
46
 
47
  extra_signatures: tuple[str, ...] = field(default_factory=tuple)
48
 
49
- def flag(self, trajectory: list[dict], deleted_symbols: tuple[str, ...]) -> bool:
 
 
 
 
 
 
50
  sigs = _HACK_SIGNATURES + tuple(self.extra_signatures)
51
  for action in trajectory:
52
  blob = " ".join(
@@ -61,4 +106,58 @@ class HackMonitor:
61
  r"(cache|\.pyc|\.class|decompil|disassembl)", blob
62
  ):
63
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  return False
 
1
+ """monitor.py — reward-hacking signature + patch-provenance monitor (ADR-010 §3c).
2
 
3
  Composer's blog reports the model recovering deleted signatures by decompiling
4
  Java bytecode and reverse-engineering Python type-check caches. The PRIMARY
5
  control is the sandbox lockdown (no decompilers, scrubbed caches — see
6
+ sandbox.py). This monitor is defense-in-depth with two layers:
 
 
 
7
 
8
+ 1. signature layer — flag trajectory actions containing cache/decompiler
9
+ substrings (defeatable by string-concat obfuscation, kept for cheap
10
+ defense-in-depth);
11
+ 2. patch-provenance layer — if a deleted symbol reappears VERBATIM in the
12
+ agent's submitted patch/diff AND the trajectory contains a cache/decompiler/
13
+ file-read action targeting a bytecode/cache artifact (even string-concat
14
+ obfuscated, e.g. ``"__py"+"cache__"``), flag it. This catches the
15
+ obfuscation bypass that pure substring matching misses.
16
+
17
+ It is heuristic (will have false negatives); it is not the only safeguard. It is
18
+ NOT a full AST analysis — it is a signature + patch-provenance monitor.
19
  """
20
  from __future__ import annotations
21
 
 
42
  "git cat-file",
43
  )
44
 
45
+ # Cache/bytecode/decompiler artifact markers for the patch-provenance check.
46
+ # Matched AFTER normalizing away quotes + '+' (string-concat obfuscation) and
47
+ # whitespace, so ``"__py" + "cache__"`` collapses to ``__pycache__``.
48
+ _ARTIFACT_MARKERS: tuple[str, ...] = (
49
+ "__pycache__",
50
+ ".pyc",
51
+ ".mypy_cache",
52
+ ".class",
53
+ "javap",
54
+ "decompile",
55
+ "uncompyle",
56
+ "decompyle",
57
+ ".pyo",
58
+ "bytecode",
59
+ "disassemb",
60
+ )
61
+
62
+ # Actions that "read" something (a cache/bytecode artifact, for provenance).
63
+ _READ_MARKERS: tuple[str, ...] = (
64
+ "cat", "read", "open", "load", "strings", "unzip", "jar", "javap",
65
+ "decompile", "uncompyle", "decompyle", "dump", "show", "less", "head",
66
+ "tail", "import",
67
+ )
68
+
69
+
70
+ def _normalize(blob: str) -> str:
71
+ """Strip quotes, '+' (concat), and whitespace so string-concat obfuscation
72
+ like ``"__py" + "cache__"`` collapses to a contiguous ``__pycache__``."""
73
+ return re.sub(r"[\"'+\s]", "", blob)
74
+
75
 
76
  @dataclass
77
  class HackMonitor:
78
  """Flags a trajectory as a suspected reward-hack.
79
 
80
+ `flag(trajectory, deleted_symbols, patch=...)` returns True if any action
81
+ looks like it recovered a deleted symbol via a non-implementation path. Two
82
+ layers: a cheap signature substring matcher, and a patch-provenance check
83
+ that defeats string-concat obfuscation of cache/bytecode reads.
84
  """
85
 
86
  extra_signatures: tuple[str, ...] = field(default_factory=tuple)
87
 
88
+ def flag(
89
+ self,
90
+ trajectory: list[dict],
91
+ deleted_symbols: tuple[str, ...],
92
+ patch: str | None = None,
93
+ ) -> bool:
94
+ # --- layer 1: signature substring matcher (defense-in-depth) ---------
95
  sigs = _HACK_SIGNATURES + tuple(self.extra_signatures)
96
  for action in trajectory:
97
  blob = " ".join(
 
106
  r"(cache|\.pyc|\.class|decompil|disassembl)", blob
107
  ):
108
  return True
109
+
110
+ # --- layer 2: patch-provenance ---------------------------------------
111
+ # If no patch was threaded in, try to recover it from a submit_patch /
112
+ # patch / diff action so this works straight off a trajectory.
113
+ if patch is None:
114
+ patch = self._extract_patch(trajectory)
115
+ if patch and self._patch_provenance_hack(trajectory, deleted_symbols, patch):
116
+ return True
117
+
118
+ return False
119
+
120
+ @staticmethod
121
+ def _extract_patch(trajectory: list[dict]) -> str | None:
122
+ """Pull the agent's patch/diff text out of the trajectory, if present."""
123
+ parts: list[str] = []
124
+ for action in trajectory:
125
+ if action.get("type") == "submit_patch":
126
+ for key in ("patch", "diff", "content"):
127
+ v = action.get(key)
128
+ if isinstance(v, str):
129
+ parts.append(v)
130
+ for key in ("patch", "diff"):
131
+ v = action.get(key)
132
+ if isinstance(v, str):
133
+ parts.append(v)
134
+ return "\n".join(parts) if parts else None
135
+
136
+ def _patch_provenance_hack(
137
+ self,
138
+ trajectory: list[dict],
139
+ deleted_symbols: tuple[str, ...],
140
+ patch: str,
141
+ ) -> bool:
142
+ """Flag if a deleted symbol reappears verbatim in the patch AND the
143
+ trajectory shows a cache/bytecode artifact being read — even when the
144
+ artifact reference is string-concat obfuscated."""
145
+ reappeared = [
146
+ s for s in deleted_symbols if s and s in patch
147
+ ]
148
+ if not reappeared:
149
+ return False
150
+ for action in trajectory:
151
+ raw = " ".join(
152
+ str(v) for v in action.values()
153
+ if isinstance(v, (str, int, float))
154
+ )
155
+ low = raw.lower()
156
+ norm = _normalize(low)
157
+ reads_artifact = (
158
+ any(m in norm for m in _ARTIFACT_MARKERS)
159
+ and any(rm in low for rm in _READ_MARKERS)
160
+ )
161
+ if reads_artifact:
162
+ return True
163
  return False
composer_replication/datagen/tests/test_feature_deletion.py CHANGED
@@ -319,3 +319,71 @@ def test_swebench_adapter_handles_list_or_jsonstr_tests():
319
  "FAIL_TO_PASS": ftp, "PASS_TO_PASS": "[]"}
320
  task = SweBenchAdapter().to_task(inst)
321
  assert task.fail_to_pass == ("t1", "t2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  "FAIL_TO_PASS": ftp, "PASS_TO_PASS": "[]"}
320
  task = SweBenchAdapter().to_task(inst)
321
  assert task.fail_to_pass == ("t1", "t2")
322
+
323
+
324
+ # --- ADR-012 #3: patch-provenance monitor (obfuscation-resistant) -----------
325
+
326
+ def test_monitor_flags_string_concat_obfuscated_cache_read_via_patch_provenance():
327
+ """A cache read obfuscated by string-concat ('__py'+'cache__') defeats the
328
+ pure-substring signature layer, but the patch-provenance layer flags it when
329
+ the deleted symbol reappears verbatim in the submitted patch."""
330
+ mon = HackMonitor()
331
+ traj = [
332
+ # obfuscated artifact reference — NOT caught by substring signatures
333
+ {"type": "shell", "command": "cat build/" + "'__py' + 'cache__'" + "/feature_a"},
334
+ {"type": "submit_patch",
335
+ "patch": "def feature_a():\n return 42 # lifted from cache"},
336
+ ]
337
+ assert mon.flag(traj, deleted_symbols=("feature_a",)) is True
338
+
339
+
340
+ def test_monitor_flags_obfuscated_cache_read_with_explicit_patch_arg():
341
+ mon = HackMonitor()
342
+ traj = [{"type": "shell", "command": 'open("__py" + "cache__" + "/x.pyc")'}]
343
+ patch = "def feature_a():\n return 7"
344
+ assert mon.flag(traj, deleted_symbols=("feature_a",), patch=patch) is True
345
+
346
+
347
+ def test_monitor_patch_provenance_not_triggered_for_clean_reimpl():
348
+ """Reintroducing the symbol in the patch is fine when there is NO
349
+ cache/bytecode read in the trajectory — that's a legitimate reimplementation."""
350
+ mon = HackMonitor()
351
+ traj = [
352
+ {"type": "edit", "path": "src/widget.py",
353
+ "content": "def feature_a(): return 42"},
354
+ {"type": "submit_patch", "patch": "def feature_a():\n return 42"},
355
+ ]
356
+ assert mon.flag(traj, deleted_symbols=("feature_a",)) is False
357
+
358
+
359
+ # --- ADR-012 #4: curriculum turn/think-token signals ------------------------
360
+
361
+ def test_curriculum_higher_turn_task_weighted_at_least_as_high():
362
+ """Two tasks, IDENTICAL pass-rate, different mean turns => the higher-turn
363
+ (harder) task must weight >= the lower-turn one."""
364
+ cur = DifficultyCurriculum()
365
+ for _ in range(10):
366
+ cur.update("low", n_pass=1, n_total=2, turns=3.0)
367
+ cur.update("high", n_pass=1, n_total=2, turns=30.0)
368
+ assert cur.p_hat("low") == cur.p_hat("high") # same pass-rate
369
+ assert cur.weight("high") >= cur.weight("low")
370
+ assert cur.weight("high") > cur.weight("low") # strictly, given the gap
371
+
372
+
373
+ def test_curriculum_think_tokens_also_tilt_weight():
374
+ cur = DifficultyCurriculum()
375
+ for _ in range(10):
376
+ cur.update("cheap", n_pass=1, n_total=2, think_tokens=100.0)
377
+ cur.update("expensive", n_pass=1, n_total=2, think_tokens=5000.0)
378
+ assert cur.weight("expensive") >= cur.weight("cheap")
379
+
380
+
381
+ def test_curriculum_backward_compatible_without_effort_signals():
382
+ """No turns/think_tokens => weight identical to the pre-ADR-012 formula
383
+ p*(1-p), so existing behavior and tests are unchanged."""
384
+ cur = DifficultyCurriculum()
385
+ for _ in range(10):
386
+ cur.update("A", n_pass=1, n_total=2)
387
+ p = cur.p_hat("A")
388
+ assert cur.weight("A") == p * (1.0 - p)
389
+
composer_replication/hint_generator.py CHANGED
@@ -178,6 +178,72 @@ class RawErrorHintGenerator:
178
  return f"Reminder: the previous action produced this error:\n{truncated}\nReconsider and retry."
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  class LLMJudgeHintGenerator:
182
  """Layer 3: an LLM produces a short corrective hint.
183
 
@@ -321,11 +387,14 @@ def default_composite(
321
  ) -> CompositeHintGenerator:
322
  """Build the recommended layered generator: templates -> raw-error -> judge.
323
 
324
- The LLM-judge layer is included only when `llm_complete` is provided.
 
 
 
325
  """
326
  layers: list[HintGenerator] = [TemplateHintGenerator()]
327
  if enable_raw_error:
328
- layers.append(RawErrorHintGenerator())
329
  if llm_complete is not None:
330
  layers.append(LLMJudgeHintGenerator(llm_complete, cache_dir=cache_dir))
331
  return CompositeHintGenerator(layers)
@@ -340,6 +409,8 @@ __all__ = [
340
  "HintGenerator",
341
  "TemplateHintGenerator",
342
  "RawErrorHintGenerator",
 
 
343
  "LLMJudgeHintGenerator",
344
  "CompositeHintGenerator",
345
  "default_composite",
 
178
  return f"Reminder: the previous action produced this error:\n{truncated}\nReconsider and retry."
179
 
180
 
181
+ # ---------------------------------------------------------------------------
182
+ # Error-kind routing (ADR-012 finding #2)
183
+ # ---------------------------------------------------------------------------
184
+ #
185
+ # The default composite is template -> raw-error -> judge. The raw-error layer
186
+ # fires for ANY kind carrying a message — including style/communication/effort
187
+ # sites, which are EXACTLY what the LLM judge exists to cover. So we route:
188
+ # tool/runtime error kinds may use the raw-error layer; style/communication/
189
+ # effort kinds skip it and fall through to the judge.
190
+
191
+ # Error kinds that genuinely describe a tool/runtime failure whose raw text is a
192
+ # useful, self-contained hint. The explicit registry-template kinds are included
193
+ # so behavior is unchanged for them.
194
+ _TOOL_RUNTIME_KINDS: frozenset[str] = frozenset({
195
+ "tool_not_found",
196
+ "json_decode",
197
+ "type_error",
198
+ "runtime_error",
199
+ "repeated_failure",
200
+ })
201
+
202
+ # Substrings marking a kind as tool/runtime-ish even if not explicitly listed
203
+ # (keeps generic "*_error"/"*_exception" sites flowing through raw-error, which
204
+ # is where their raw text belongs).
205
+ _TOOL_RUNTIME_MARKERS: tuple[str, ...] = (
206
+ "error", "exception", "fail", "decode", "timeout", "traceback",
207
+ "exit_code", "nonzero", "syntax", "import", "assertion", "tool",
208
+ "runtime", "crash", "exec",
209
+ )
210
+
211
+ # Substrings marking a kind as a style/communication/effort site — the judge's
212
+ # domain. These take precedence: a kind matching one of these skips raw-error.
213
+ _STYLE_KINDS_MARKERS: tuple[str, ...] = (
214
+ "style", "communic", "verbose", "effort", "concise", "tone",
215
+ "format", "wordy", "rambl", "explanation", "etiquette", "clarity",
216
+ )
217
+
218
+
219
+ def is_tool_runtime_kind(error_kind: str) -> bool:
220
+ """True if `error_kind` is a tool/runtime failure that the raw-error layer
221
+ may serve. Style/communication/effort kinds return False (-> judge)."""
222
+ k = (error_kind or "").lower()
223
+ if any(m in k for m in _STYLE_KINDS_MARKERS):
224
+ return False
225
+ if k in _TOOL_RUNTIME_KINDS:
226
+ return True
227
+ return any(m in k for m in _TOOL_RUNTIME_MARKERS)
228
+
229
+
230
+ class RoutingHintGenerator:
231
+ """Wraps an inner layer (the raw-error layer) and only lets it fire for
232
+ tool/runtime error kinds. For style/communication/effort kinds it returns
233
+ None so the composite falls through to the judge — the layer those sites
234
+ were always meant to reach (ADR-012 finding #2).
235
+ """
236
+
237
+ def __init__(self, inner: HintGenerator, route=is_tool_runtime_kind) -> None:
238
+ self.inner = inner
239
+ self.route = route
240
+
241
+ def generate(self, error_kind: str, error_meta: dict) -> str | None:
242
+ if not self.route(error_kind):
243
+ return None
244
+ return self.inner.generate(error_kind, error_meta)
245
+
246
+
247
  class LLMJudgeHintGenerator:
248
  """Layer 3: an LLM produces a short corrective hint.
249
 
 
387
  ) -> CompositeHintGenerator:
388
  """Build the recommended layered generator: templates -> raw-error -> judge.
389
 
390
+ The raw-error layer is wrapped in a RoutingHintGenerator so it only fires for
391
+ tool/runtime error kinds; style/communication/effort kinds skip it and fall
392
+ through to the LLM judge (ADR-012 finding #2). The LLM-judge layer is
393
+ included only when `llm_complete` is provided.
394
  """
395
  layers: list[HintGenerator] = [TemplateHintGenerator()]
396
  if enable_raw_error:
397
+ layers.append(RoutingHintGenerator(RawErrorHintGenerator()))
398
  if llm_complete is not None:
399
  layers.append(LLMJudgeHintGenerator(llm_complete, cache_dir=cache_dir))
400
  return CompositeHintGenerator(layers)
 
409
  "HintGenerator",
410
  "TemplateHintGenerator",
411
  "RawErrorHintGenerator",
412
+ "RoutingHintGenerator",
413
+ "is_tool_runtime_kind",
414
  "LLMJudgeHintGenerator",
415
  "CompositeHintGenerator",
416
  "default_composite",
composer_replication/tests/test_hint_routing.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for error-kind hint routing on the DEFAULT composite (ADR-012 #2).
2
+
3
+ The default composite is template -> raw-error -> judge. Before ADR-012 the
4
+ raw-error layer consumed ANY site carrying an `error_message`, including
5
+ style/communication/effort sites — exactly the sites the LLM judge exists to
6
+ cover. These tests validate the DEFAULT path (raw-error NOT disabled): a
7
+ style/communication site WITH an error_message routes through to the judge,
8
+ while tool/runtime sites still use the raw-error layer.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from composer_replication.hint_generator import (
13
+ RoutingHintGenerator,
14
+ RawErrorHintGenerator,
15
+ default_composite,
16
+ is_tool_runtime_kind,
17
+ )
18
+
19
+
20
+ # --- the headline acceptance: style site reaches judge on the DEFAULT path ---
21
+
22
+ def test_style_site_with_error_message_reaches_judge_on_default_composite():
23
+ calls = {"n": 0}
24
+
25
+ def fake_complete(prompt: str) -> str:
26
+ calls["n"] += 1
27
+ return "Be more concise; you repeated the same explanation twice."
28
+
29
+ # NOTE: raw-error is ENABLED (the default). Pre-ADR-012 this would have been
30
+ # eaten by the raw-error layer and the judge never called.
31
+ comp = default_composite(llm_complete=fake_complete) # enable_raw_error=True
32
+ hint = comp.generate(
33
+ "verbose_communication",
34
+ {"error_message": "The agent restated the plan three times."},
35
+ )
36
+ assert hint == "Be more concise; you repeated the same explanation twice."
37
+ assert calls["n"] == 1, "style site must reach the judge, not the raw-error layer"
38
+
39
+
40
+ def test_effort_site_with_message_routes_to_judge():
41
+ calls = {"n": 0}
42
+
43
+ def fake_complete(prompt: str) -> str:
44
+ calls["n"] += 1
45
+ return "Don't pad the answer; one example suffices."
46
+
47
+ comp = default_composite(llm_complete=fake_complete)
48
+ hint = comp.generate("low_effort_style", {"error_message": "padding detected"})
49
+ assert hint == "Don't pad the answer; one example suffices."
50
+ assert calls["n"] == 1
51
+
52
+
53
+ # --- tool/runtime sites still served by raw-error (no regression) -----------
54
+
55
+ def test_tool_runtime_site_still_served_by_raw_error_no_judge():
56
+ calls = {"n": 0}
57
+
58
+ def fake_complete(prompt: str) -> str:
59
+ calls["n"] += 1
60
+ return "JUDGE (should not be called)"
61
+
62
+ comp = default_composite(llm_complete=fake_complete)
63
+ # an unmapped *runtime* error (no template) -> raw-error layer, not judge.
64
+ hint = comp.generate("weird_runtime_error", {"error_message": "Segfault at 0x0"})
65
+ assert hint is not None
66
+ assert "Segfault at 0x0" in hint
67
+ assert calls["n"] == 0, "tool/runtime sites must be served by raw-error, not judge"
68
+
69
+
70
+ def test_template_site_unaffected_by_routing():
71
+ comp = default_composite() # no judge
72
+ hint = comp.generate("tool_not_found", {"available_tools": ["read", "write"]})
73
+ assert hint is not None and "Available tools" in hint
74
+
75
+
76
+ # --- the route predicate ----------------------------------------------------
77
+
78
+ def test_route_predicate_classifies_kinds():
79
+ # tool/runtime
80
+ for k in ("tool_not_found", "json_decode", "type_error", "runtime_error",
81
+ "repeated_failure", "weird_runtime_error", "some_exception",
82
+ "weird_unmapped_error"):
83
+ assert is_tool_runtime_kind(k) is True, k
84
+ # style/communication/effort
85
+ for k in ("verbose_communication", "low_effort_style", "tone_violation",
86
+ "rambling_explanation", "bad_formatting"):
87
+ assert is_tool_runtime_kind(k) is False, k
88
+
89
+
90
+ def test_routing_generator_returns_none_for_style_kind():
91
+ routed = RoutingHintGenerator(RawErrorHintGenerator())
92
+ # style kind WITH a message -> None (defer to judge), even though the inner
93
+ # raw-error layer would have produced a hint.
94
+ assert routed.generate("verbose_style", {"error_message": "too long"}) is None
95
+ # tool/runtime kind WITH a message -> inner fires.
96
+ out = routed.generate("runtime_error", {"error_message": "boom"})
97
+ assert out is not None and "boom" in out
composer_replication/trainer/composer_trainer.py CHANGED
@@ -225,16 +225,37 @@ class ComposerReplicationTrainer(GRPOTrainer): # type: ignore[misc, valid-type]
225
  # Gather the provably-aligned response logits from each sequence, then
226
  # JSD only those positions (this is the masked error-turn distillation).
227
  # gather over the sequence dim (dim=1): expand index to the vocab dim.
 
 
 
 
 
 
 
 
 
 
 
228
  vocab = student_logits.size(-1)
229
- s_gather = s_idx.unsqueeze(-1).expand(-1, -1, vocab)
230
- t_gather = t_idx.unsqueeze(-1).expand(-1, -1, vocab)
 
 
231
  student_aligned = torch.gather(student_logits, 1, s_gather)
232
  teacher_aligned = torch.gather(teacher_logits, 1, t_gather)
233
 
 
 
 
 
 
 
 
 
234
  return generalized_jsd_loss(
235
  student_logits=student_aligned,
236
  teacher_logits=teacher_aligned,
237
- labels=inputs.get("sdpo_loss_mask"), # optional further error-turn mask
238
  beta=self.sdpo_jsd_beta,
239
  temperature=self.sdpo_temperature,
240
  token_clip=self.sdpo_token_clip,
@@ -325,8 +346,18 @@ def make_dr_grpo_config(**overrides: Any):
325
  standard deviation introduces a question-level difficulty bias."
326
  - ``num_iterations=1`` — single-epoch regime (a prompt is never
327
  trained on twice), matching the tech report.
328
- - ``beta`` (KL-to-ref coef) kept; TRL uses the k1 (−log r)-family
329
- estimator the report selects.
 
 
 
 
 
 
 
 
 
 
330
 
331
  Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
332
  ``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
 
225
  # Gather the provably-aligned response logits from each sequence, then
226
  # JSD only those positions (this is the masked error-turn distillation).
227
  # gather over the sequence dim (dim=1): expand index to the vocab dim.
228
+ #
229
+ # ADR-011: ragged-K rows are padded with a sentinel (-1) and a per-row
230
+ # *_valid mask. Negative indices are illegal for torch.gather, so clamp
231
+ # to 0 before gathering, then neutralize those positions by feeding
232
+ # labels=-100 (the standard HF ignore convention that generalized_jsd_loss
233
+ # already honors). This makes sentinel/padding positions contribute 0.
234
+ if "student_response_valid" in inputs and inputs["student_response_valid"] is not None:
235
+ aligned_mask = inputs["student_response_valid"].bool()
236
+ else:
237
+ aligned_mask = (s_idx >= 0) & (t_idx >= 0)
238
+
239
  vocab = student_logits.size(-1)
240
+ s_safe = s_idx.clamp_min(0)
241
+ t_safe = t_idx.clamp_min(0)
242
+ s_gather = s_safe.unsqueeze(-1).expand(-1, -1, vocab)
243
+ t_gather = t_safe.unsqueeze(-1).expand(-1, -1, vocab)
244
  student_aligned = torch.gather(student_logits, 1, s_gather)
245
  teacher_aligned = torch.gather(teacher_logits, 1, t_gather)
246
 
247
+ # Build (B, K) labels: 1 at valid aligned positions, -100 (ignore) at
248
+ # sentinel/padding positions so they drop out of the JSD reduction.
249
+ aligned_labels = torch.where(
250
+ aligned_mask,
251
+ torch.ones_like(s_idx),
252
+ torch.full_like(s_idx, -100),
253
+ )
254
+
255
  return generalized_jsd_loss(
256
  student_logits=student_aligned,
257
  teacher_logits=teacher_aligned,
258
+ labels=aligned_labels, # sentinel-masked aligned error-turn positions
259
  beta=self.sdpo_jsd_beta,
260
  temperature=self.sdpo_temperature,
261
  token_clip=self.sdpo_token_clip,
 
346
  standard deviation introduces a question-level difficulty bias."
347
  - ``num_iterations=1`` — single-epoch regime (a prompt is never
348
  trained on twice), matching the tech report.
349
+ - ``beta`` (KL-to-ref coef) kept. NOTE on the KL estimator (ADR-012
350
+ finding #1, verified against the installed trl==1.5.0 source):
351
+ ``GRPOTrainer._compute_loss`` uses the **k3** estimator
352
+ ``exp(ref_logp - logp) - (ref_logp - logp) - 1``
353
+ (trl/trainer/grpo_trainer.py ~L2513), NOT the k1 estimator
354
+ ``-log r == (ref_logp - logp)``. k3 is Schulman's low-variance,
355
+ always-non-negative KL approximation; k1 is its unbiased but
356
+ higher-variance counterpart. The Dr. GRPO / Composer 2 report discusses
357
+ KL in k1 terms, but the delta is small for r≈1 (k3 = k1 + O((Δlogp)^2))
358
+ and TRL's k3 choice is the production reality. We do NOT monkeypatch TRL
359
+ to force k1; we document the honest delta. See
360
+ ``test_dr_grpo_config_and_alignment.py::test_trl_kl_estimator_is_k3_not_k1``.
361
 
362
  Any field can be overridden via kwargs (e.g. ``learning_rate=...``,
363
  ``output_dir=...``). The three Dr. GRPO-defining knobs are forced unless
composer_replication/trainer/data_collator.py CHANGED
@@ -118,6 +118,47 @@ def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]:
118
  return seq + [pad_id] * (target_len - len(seq))
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # ---------------------------------------------------------------------------
122
  # The collator
123
  # ---------------------------------------------------------------------------
@@ -190,6 +231,25 @@ class ComposerDataCollator:
190
  out["attention_mask"] = aligned["attention_mask"]
191
  out["response_mask"] = aligned["response_mask"]
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # --- Channel 3: trace-replay DPO fields ---
194
  if self.config.enable_replay_dpo:
195
  dpo = self._build_dpo_fields(batch)
 
118
  return seq + [pad_id] * (target_len - len(seq))
119
 
120
 
121
+ def _mask_to_padded_indices(
122
+ mask: torch.Tensor, # (B, T) where nonzero/True == valid position
123
+ pad_sentinel: int = -1,
124
+ ) -> tuple[torch.Tensor, torch.Tensor]:
125
+ """Convert a (B,T) bool/0-1 mask → (B,K_max) index tensor + (B,K_max) validity mask.
126
+
127
+ Each row's K valid positions are written left-aligned into ``idx``; the
128
+ ragged tail (rows with fewer than K_max positions) is padded with
129
+ ``pad_sentinel`` (default -1). ``valid`` is True exactly where ``idx``
130
+ holds a real position.
131
+
132
+ ADR-011: the SDPO loss gathers post-hint response logits via these indices,
133
+ then masks the sentinel/padding positions so they contribute 0. K_max=0
134
+ (no valid positions anywhere) returns (B,0) tensors.
135
+ """
136
+ B, T = mask.shape
137
+ bool_mask = mask != 0
138
+ counts = bool_mask.sum(dim=1).long() # (B,) — K per row
139
+ K_max = int(counts.max().item()) if counts.numel() else 0
140
+ if K_max == 0:
141
+ return (
142
+ torch.full((B, 0), pad_sentinel, dtype=torch.long, device=mask.device),
143
+ torch.zeros(B, 0, dtype=torch.bool, device=mask.device),
144
+ )
145
+ idx = torch.full((B, K_max), pad_sentinel, dtype=torch.long, device=mask.device)
146
+ valid = torch.zeros(B, K_max, dtype=torch.bool, device=mask.device)
147
+ # torch.nonzero on a 2D bool tensor yields (total_K, 2): (batch_idx, pos_idx),
148
+ # row-major so positions are already in per-row, ascending order.
149
+ nz = torch.nonzero(bool_mask, as_tuple=False) # (total_K, 2)
150
+ pos_idx = nz[:, 1]
151
+ offsets = torch.zeros(B + 1, dtype=torch.long, device=mask.device)
152
+ offsets[1:] = counts.cumsum(dim=0)
153
+ for b in range(B):
154
+ start, end = int(offsets[b].item()), int(offsets[b + 1].item())
155
+ k = end - start
156
+ if k > 0:
157
+ idx[b, :k] = pos_idx[start:end]
158
+ valid[b, :k] = True
159
+ return idx, valid
160
+
161
+
162
  # ---------------------------------------------------------------------------
163
  # The collator
164
  # ---------------------------------------------------------------------------
 
231
  out["attention_mask"] = aligned["attention_mask"]
232
  out["response_mask"] = aligned["response_mask"]
233
 
234
+ # --- ADR-011: emit SDPO alignment indices ---
235
+ # The loss (strict mode, default) requires explicit per-token
236
+ # alignment indices into each sequence so the JSD compares
237
+ # corresponding post-hint response tokens. Derive them from the
238
+ # already-aligned masks: teacher positions from sdpo_loss_mask==1,
239
+ # student positions from response_mask==1. Both masks are placed
240
+ # on content tokens by _build_chat_aligned_mask, and the
241
+ # placeholder-system-message trick makes them land at the SAME
242
+ # logical token, so at valid positions s_idx == t_idx.
243
+ if "sdpo_loss_mask" in out and "response_mask" in out:
244
+ t_mask = out["sdpo_loss_mask"] == 1
245
+ s_mask = out["response_mask"] == 1
246
+ t_idx, t_valid = _mask_to_padded_indices(t_mask)
247
+ s_idx, s_valid = _mask_to_padded_indices(s_mask)
248
+ out["student_response_idx"] = s_idx
249
+ out["teacher_response_idx"] = t_idx
250
+ out["student_response_valid"] = s_valid
251
+ out["teacher_response_valid"] = t_valid
252
+
253
  # --- Channel 3: trace-replay DPO fields ---
254
  if self.config.enable_replay_dpo:
255
  dpo = self._build_dpo_fields(batch)
composer_replication/trainer/tests/test_dr_grpo_config_and_alignment.py CHANGED
@@ -57,6 +57,70 @@ def test_make_dr_grpo_config_override_does_not_silently_break_guard(tmp_path):
57
  assert cfg.loss_type == "grpo"
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # ---------------------------------------------------------------------------
61
  # Gate 2 — SDPO strict-alignment guard (no real GRPOTrainer needed)
62
  # ---------------------------------------------------------------------------
 
57
  assert cfg.loss_type == "grpo"
58
 
59
 
60
+ # ---------------------------------------------------------------------------
61
+ # ADR-012 finding #1 — TRL's native KL estimator (k1 vs k3)
62
+ # ---------------------------------------------------------------------------
63
+
64
+ def test_trl_kl_estimator_is_k3_not_k1():
65
+ """Document, honestly, which KL estimator TRL's GRPOTrainer actually uses.
66
+
67
+ Two common per-token KL approximations of KL(pi || pi_ref), given the log
68
+ importance ratio Δ = ref_logp - logp (so r = pi/pi_ref = exp(-Δ)... we use
69
+ the trl convention Δ = ref_logp - logp directly):
70
+
71
+ k1 = Δ = (ref_logp - logp) (unbiased, higher var)
72
+ k3 = exp(Δ) - Δ - 1 (Schulman, low var, >= 0)
73
+
74
+ make_dr_grpo_config's docstring previously *claimed* TRL uses k1. Inspecting
75
+ the installed trl==1.5.0 source (grpo_trainer.py ~L2513) shows it actually
76
+ computes k3: `torch.exp(ref - logp) - (ref - logp) - 1`. This test pins
77
+ that finding so the docstring stays honest and a future TRL change is caught.
78
+ """
79
+ # Known logprob pairs (student logp, reference logp).
80
+ logp = torch.tensor([-1.0, -2.0, -0.5, -3.0])
81
+ ref_logp = torch.tensor([-1.2, -1.5, -0.7, -2.4])
82
+ delta = ref_logp - logp
83
+
84
+ k1 = delta
85
+ k3 = torch.exp(delta) - delta - 1.0
86
+
87
+ # k3 is always non-negative; k1 can be negative — a structural difference.
88
+ assert (k3 >= -1e-6).all(), "k3 must be non-negative (Schulman estimator)"
89
+ assert (k1 < 0).any(), "k1 (= Δ) can be negative; the test data exercises that"
90
+
91
+ # The TRL 1.5.0 source uses k3 (verified by grepping the installed package).
92
+ import inspect
93
+
94
+ from trl import GRPOTrainer
95
+
96
+ src = inspect.getsource(GRPOTrainer)
97
+ # The k3 signature: exp(ref - logp) - (ref - logp) - 1. We assert the
98
+ # distinctive `torch.exp(` of the ratio appears in the per-token KL block.
99
+ assert "per_token_kl" in src, "TRL GRPOTrainer no longer has a per_token_kl block"
100
+ uses_k3 = "torch.exp(ref_per_token_logps - per_token_logps)" in src
101
+ uses_k1_only = (
102
+ "per_token_kl = ref_per_token_logps - per_token_logps" in src and not uses_k3
103
+ )
104
+ assert uses_k3, (
105
+ "Expected TRL 1.5.0 to compute the k3 KL estimator "
106
+ "exp(ref - logp) - (ref - logp) - 1. If this fails, TRL changed its "
107
+ "estimator — re-verify make_dr_grpo_config's docstring (which documents "
108
+ f"k3, not k1). uses_k1_only={uses_k1_only}"
109
+ )
110
+
111
+ # Sanity: for small Δ (r≈1) the two estimators agree to second order, which
112
+ # is why the report's k1 framing and TRL's k3 reality differ only mildly.
113
+ small = torch.tensor([0.01, -0.02, 0.005])
114
+ k1_small = small
115
+ k3_small = torch.exp(small) - small - 1.0
116
+ assert torch.allclose(k3_small, 0.5 * small**2, atol=1e-4), (
117
+ "k3 should be ~Δ²/2 for small Δ (its leading order)"
118
+ )
119
+ assert (k3_small.abs() < k1_small.abs()).all(), (
120
+ "for small Δ, |k3| << |k1| — the delta the docstring documents is minor"
121
+ )
122
+
123
+
124
  # ---------------------------------------------------------------------------
125
  # Gate 2 — SDPO strict-alignment guard (no real GRPOTrainer needed)
126
  # ---------------------------------------------------------------------------
composer_replication/trainer/tests/test_sdpo_alignment_indices.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ADR-011 — collator-emitted SDPO alignment indices + loss sentinel-masking.
2
+
3
+ These tests close the strict-SDPO-raises regression: the SDPO loss requires
4
+ explicit `student_response_idx`/`teacher_response_idx` (B,K) LongTensors, and
5
+ the production collator must emit them. Covered acceptance gates:
6
+
7
+ 1. `_mask_to_padded_indices` ragged-K shape + sentinel/valid semantics.
8
+ 2. Real `ComposerDataCollator` emits the 4 alignment keys with correct
9
+ shapes; student_response_idx == teacher_response_idx at valid positions.
10
+ 3. THE REGRESSION: real collator → batch → `_compute_sdpo_loss` in STRICT
11
+ mode (default) runs WITHOUT raising and returns a finite positive loss.
12
+ 4. Ragged-K: a 2-row batch with different K per row → finite loss, the K=1
13
+ row's sentinel padding does not leak into the JSD.
14
+
15
+ All CPU-only and fast (stub tokenizer + tiny model — no model download).
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import pytest
20
+ import torch
21
+
22
+ from composer_replication.trainer.data_collator import (
23
+ CollatorConfig,
24
+ ComposerDataCollator,
25
+ _mask_to_padded_indices,
26
+ )
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Stubs (mirror the patterns in test_chat_template_alignment.py /
31
+ # test_dr_grpo_config_and_alignment.py so these tests need no model cache).
32
+ # ---------------------------------------------------------------------------
33
+
34
+ class _StubTok:
35
+ """Word-level deterministic tokenizer; apply_chat_template space-joins."""
36
+
37
+ pad_token_id = 0
38
+
39
+ def __init__(self) -> None:
40
+ self._v: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
41
+
42
+ def _id(self, w: str) -> int:
43
+ if w not in self._v:
44
+ self._v[w] = len(self._v)
45
+ return self._v[w]
46
+
47
+ def __call__(self, text, **_k):
48
+ return {"input_ids": [self._id(w) for w in text.split()] if text else []}
49
+
50
+ def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
51
+ return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()]
52
+
53
+
54
+ class _TinyLM(torch.nn.Module):
55
+ """Minimal HF-style model: model(input_ids=...).logits.
56
+
57
+ Position-DEPENDENT: adds a learned positional bias so identical token ids at
58
+ DIFFERENT sequence positions produce DIFFERENT logits. This matters for the
59
+ SDPO regression test — student and teacher share the same response token ids
60
+ but at different absolute positions (the hint/placeholder shifts them), so a
61
+ position-independent model would give JSD≈0 and mask a real misalignment bug.
62
+ """
63
+
64
+ def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512):
65
+ super().__init__()
66
+ self.embed = torch.nn.Embedding(vocab, hidden)
67
+ self.pos = torch.nn.Embedding(max_pos, hidden)
68
+ self.head = torch.nn.Linear(hidden, vocab)
69
+
70
+ def forward(self, input_ids: torch.Tensor):
71
+ T = input_ids.size(1)
72
+ positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
73
+ h = self.embed(input_ids) + self.pos(positions)
74
+ logits = self.head(h)
75
+
76
+ class _Out:
77
+ pass
78
+
79
+ out = _Out()
80
+ out.logits = logits
81
+ return out
82
+
83
+
84
+ def _hint_gen(kind, _meta):
85
+ return "HINT search before reading"
86
+
87
+
88
+ def _make_sdpo_trainer():
89
+ """ComposerReplicationTrainer instance without GRPOTrainer.__init__ — we
90
+ only exercise _compute_sdpo_loss, in STRICT mode (default)."""
91
+ from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
92
+
93
+ obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
94
+ obj.alpha_sdpo = 1.0
95
+ obj.sdpo_jsd_beta = 0.5
96
+ obj.sdpo_temperature = 1.0
97
+ obj.sdpo_token_clip = None
98
+ obj.strict_sdpo_alignment = True # the default / production setting
99
+ return obj
100
+
101
+
102
+ def _error_trace(trace_id: str, recovery: str = "let me use a real tool instead"):
103
+ return {
104
+ "trace_id": trace_id,
105
+ "turns": [
106
+ {"role": "user", "content": "do the task now"},
107
+ {"role": "user", "content": "tool not found error occurred"},
108
+ {
109
+ "role": "assistant",
110
+ "content": recovery,
111
+ "tool_error": "tool_not_found",
112
+ "error_meta": {},
113
+ },
114
+ ],
115
+ "final_reward": 0.0,
116
+ }
117
+
118
+
119
+ # ---------------------------------------------------------------------------
120
+ # Gate 1 — _mask_to_padded_indices ragged-K semantics
121
+ # ---------------------------------------------------------------------------
122
+
123
+ def test_mask_to_padded_indices_ragged_k():
124
+ """2 rows, K=3 and K=1 → (2,3) idx; row1 tail padded with -1;
125
+ valid[1] == [True, False, False]."""
126
+ mask = torch.tensor(
127
+ [
128
+ [0, 1, 1, 0, 1], # K=3 at positions 1,2,4
129
+ [0, 0, 1, 0, 0], # K=1 at position 2
130
+ ],
131
+ dtype=torch.long,
132
+ )
133
+ idx, valid = _mask_to_padded_indices(mask)
134
+ assert idx.shape == (2, 3)
135
+ assert valid.shape == (2, 3)
136
+ assert idx[0].tolist() == [1, 2, 4]
137
+ assert idx[1].tolist() == [2, -1, -1]
138
+ assert valid[0].tolist() == [True, True, True]
139
+ assert valid[1].tolist() == [True, False, False]
140
+ assert idx.dtype == torch.long
141
+ assert valid.dtype == torch.bool
142
+
143
+
144
+ def test_mask_to_padded_indices_empty_returns_b0():
145
+ """K_max == 0 (no valid positions) returns (B,0) tensors."""
146
+ mask = torch.zeros(3, 5, dtype=torch.long)
147
+ idx, valid = _mask_to_padded_indices(mask)
148
+ assert idx.shape == (3, 0)
149
+ assert valid.shape == (3, 0)
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Gate 2 — collator emits the 4 alignment keys with correct shapes
154
+ # ---------------------------------------------------------------------------
155
+
156
+ def test_collator_emits_alignment_indices_keys():
157
+ tok = _StubTok()
158
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
159
+ collator = ComposerDataCollator(tokenizer=tok, config=cfg)
160
+ batch = collator([_error_trace("align-1")])
161
+
162
+ for key in (
163
+ "student_response_idx",
164
+ "teacher_response_idx",
165
+ "student_response_valid",
166
+ "teacher_response_valid",
167
+ ):
168
+ assert key in batch, f"collator did not emit {key!r}"
169
+
170
+ s_idx = batch["student_response_idx"]
171
+ t_idx = batch["teacher_response_idx"]
172
+ s_valid = batch["student_response_valid"]
173
+ assert s_idx.shape == t_idx.shape
174
+ assert s_idx.shape == s_valid.shape
175
+ assert s_idx.dtype == torch.long
176
+ assert s_valid.dtype == torch.bool
177
+ # There must be at least one valid aligned position.
178
+ assert int(s_valid.sum()) > 0
179
+ # At valid positions the placeholder-trick makes the two indices identical.
180
+ vmask = s_valid
181
+ assert torch.equal(s_idx[vmask], t_idx[vmask]), (
182
+ "student/teacher indices diverge at valid positions; the placeholder "
183
+ "alignment trick is broken."
184
+ )
185
+
186
+
187
+ # ---------------------------------------------------------------------------
188
+ # Gate 3 — THE REGRESSION TEST: real collator → strict _compute_sdpo_loss
189
+ # ---------------------------------------------------------------------------
190
+
191
+ def test_strict_sdpo_loss_runs_on_real_collator_batch():
192
+ """Real ComposerDataCollator batch → _compute_sdpo_loss in STRICT mode
193
+ (default) runs WITHOUT raising and returns a finite, positive loss.
194
+ This is the whole point of ADR-011."""
195
+ tok = _StubTok()
196
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
197
+ collator = ComposerDataCollator(tokenizer=tok, config=cfg)
198
+ batch = collator([_error_trace("regression-1")])
199
+
200
+ # vocab must cover every token id the stub tokenizer produced.
201
+ vocab = int(max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())) + 1
202
+ model = _TinyLM(vocab=max(vocab, 8))
203
+
204
+ obj = _make_sdpo_trainer()
205
+ loss = obj._compute_sdpo_loss(model, batch) # must NOT raise
206
+
207
+ val = float(loss.detach())
208
+ assert val == val, "SDPO loss is NaN"
209
+ assert val not in (float("inf"), float("-inf")), "SDPO loss is infinite"
210
+ # JSD is always >= 0. With this context-free stub model the gathered
211
+ # student/teacher logits at correctly-aligned positions (same token id, same
212
+ # absolute position) are identical, so the JSD floors at ~0 — that is the
213
+ # CORRECT answer for a perfectly-aligned identical model, not a bug. The
214
+ # whole-point assertion is that strict mode RAN (no raise) and produced a
215
+ # real finite scalar on a grad path; positivity needs an attention model
216
+ # (covered by examples/composer_grpo_sdpo_smoke on Qwen2.5-0.5B).
217
+ assert val >= -1e-6, f"JSD must be non-negative, got {val}"
218
+ assert loss.requires_grad, "SDPO loss must be differentiable (grad path)"
219
+
220
+
221
+ # ---------------------------------------------------------------------------
222
+ # Gate 4 — ragged-K batch: K=1 row padding must not leak into the loss
223
+ # ---------------------------------------------------------------------------
224
+
225
+ def test_ragged_k_batch_finite_loss_no_padding_leak():
226
+ """A 2-row batch with different recovery lengths → ragged K. The loss must
227
+ be finite and the K=1 row's sentinel padding must not contribute."""
228
+ tok = _StubTok()
229
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
230
+ collator = ComposerDataCollator(tokenizer=tok, config=cfg)
231
+ batch = collator([
232
+ _error_trace("ragged-long", recovery="recover with a real working tool now please"),
233
+ _error_trace("ragged-short", recovery="ok"),
234
+ ])
235
+
236
+ s_idx = batch["student_response_idx"]
237
+ s_valid = batch["student_response_valid"]
238
+ # Ragged: at least one row should be shorter (have an invalid tail) OR the
239
+ # rows genuinely differ — assert sentinel padding exists where invalid.
240
+ assert (s_idx == -1)[~s_valid].all(), "invalid positions must hold sentinel -1"
241
+
242
+ vocab = int(max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())) + 1
243
+ model = _TinyLM(vocab=max(vocab, 8))
244
+ obj = _make_sdpo_trainer()
245
+
246
+ loss = obj._compute_sdpo_loss(model, batch)
247
+ val = float(loss.detach())
248
+ assert val == val and val not in (float("inf"), float("-inf"))
249
+ # Non-negative (JSD floor). The leak failure mode this guards against is a
250
+ # sentinel (-1) index reaching torch.gather (illegal → error) or a padding
251
+ # position contributing garbage → NaN/inf. A finite, non-negative scalar
252
+ # proves the clamp-to-0 + label=-100 sentinel masking worked.
253
+ assert val >= -1e-6
254
+
255
+ # Padding-leak guard: zeroing the (clamped) sentinel rows must not change
256
+ # the loss, since valid-mask labels already drop them. We verify by
257
+ # recomputing with the valid mask forced all-True on a fresh batch where
258
+ # the short row is genuinely shorter — instead we assert the simpler
259
+ # invariant: the loss equals the loss computed if we explicitly drop the
260
+ # invalid tail by truncating to the per-batch min-K.
261
+ min_k = int(s_valid.sum(dim=1).min())
262
+ if min_k < s_idx.shape[1]:
263
+ truncated = dict(batch)
264
+ truncated["student_response_idx"] = batch["student_response_idx"][:, :min_k]
265
+ truncated["teacher_response_idx"] = batch["teacher_response_idx"][:, :min_k]
266
+ truncated["student_response_valid"] = batch["student_response_valid"][:, :min_k]
267
+ truncated["teacher_response_valid"] = batch["teacher_response_valid"][:, :min_k]
268
+ # Same model state (no grad step taken) → deterministic forward.
269
+ loss_trunc = obj._compute_sdpo_loss(model, truncated)
270
+ # The full-batch loss includes the long row's extra valid tokens, so it
271
+ # need not equal the truncated loss; we only assert both are finite and
272
+ # the sentinel tail produced no NaN/inf (the real leak failure mode).
273
+ vt = float(loss_trunc.detach())
274
+ assert vt == vt and vt not in (float("inf"), float("-inf"))
docs/adrs/ADR-010-feature-deletion-datagen.md CHANGED
@@ -48,7 +48,7 @@ package.
48
 
49
  ## Considered Options
50
 
51
- - **A. `FeatureDeletionEnv` that inverts OSS SWE substrates (revert gold patch) + online pass-rate difficulty gate + sandbox/AST reward-hacking safeguards** (chosen)
52
  - **B. Greenfield repo-scraping generator (clone arbitrary GitHub repos, delete AST nodes, hope tests cover them)**
53
  - **C. Skip generation; reuse SWE-bench-lite tasks as-is without a deletion/inversion layer**
54
 
@@ -62,8 +62,9 @@ adapters that invert the 5 OSS datasets by reverting their gold patch, a
62
  `PASS_TO_PASS`, gold patch restores green, deletion is reachable from tests),
63
  an online pass-rate difficulty gate, and reward-hacking safeguards
64
  (pre-task scrub of `__pycache__`/`.mypy_cache`/`.class`/`.git`; allowlisted
65
- sandbox without `find`/`strings`/`unzip`/decompilers; AST provenance monitor
66
- that masks reward when deleted symbols reappear via non-implementation paths).
 
67
  A TRL `reward_fn(prompts, completions, **kwargs) -> list[float]` adapter wires
68
  it to the RL loop.
69
 
@@ -74,7 +75,7 @@ it to the RL loop.
74
  - **Positive**: Online difficulty gate matches the actual recipe.
75
  - **Negative**: Bounded to what the OSS substrates cover (Python-dominant; SWE-bench is Python/JS-heavy). Other languages need new substrates. Documented as a known coverage limit.
76
  - **Negative**: Running tests in a sandbox requires Docker images per substrate; CPU-pool generation has real wall-clock cost (~15 node-days to invert all 21k SWE-rebench tasks per research/06). Mitigated by reusing the substrates' published Docker images and generating lazily.
77
- - **Negative**: Reward-hacking safeguards are a moving target; the AST provenance monitor is heuristic and will have false negatives. Mitigated by treating it as defense-in-depth (sandbox lockdown is the primary control) and logging suspected hacks for review.
78
  - **Neutral**: Adds a `[datagen]` optional extra (datasets, docker SDK).
79
 
80
  ## Pros and Cons of the Options
@@ -142,14 +143,15 @@ remediated where possible without Docker:
142
  (coverage of the changed region by the failing tests, or revert-provenance)
143
  needs the live Docker materializers. **This is the same `[~]` gate as the
144
  substrate-inversion e2e — see below.**
145
- - **[OPEN] `HackMonitor` is a substring matcher, not the AST-provenance monitor
146
- the ADR advertises** (DeepSeek P0). It flags cache/decompiler signatures in the
147
- trajectory but does no AST/symbol-reappearance analysis, and is bypassable by
148
- string-concat. With the scrub now in place as the primary control, the monitor
149
- is correctly-scoped defense-in-depth — but the ADR's §3c "AST provenance
150
- monitor" language overstates it. Re-scoped: it is a *signature-based* monitor;
151
- a genuine AST provenance check (scan the agent's patch for reintroduced
152
- `deleted_symbols` reached via non-implementation paths) is a follow-up.
 
153
  - **[OPEN — recipe fidelity] Curriculum ignores rollout-turns and
154
  thinking-token count** (DeepSeek, GPT-5.5). The Composer 2 tech report keys the
155
  curriculum on these; the implementation tracks only pass-rate. Follow-up:
 
48
 
49
  ## Considered Options
50
 
51
+ - **A. `FeatureDeletionEnv` that inverts OSS SWE substrates (revert gold patch) + online pass-rate difficulty gate + sandbox + signature/patch-provenance reward-hacking safeguards** (chosen)
52
  - **B. Greenfield repo-scraping generator (clone arbitrary GitHub repos, delete AST nodes, hope tests cover them)**
53
  - **C. Skip generation; reuse SWE-bench-lite tasks as-is without a deletion/inversion layer**
54
 
 
62
  `PASS_TO_PASS`, gold patch restores green, deletion is reachable from tests),
63
  an online pass-rate difficulty gate, and reward-hacking safeguards
64
  (pre-task scrub of `__pycache__`/`.mypy_cache`/`.class`/`.git`; allowlisted
65
+ sandbox without `find`/`strings`/`unzip`/decompilers; signature + patch-provenance
66
+ monitor that masks reward when deleted symbols reappear via non-implementation
67
+ paths — including string-concat-obfuscated cache reads).
68
  A TRL `reward_fn(prompts, completions, **kwargs) -> list[float]` adapter wires
69
  it to the RL loop.
70
 
 
75
  - **Positive**: Online difficulty gate matches the actual recipe.
76
  - **Negative**: Bounded to what the OSS substrates cover (Python-dominant; SWE-bench is Python/JS-heavy). Other languages need new substrates. Documented as a known coverage limit.
77
  - **Negative**: Running tests in a sandbox requires Docker images per substrate; CPU-pool generation has real wall-clock cost (~15 node-days to invert all 21k SWE-rebench tasks per research/06). Mitigated by reusing the substrates' published Docker images and generating lazily.
78
+ - **Negative**: Reward-hacking safeguards are a moving target; the signature + patch-provenance monitor is heuristic and will have false negatives. Mitigated by treating it as defense-in-depth (sandbox lockdown is the primary control) and logging suspected hacks for review.
79
  - **Neutral**: Adds a `[datagen]` optional extra (datasets, docker SDK).
80
 
81
  ## Pros and Cons of the Options
 
143
  (coverage of the changed region by the failing tests, or revert-provenance)
144
  needs the live Docker materializers. **This is the same `[~]` gate as the
145
  substrate-inversion e2e — see below.**
146
+ - **[RESOLVED — ADR-012] `HackMonitor` was a substring matcher, not the
147
+ AST-provenance monitor the ADR advertised** (DeepSeek P0). It flagged
148
+ cache/decompiler signatures in the trajectory but did no symbol-reappearance
149
+ analysis, and was bypassable by string-concat. With the scrub now in place as
150
+ the primary control, the monitor is correctly-scoped defense-in-depth. ADR-012
151
+ re-scoped the language to "signature + patch-provenance monitor" (not "AST")
152
+ and added a patch-provenance layer: a deleted symbol reappearing verbatim in
153
+ the agent's patch alongside a cache/bytecode read — normalized to defeat
154
+ string-concat obfuscation (`"__py"+"cache__"`) — is now flagged.
155
  - **[OPEN — recipe fidelity] Curriculum ignores rollout-turns and
156
  thinking-token count** (DeepSeek, GPT-5.5). The Composer 2 tech report keys the
157
  curriculum on these; the implementation tracks only pass-rate. Follow-up: