Baladithya Balamurugan Claude Opus 4.8 (1M context) commited on
Commit
7a55e1e
·
1 Parent(s): c11cf49

Wave 2: 4 new modules (kill-switch, EKS/SageMaker executors, DockerSandbox) + B4/B7 completion

Browse files

Built by the parallel execution team (worktree-isolated), integrated + tested here.

New modules (all CPU-testable via mock/lazy-import; optional deps gated):
- composer_replication/safety/ — HeldOutGuard run-level collapse kill-switch
(held-out-declines-while-reward-rises streak + KL-to-init hard-stop 0.08 +
proxy-real hacking-gap; EMA-denoised, latched-fire, CollapseStopError). The
documented #2 safeguard for the self-evolving flywheel. 23 tests.
- composer_replication/diloco/serverless/eks.py — EKSExecutor satisfying the
ServerlessExecutor Protocol via a SINGLE Kubernetes Indexed Job → N rank-ordered
ReplicaHandles, gang-cancel (Background propagation), REPLICA_RANK via the
downward API, S3 rendezvous (IRSA). 28 tests (mock BatchV1/CoreV1).
- composer_replication/diloco/serverless/sagemaker.py — SageMakerExecutor
(boto3 create_training_job, same S3 rendezvous, status mapping). +13-test
module written during integration (the build agent shipped it test-less).
- composer_replication/datagen/docker_sandbox.py — DockerSandbox (ephemeral
container, --network none, mem/pids limits, gVisor runtime option) + refactored
the per-class _scrub_tree into a shared module-level scrub_tree free function
so every sandbox backend applies the identical reward-hack control. Live Docker
tests pass; LocalSubprocessSandbox/FeatureDeletionEnv unaffected (review: clean).

Wiring + completion:
- Re-exported EKSExecutor/SageMakerExecutor (serverless __init__) and
DockerSandbox/scrub_tree (datagen __init__).
- pyproject: added [eks] (kubernetes) + [aws] (boto3) extras.
- B7-complete: added make_dr_grpo_config/make_po_config/PO_OBJECTIVES to the
TOP-LEVEL __all__ (were importable but missing from __all__).
- B4-complete: reconciled the 4 surviving stale "115 passing" current-framed
claims (README/OVERVIEW/VISION_VALIDATION) to the canonical 266/62.
- All new files ruff-clean (E,F,W,I,N,UP,B).

Full suite: 355 passed / 65 skipped / 1 flaky-under-contention (spike-006
loss-trend test, passes in isolation — tracked as R11, not a regression).
Wave-3 backlog (R1-R12) filed in docs/BACKLOG_RESOLUTION_2026-06-09.md from the
concurrent review team.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

composer_replication/__init__.py CHANGED
@@ -132,6 +132,9 @@ __all__ = [
132
  "replay_trace",
133
  # Trainer
134
  "ComposerReplicationTrainer",
 
 
 
135
  # DiLoCo (optional)
136
  "make_diloco_outer_loop",
137
  # Meta
 
132
  "replay_trace",
133
  # Trainer
134
  "ComposerReplicationTrainer",
135
+ "make_dr_grpo_config",
136
+ "make_po_config",
137
+ "PO_OBJECTIVES",
138
  # DiLoCo (optional)
139
  "make_diloco_outer_loop",
140
  # Meta
composer_replication/datagen/__init__.py CHANGED
@@ -8,6 +8,7 @@ Public surface:
8
  - FeatureDeletionTask — the task tuple (schema.py)
9
  - FeatureDeletionEnv — Gym/OpenEnv-style env + TRL reward_fn adapter (env.py)
10
  - Sandbox / FakeSandbox / LocalSubprocessSandbox — execution backends (sandbox.py)
 
11
  - HackMonitor — reward-hacking provenance monitor (monitor.py)
12
  - DifficultyCurriculum — online pass-rate difficulty gate (curriculum.py)
13
  - validate_task — 4-gate solvability validator (validator.py)
@@ -20,11 +21,13 @@ from __future__ import annotations
20
  from composer_replication.datagen.curriculum import DifficultyCurriculum
21
  from composer_replication.datagen.env import FeatureDeletionEnv, StepResult
22
  from composer_replication.datagen.monitor import HackMonitor
 
23
  from composer_replication.datagen.sandbox import (
24
  FakeSandbox,
25
  LocalSubprocessSandbox,
26
  Sandbox,
27
  TestRunResult,
 
28
  )
29
  from composer_replication.datagen.schema import FeatureDeletionTask
30
  from composer_replication.datagen.substrates import SweBenchAdapter
@@ -37,7 +40,9 @@ __all__ = [
37
  "Sandbox",
38
  "FakeSandbox",
39
  "LocalSubprocessSandbox",
 
40
  "TestRunResult",
 
41
  "HackMonitor",
42
  "DifficultyCurriculum",
43
  "validate_task",
 
8
  - FeatureDeletionTask — the task tuple (schema.py)
9
  - FeatureDeletionEnv — Gym/OpenEnv-style env + TRL reward_fn adapter (env.py)
10
  - Sandbox / FakeSandbox / LocalSubprocessSandbox — execution backends (sandbox.py)
11
+ - DockerSandbox — hardened ephemeral-container backend (docker_sandbox.py)
12
  - HackMonitor — reward-hacking provenance monitor (monitor.py)
13
  - DifficultyCurriculum — online pass-rate difficulty gate (curriculum.py)
14
  - validate_task — 4-gate solvability validator (validator.py)
 
21
  from composer_replication.datagen.curriculum import DifficultyCurriculum
22
  from composer_replication.datagen.env import FeatureDeletionEnv, StepResult
23
  from composer_replication.datagen.monitor import HackMonitor
24
+ from composer_replication.datagen.docker_sandbox import DockerSandbox
25
  from composer_replication.datagen.sandbox import (
26
  FakeSandbox,
27
  LocalSubprocessSandbox,
28
  Sandbox,
29
  TestRunResult,
30
+ scrub_tree,
31
  )
32
  from composer_replication.datagen.schema import FeatureDeletionTask
33
  from composer_replication.datagen.substrates import SweBenchAdapter
 
40
  "Sandbox",
41
  "FakeSandbox",
42
  "LocalSubprocessSandbox",
43
+ "DockerSandbox",
44
  "TestRunResult",
45
+ "scrub_tree",
46
  "HackMonitor",
47
  "DifficultyCurriculum",
48
  "validate_task",
composer_replication/datagen/docker_sandbox.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """docker_sandbox.py — the hardened container backend for the FD env (ADR-010 §3).
2
+
3
+ `DockerSandbox` is a drop-in `Sandbox` (boot/exec/run_tests/trajectory/
4
+ is_command_allowed) that runs the agent's tool calls and the verifiable test
5
+ command inside an ephemeral, locked-down Docker container instead of a raw host
6
+ subprocess. It is the production execution path for genuinely UNTRUSTED
7
+ model-generated code — the LocalSubprocessSandbox sibling runs everything in the
8
+ host process and only enforces the scrub + denylist, which is fine for
9
+ first-party / dev use but is NOT a host-security boundary.
10
+
11
+ The lockdown recipe (CIS Docker benchmark 5.x + gVisor guidance, see the
12
+ `sandbox-container` research digest):
13
+ - `network_mode='none'` — no egress: no decompiler downloads, no
14
+ signature exfil/recovery over the wire (the SWE-RL reward-hack threat).
15
+ - `read_only=True` root fs + a small `tmpfs={'/tmp': ...}` for scratch.
16
+ - the working tree bind-mounted RW at /work (the agent must mutate the repo).
17
+ - `cap_drop=['ALL']` + `security_opt=['no-new-privileges:true']`.
18
+ - `user='1000:1000'` — never run agent code as root in the container.
19
+ - `pids_limit` (fork-bomb guard) + `mem_limit==memswap_limit` (OOM guard, no
20
+ swap) + `nano_cpus` (CPU quota).
21
+ - optional `runtime='runsc'` (gVisor) — a userspace kernel that intercepts
22
+ syscalls so a kernel-exploit payload hits the Sentry, not the host. This is
23
+ the RECOMMENDED runtime for untrusted model code per the ADR-010 threat
24
+ model, but requires host setup (`sudo runsc install` writes the 'runsc'
25
+ entry into /etc/docker/daemon.json + dockerd restart) so it is OPTIONAL and
26
+ defaults to None (= the daemon default, runc).
27
+
28
+ PRIMARY reward-hack control: the SAME host-side `scrub_tree(workdir)` used by
29
+ LocalSubprocessSandbox runs in `boot()` BEFORE the container starts. The bind
30
+ mount is shared host<->container, so scrubbing __pycache__/.git/*.pyc on the
31
+ host pre-boot is exactly equivalent to scrubbing inside the container. The
32
+ command denylist remains cheap defense-in-depth, not the wall.
33
+
34
+ `docker` is LAZY-imported inside methods so the pure-Python core and the
35
+ FakeSandbox path never require the SDK; a clear RuntimeError is raised if the
36
+ SDK or the daemon is absent.
37
+ """
38
+ from __future__ import annotations
39
+
40
+ import shlex
41
+ from dataclasses import dataclass, field
42
+ from uuid import uuid4
43
+
44
+ from composer_replication.datagen.sandbox import (
45
+ SANDBOX_DENYLIST,
46
+ TestRunResult,
47
+ scrub_tree,
48
+ )
49
+
50
+ # Label stamped on every container we create, so a reaper can sweep ephemeral
51
+ # containers leaked by a crashed episode (docker-py-native `--rm` durability
52
+ # without the auto_remove log-loss/race problem).
53
+ _LABEL_KEY = "composer_replication"
54
+ _LABEL_VALUE = "datagen"
55
+
56
+
57
+ def _require_docker():
58
+ """Lazy-import the docker SDK and return the module, or raise a clear
59
+ RuntimeError if the SDK is not installed. (Kept separate from the client so
60
+ callers can introspect `docker.errors` without opening a connection.)"""
61
+ try:
62
+ import docker # noqa: PLC0415 (intentional lazy import)
63
+ except ImportError as e: # pragma: no cover - exercised via monkeypatch
64
+ raise RuntimeError(
65
+ "DockerSandbox requires the 'docker' Python SDK (docker>=7). "
66
+ "Install it with `pip install docker` (or the project's [datagen] "
67
+ "extra). It is lazy-imported so the FakeSandbox/core path never "
68
+ "needs it."
69
+ ) from e
70
+ return docker
71
+
72
+
73
+ def _make_client():
74
+ """Construct a `docker.from_env()` client or raise a clear RuntimeError if
75
+ the daemon is unreachable. The SDK constructs lazily, so we ping with
76
+ `client.ping()` to surface a dead daemon here rather than at first use."""
77
+ docker = _require_docker()
78
+ try:
79
+ client = docker.from_env()
80
+ client.ping()
81
+ return client
82
+ except Exception as e: # docker.errors.DockerException, ConnectionError, ...
83
+ raise RuntimeError(
84
+ "DockerSandbox could not reach a Docker daemon (is `docker info` "
85
+ f"healthy?). Underlying error: {e!r}"
86
+ ) from e
87
+
88
+
89
+ def runsc_available(client=None) -> bool:
90
+ """True iff the gVisor 'runsc' runtime is registered with the daemon.
91
+
92
+ Mirrors the `_docker_available()` gating philosophy: runsc is not installed
93
+ on most dev/CI boxes, so any runsc-specific behavior must be gated on this.
94
+ """
95
+ try:
96
+ client = client or _make_client()
97
+ runtimes = client.info().get("Runtimes", {}) or {}
98
+ return "runsc" in runtimes
99
+ except Exception:
100
+ return False
101
+
102
+
103
+ @dataclass
104
+ class DockerSandbox:
105
+ """Hardened ephemeral-container `Sandbox`. See module docstring.
106
+
107
+ Args:
108
+ workdir: host path to the materialized repo; bind-mounted RW at /work and
109
+ scrubbed on the host before boot (the primary reward-hack
110
+ control). MUST be an existing directory by `boot()` time.
111
+ runtime: None (=> daemon default, runc) or 'runsc' (gVisor) for untrusted
112
+ model code. Requires host-side `sudo runsc install` + dockerd
113
+ restart; gate with `runsc_available()`.
114
+ mem_limit / memswap_limit: OOM guard; equal values forbid swap.
115
+ pids_limit: fork-bomb guard.
116
+ nano_cpus: CPU quota in 1e-9 CPUs (2_000_000_000 == 2 CPUs).
117
+ user: non-root uid:gid the agent code runs as inside the container.
118
+ exec_timeout_s: wall-clock cap injected via coreutils `timeout` (exec_run
119
+ has no timeout param — docker-py #2651).
120
+ """
121
+
122
+ workdir: str
123
+ runtime: str | None = None
124
+ mem_limit: str = "1g"
125
+ memswap_limit: str = "1g"
126
+ pids_limit: int = 256
127
+ nano_cpus: int = 2_000_000_000 # 2 CPUs
128
+ user: str = "1000:1000"
129
+ container_workdir: str = "/work"
130
+ tmpfs_size: str = "64m"
131
+ exec_timeout_s: int = 600
132
+ keep_root_writable: bool = False # escape hatch if read-only fs breaks tooling
133
+
134
+ _trajectory: list[dict] = field(default_factory=list, init=False)
135
+ booted_image: str | None = field(default=None, init=False)
136
+ _client: object | None = field(default=None, init=False)
137
+ _container: object | None = field(default=None, init=False)
138
+
139
+ # ---- construction of the hardening kwargs --------------------------------
140
+
141
+ def container_kwargs(self, image: str) -> dict:
142
+ """The full hardened `containers.run` kwarg set. Pulled out as a method
143
+ so the pure-unit tests can assert the config (network_disabled,
144
+ mem_limit, runtime, ...) WITHOUT a live daemon."""
145
+ kwargs: dict = {
146
+ "image": image,
147
+ # Long-lived idle container; exec_run drives the actual work.
148
+ "command": ["sleep", "infinity"],
149
+ "detach": True,
150
+ # --- network egress kill-switch ---
151
+ # network_disabled removes networking entirely; we ALSO set
152
+ # network_mode='none' for parity with the existing CLI substrate
153
+ # test (`--network none`) and belt-and-suspenders.
154
+ "network_disabled": True,
155
+ "network_mode": "none",
156
+ # --- filesystem lockdown ---
157
+ "read_only": not self.keep_root_writable,
158
+ "tmpfs": {"/tmp": f"rw,noexec,nosuid,size={self.tmpfs_size}"},
159
+ "volumes": {
160
+ self.workdir: {"bind": self.container_workdir, "mode": "rw"}
161
+ },
162
+ "working_dir": self.container_workdir,
163
+ # --- privilege lockdown ---
164
+ "user": self.user,
165
+ "cap_drop": ["ALL"],
166
+ "security_opt": ["no-new-privileges:true"],
167
+ # --- resource limits ---
168
+ "pids_limit": self.pids_limit,
169
+ "mem_limit": self.mem_limit,
170
+ "memswap_limit": self.memswap_limit,
171
+ "nano_cpus": self.nano_cpus,
172
+ # --- lifecycle / reaping ---
173
+ "name": f"fd-{uuid4().hex[:12]}",
174
+ "labels": {_LABEL_KEY: _LABEL_VALUE},
175
+ }
176
+ # runtime is OPTIONAL; only pass it through when set so the default
177
+ # (runc) path never references a runtime that may not exist.
178
+ if self.runtime:
179
+ kwargs["runtime"] = self.runtime
180
+ return kwargs
181
+
182
+ # ---- Sandbox Protocol ----------------------------------------------------
183
+
184
+ def boot(self, image: str) -> None:
185
+ """Scrub the HOST workdir (primary control), reap any leaked siblings,
186
+ then start the hardened ephemeral container."""
187
+ self.booted_image = image
188
+ self._trajectory = []
189
+ # PRIMARY reward-hack control — run on the host before the bind mount.
190
+ scrub_tree(self.workdir)
191
+
192
+ self._client = _make_client()
193
+ self.reap_leaked(self._client)
194
+
195
+ docker = _require_docker()
196
+ kwargs = self.container_kwargs(image)
197
+ try:
198
+ self._container = self._client.containers.run(**kwargs)
199
+ except docker.errors.ImageNotFound as e:
200
+ raise RuntimeError(
201
+ f"DockerSandbox.boot: image {image!r} not found locally and "
202
+ "could not be pulled (the container is --network none). Pull it "
203
+ f"on the host first. Underlying: {e!r}"
204
+ ) from e
205
+ except docker.errors.APIError as e:
206
+ raise RuntimeError(
207
+ f"DockerSandbox.boot: Docker API error starting {image!r} with "
208
+ f"runtime={self.runtime!r}: {e!r}"
209
+ ) from e
210
+
211
+ def is_command_allowed(self, command: str) -> bool:
212
+ # First-token-only check — see SANDBOX_DENYLIST notes. NOT a boundary on
213
+ # its own; the container isolation + host scrub are the real controls.
214
+ return command not in SANDBOX_DENYLIST
215
+
216
+ def _exec(self, cmd: str) -> tuple[int, str]:
217
+ """Run one shell command in the live container via exec_run, enforcing a
218
+ wall-clock cap with coreutils `timeout` (exec_run has no timeout param —
219
+ docker-py #2651). Returns (exit_code, combined_output)."""
220
+ if self._container is None:
221
+ raise RuntimeError("DockerSandbox.exec called before boot()")
222
+ # Wrap in `timeout` then a login-ish shell so PATH lookups work. demux
223
+ # keeps stdout/stderr separable but we combine them like the local
224
+ # sandbox does for the pytest parser.
225
+ wrapped = f"timeout {self.exec_timeout_s} {cmd}"
226
+ full = ["/bin/sh", "-c", wrapped]
227
+ res = self._container.exec_run(
228
+ full, workdir=self.container_workdir, demux=True
229
+ )
230
+ exit_code = res.exit_code if res.exit_code is not None else -1
231
+ out = res.output
232
+ # demux=True => output is (stdout_bytes|None, stderr_bytes|None).
233
+ if isinstance(out, tuple):
234
+ stdout_b, stderr_b = out
235
+ else: # defensive — some daemons return raw bytes
236
+ stdout_b, stderr_b = out, None
237
+ text = self._decode(stdout_b) + self._decode(stderr_b)
238
+ return exit_code, text
239
+
240
+ @staticmethod
241
+ def _decode(b) -> str:
242
+ """Untrusted code can emit non-UTF-8 bytes; never rely on text mode."""
243
+ if not b:
244
+ return ""
245
+ if isinstance(b, bytes):
246
+ return b.decode("utf-8", errors="replace")
247
+ return str(b)
248
+
249
+ def exec(self, action: dict) -> str:
250
+ self._trajectory.append(action)
251
+ cmd = str(action.get("command", ""))
252
+ if not cmd.strip():
253
+ return ""
254
+ head = cmd.strip().split()[0]
255
+ if not self.is_command_allowed(head):
256
+ return f"ERROR: command '{head}' is not allowed in the sandbox."
257
+ _exit, out = self._exec(cmd)
258
+ return out
259
+
260
+ def run_tests(self, test_command: str, tests: tuple[str, ...]) -> TestRunResult:
261
+ # shlex.quote each node id — SWE-bench node ids contain spaces/brackets
262
+ # (parametrized tests) and could otherwise break the shell or inject
263
+ # commands (matches LocalSubprocessSandbox).
264
+ node_ids = " ".join(shlex.quote(t) for t in tests)
265
+ cmd = f"{test_command} {node_ids}".strip()
266
+ returncode, out = self._exec(cmd)
267
+ # Same conservative parse as LocalSubprocessSandbox: a test is "passed"
268
+ # only if its node id appears with PASSED, else failed; collection
269
+ # errors => collected_ok False.
270
+ passed, failed = set(), set()
271
+ collected_ok = "errors during collection" not in out.lower()
272
+ for t in tests:
273
+ if f"{t} PASSED" in out or (returncode == 0 and not failed):
274
+ passed.add(t)
275
+ else:
276
+ failed.add(t)
277
+ return TestRunResult(
278
+ passed=frozenset(passed),
279
+ failed=frozenset(failed),
280
+ stdout=out,
281
+ collected_ok=collected_ok,
282
+ )
283
+
284
+ def trajectory(self) -> list[dict]:
285
+ return list(self._trajectory)
286
+
287
+ # ---- lifecycle / cleanup -------------------------------------------------
288
+
289
+ def close(self) -> None:
290
+ """Tear down the ephemeral container (force=True). Idempotent; swallows
291
+ errors so an already-gone container never masks the episode result."""
292
+ c = self._container
293
+ self._container = None
294
+ if c is not None:
295
+ try:
296
+ c.remove(force=True)
297
+ except Exception:
298
+ pass
299
+
300
+ @staticmethod
301
+ def reap_leaked(client=None) -> int:
302
+ """Sweep ephemeral containers leaked by a crashed episode (labelled
303
+ composer_replication=datagen). Callable at boot and shutdown. Returns
304
+ the count removed. Best-effort — never raises."""
305
+ removed = 0
306
+ try:
307
+ client = client or _make_client()
308
+ leaked = client.containers.list(
309
+ all=True, filters={"label": f"{_LABEL_KEY}={_LABEL_VALUE}"}
310
+ )
311
+ for c in leaked:
312
+ try:
313
+ c.remove(force=True)
314
+ removed += 1
315
+ except Exception:
316
+ pass
317
+ except Exception:
318
+ pass
319
+ return removed
320
+
321
+ def __enter__(self) -> DockerSandbox:
322
+ return self
323
+
324
+ def __exit__(self, *exc) -> None:
325
+ self.close()
326
+
327
+ def __del__(self): # pragma: no cover - best-effort GC cleanup
328
+ try:
329
+ self.close()
330
+ except Exception:
331
+ pass
composer_replication/datagen/sandbox.py CHANGED
@@ -50,11 +50,51 @@ SANDBOX_DENYLIST: frozenset[str] = frozenset(
50
  # primary control. `is_command_allowed` checks only the first whitespace token,
51
  # so `/usr/bin/find`, `sh -c "strings x"`, and especially `python -c "import
52
  # marshal,dis; ..."` all bypass it. The ADR-claimed PRIMARY control is the
53
- # pre-task cache/.git scrub in `boot()` — see `_scrub_tree` below, which is now
54
  # implemented (was previously absent, making the denylist the only — broken —
55
  # defense). The denylist remains as cheap defense-in-depth, not the wall.
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @runtime_checkable
59
  class Sandbox(Protocol):
60
  """An execution environment for one FD episode."""
@@ -120,13 +160,11 @@ class LocalSubprocessSandbox:
120
  _trajectory: list[dict] = field(default_factory=list)
121
  booted_image: str | None = None
122
 
123
- # Cache/history artifacts that let an agent recover a deleted signature
124
- # WITHOUT reimplementing it (the Composer-blog reward-hacks). Scrubbed at
125
- # boot() so the denylist isn't the only — and bypassable — line of defense.
126
- _SCRUB_NAMES: tuple[str, ...] = (
127
- "__pycache__", ".mypy_cache", ".pytest_cache", ".git", ".hg",
128
- )
129
- _SCRUB_SUFFIXES: tuple[str, ...] = (".pyc", ".pyo", ".class")
130
 
131
  def boot(self, image: str) -> None:
132
  self.booted_image = image
@@ -134,26 +172,9 @@ class LocalSubprocessSandbox:
134
  self._scrub_tree()
135
 
136
  def _scrub_tree(self) -> None:
137
- """PRIMARY reward-hack control (ADR-010 §3): physically remove byte-code
138
- caches, type-check caches, and VCS history from the working tree before
139
- the episode starts, so there is no cached signature to recover. This is
140
- the wall; the command denylist is only cheap defense-in-depth on top.
141
- Cross-family review 2026-05-29 found this was previously UNIMPLEMENTED —
142
- boot() only recorded the image string."""
143
- if not self.workdir or not os.path.isdir(self.workdir):
144
- return
145
- for root, dirs, files in os.walk(self.workdir, topdown=True):
146
- # Remove (and stop descending into) scrub-named directories.
147
- for d in list(dirs):
148
- if d in self._SCRUB_NAMES:
149
- shutil.rmtree(os.path.join(root, d), ignore_errors=True)
150
- dirs.remove(d)
151
- for f in files:
152
- if f.endswith(self._SCRUB_SUFFIXES):
153
- try:
154
- os.remove(os.path.join(root, f))
155
- except OSError:
156
- pass
157
 
158
  def is_command_allowed(self, command: str) -> bool:
159
  # NOTE: first-token-only check — see SANDBOX_DENYLIST comment. This is
 
50
  # primary control. `is_command_allowed` checks only the first whitespace token,
51
  # so `/usr/bin/find`, `sh -c "strings x"`, and especially `python -c "import
52
  # marshal,dis; ..."` all bypass it. The ADR-claimed PRIMARY control is the
53
+ # pre-task cache/.git scrub in `boot()` — see `scrub_tree` below, which is now
54
  # implemented (was previously absent, making the denylist the only — broken —
55
  # defense). The denylist remains as cheap defense-in-depth, not the wall.
56
 
57
 
58
+ # Cache/history artifacts that let an agent recover a deleted signature WITHOUT
59
+ # reimplementing it (the Composer-blog reward-hacks). Scrubbed at boot() so the
60
+ # denylist isn't the only — and bypassable — line of defense. These are module
61
+ # level so EVERY sandbox backend (LocalSubprocessSandbox AND DockerSandbox)
62
+ # applies the identical primary control via the shared `scrub_tree` free
63
+ # function below.
64
+ SCRUB_NAMES: tuple[str, ...] = (
65
+ "__pycache__", ".mypy_cache", ".pytest_cache", ".git", ".hg",
66
+ )
67
+ SCRUB_SUFFIXES: tuple[str, ...] = (".pyc", ".pyo", ".class")
68
+
69
+
70
+ def scrub_tree(workdir: str) -> None:
71
+ """PRIMARY reward-hack control (ADR-010 §3): physically remove byte-code
72
+ caches, type-check caches, and VCS history from the working tree before the
73
+ episode starts, so there is no cached signature to recover. This is the
74
+ wall; the command denylist is only cheap defense-in-depth on top.
75
+
76
+ Shared by LocalSubprocessSandbox (scrubs the subprocess cwd) and
77
+ DockerSandbox (scrubs the HOST workdir BEFORE the bind mount — the mount is
78
+ shared host<->container, so a host-side scrub pre-boot is exactly equivalent
79
+ to scrubbing inside the container). Cross-family review 2026-05-29 found this
80
+ was previously UNIMPLEMENTED — boot() only recorded the image string.
81
+ """
82
+ if not workdir or not os.path.isdir(workdir):
83
+ return
84
+ for root, dirs, files in os.walk(workdir, topdown=True):
85
+ # Remove (and stop descending into) scrub-named directories.
86
+ for d in list(dirs):
87
+ if d in SCRUB_NAMES:
88
+ shutil.rmtree(os.path.join(root, d), ignore_errors=True)
89
+ dirs.remove(d)
90
+ for f in files:
91
+ if f.endswith(SCRUB_SUFFIXES):
92
+ try:
93
+ os.remove(os.path.join(root, f))
94
+ except OSError:
95
+ pass
96
+
97
+
98
  @runtime_checkable
99
  class Sandbox(Protocol):
100
  """An execution environment for one FD episode."""
 
160
  _trajectory: list[dict] = field(default_factory=list)
161
  booted_image: str | None = None
162
 
163
+ # Back-compat aliases for the module-level scrub constants (callers/tests
164
+ # that referenced the old instance attributes keep working). The real
165
+ # control is the shared module-level `scrub_tree` free function.
166
+ _SCRUB_NAMES: tuple[str, ...] = SCRUB_NAMES
167
+ _SCRUB_SUFFIXES: tuple[str, ...] = SCRUB_SUFFIXES
 
 
168
 
169
  def boot(self, image: str) -> None:
170
  self.booted_image = image
 
172
  self._scrub_tree()
173
 
174
  def _scrub_tree(self) -> None:
175
+ """Delegate to the shared module-level `scrub_tree` (see its docstring).
176
+ Kept as a method for back-compat with existing callers."""
177
+ scrub_tree(self.workdir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  def is_command_allowed(self, command: str) -> bool:
180
  # NOTE: first-token-only check — see SANDBOX_DENYLIST comment. This is
composer_replication/datagen/tests/test_docker_sandbox.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """test_docker_sandbox.py — DockerSandbox unit + live-Docker coverage.
2
+
3
+ Two tiers, mirroring the repo's `test_docker_substrate_e2e.py` gating and the
4
+ ModalSpawnExecutor mock pattern:
5
+
6
+ 1. PURE-UNIT (always run, no daemon): a mock docker client/container asserts
7
+ the hardening config (network_disabled, mem_limit, runtime, cap_drop, ...),
8
+ the shared host-side scrub, the bytes->str decode, the pytest-summary
9
+ parse in run_tests, the denylist short-circuit, and the missing-SDK /
10
+ dead-daemon RuntimeError paths. These cover DockerSandbox even on a box
11
+ with no Docker.
12
+
13
+ 2. LIVE-DOCKER (skipif `_docker_available()`): boots a REAL hardened
14
+ `python:3.11-slim` container with --network none and runs the 4 inversion
15
+ gates + a cache-scrub check + a network-isolation check inside it. Since
16
+ Docker is available on this host, these ACTUALLY RUN.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import shutil
22
+ import subprocess
23
+ import tempfile
24
+ import textwrap
25
+ import types
26
+ from collections import namedtuple
27
+
28
+ import pytest
29
+
30
+ from composer_replication.datagen import docker_sandbox as ds_mod
31
+ from composer_replication.datagen.docker_sandbox import DockerSandbox
32
+ from composer_replication.datagen.sandbox import SANDBOX_DENYLIST, Sandbox
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Live-Docker gate (mirrors test_docker_substrate_e2e.py)
37
+ # ---------------------------------------------------------------------------
38
+ def _docker_available() -> bool:
39
+ """True iff a usable Docker daemon is reachable via the CLI."""
40
+ if shutil.which("docker") is None:
41
+ return False
42
+ try:
43
+ r = subprocess.run(["docker", "info"], capture_output=True, timeout=10)
44
+ return r.returncode == 0
45
+ except Exception:
46
+ return False
47
+
48
+
49
+ # A tiny image we know exists locally on this host (212MB python:3.11-slim).
50
+ # The live tests run `--network none`, so the image MUST be present already
51
+ # (no pull possible inside a network-disabled container).
52
+ _TEST_IMAGE = "python:3.11-slim"
53
+
54
+
55
+ def _image_present(image: str) -> bool:
56
+ try:
57
+ r = subprocess.run(
58
+ ["docker", "image", "inspect", image], capture_output=True, timeout=15
59
+ )
60
+ return r.returncode == 0
61
+ except Exception:
62
+ return False
63
+
64
+
65
+ # ===========================================================================
66
+ # TIER 1 — PURE-UNIT (mock docker client, no daemon required)
67
+ # ===========================================================================
68
+
69
+ _ExecResult = namedtuple("ExecResult", ["exit_code", "output"])
70
+
71
+
72
+ class _MockContainer:
73
+ """Stand-in for a docker-py Container.
74
+
75
+ Knobs:
76
+ - exec_script: callable(cmd_argv) -> (exit_code, stdout_bytes, stderr_bytes)
77
+ used to fake exec_run results. Defaults to an empty success.
78
+ """
79
+
80
+ def __init__(self, *, exec_script=None):
81
+ self.exec_calls: list[tuple] = []
82
+ self.removed = False
83
+ self.remove_force = None
84
+ self._exec_script = exec_script or (lambda argv: (0, b"", b""))
85
+
86
+ def exec_run(self, cmd, *, workdir=None, demux=False, **kw):
87
+ self.exec_calls.append((cmd, {"workdir": workdir, "demux": demux, **kw}))
88
+ code, out, err = self._exec_script(cmd)
89
+ if demux:
90
+ return _ExecResult(code, (out, err))
91
+ return _ExecResult(code, (out or b"") + (err or b""))
92
+
93
+ def remove(self, force=False):
94
+ self.removed = True
95
+ self.remove_force = force
96
+
97
+
98
+ class _MockContainers:
99
+ def __init__(self, container, *, run_raises=None):
100
+ self._container = container
101
+ self.run_kwargs: dict | None = None
102
+ self._run_raises = run_raises
103
+ self._listed: list = []
104
+
105
+ def run(self, **kwargs):
106
+ self.run_kwargs = kwargs
107
+ if self._run_raises is not None:
108
+ raise self._run_raises
109
+ return self._container
110
+
111
+ def list(self, all=False, filters=None): # noqa: A002 - matches docker-py
112
+ return list(self._listed)
113
+
114
+
115
+ class _MockClient:
116
+ def __init__(self, container, *, run_raises=None, runtimes=None):
117
+ self.containers = _MockContainers(container, run_raises=run_raises)
118
+ self._info = {"Runtimes": runtimes or {"runc": {}}}
119
+ self.pinged = False
120
+
121
+ def ping(self):
122
+ self.pinged = True
123
+ return True
124
+
125
+ def info(self):
126
+ return self._info
127
+
128
+
129
+ def _patch_client(monkeypatch, client):
130
+ """Make DockerSandbox use `client` instead of a real daemon, and stub the
131
+ lazy `docker` module so `docker.errors.*` resolve in boot()."""
132
+ monkeypatch.setattr(ds_mod, "_make_client", lambda: client)
133
+
134
+ fake_docker = types.ModuleType("docker")
135
+ errors = types.ModuleType("docker.errors")
136
+
137
+ class ImageNotFound(Exception): # noqa: N818 — mirrors docker.errors.ImageNotFound name
138
+ pass
139
+
140
+ class APIError(Exception):
141
+ pass
142
+
143
+ errors.ImageNotFound = ImageNotFound
144
+ errors.APIError = APIError
145
+ fake_docker.errors = errors
146
+ fake_docker.from_env = lambda: client
147
+ monkeypatch.setattr(ds_mod, "_require_docker", lambda: fake_docker)
148
+ return fake_docker
149
+
150
+
151
+ def test_dockersandbox_is_a_sandbox_protocol_instance():
152
+ """Drop-in for FakeSandbox/LocalSubprocessSandbox in env/validator."""
153
+ sb = DockerSandbox(workdir="/tmp")
154
+ assert isinstance(sb, Sandbox)
155
+
156
+
157
+ def test_container_kwargs_hardening_config():
158
+ """The lockdown recipe is present and correct WITHOUT any daemon."""
159
+ sb = DockerSandbox(workdir="/some/work")
160
+ kw = sb.container_kwargs(_TEST_IMAGE)
161
+
162
+ # network egress kill-switch
163
+ assert kw["network_disabled"] is True
164
+ assert kw["network_mode"] == "none"
165
+ # filesystem lockdown
166
+ assert kw["read_only"] is True
167
+ assert kw["tmpfs"] == {"/tmp": "rw,noexec,nosuid,size=64m"}
168
+ assert kw["volumes"] == {"/some/work": {"bind": "/work", "mode": "rw"}}
169
+ assert kw["working_dir"] == "/work"
170
+ # privilege lockdown
171
+ assert kw["user"] == "1000:1000"
172
+ assert kw["cap_drop"] == ["ALL"]
173
+ assert kw["security_opt"] == ["no-new-privileges:true"]
174
+ # resource limits
175
+ assert kw["pids_limit"] == 256
176
+ assert kw["mem_limit"] == "1g"
177
+ assert kw["memswap_limit"] == "1g" # == mem_limit => no swap
178
+ assert kw["nano_cpus"] == 2_000_000_000
179
+ # lifecycle
180
+ assert kw["detach"] is True
181
+ assert kw["command"] == ["sleep", "infinity"]
182
+ assert kw["labels"] == {"composer_replication": "datagen"}
183
+ assert kw["name"].startswith("fd-")
184
+
185
+
186
+ def test_runtime_optional_default_runc():
187
+ """runtime defaults to None => the 'runtime' kwarg is omitted (daemon
188
+ default runc), so the default path never names a runtime that may not
189
+ exist."""
190
+ assert "runtime" not in DockerSandbox(workdir="/w").container_kwargs("img")
191
+
192
+
193
+ def test_runtime_runsc_passed_through_when_set():
194
+ """When the caller opts into gVisor, runtime='runsc' reaches run()."""
195
+ kw = DockerSandbox(workdir="/w", runtime="runsc").container_kwargs("img")
196
+ assert kw["runtime"] == "runsc"
197
+
198
+
199
+ def test_resource_limits_are_configurable():
200
+ sb = DockerSandbox(
201
+ workdir="/w", mem_limit="256m", memswap_limit="256m",
202
+ pids_limit=64, nano_cpus=1_000_000_000, tmpfs_size="16m",
203
+ )
204
+ kw = sb.container_kwargs("img")
205
+ assert kw["mem_limit"] == "256m"
206
+ assert kw["memswap_limit"] == "256m"
207
+ assert kw["pids_limit"] == 64
208
+ assert kw["nano_cpus"] == 1_000_000_000
209
+ assert kw["tmpfs"] == {"/tmp": "rw,noexec,nosuid,size=16m"}
210
+
211
+
212
+ def test_keep_root_writable_escape_hatch():
213
+ kw = DockerSandbox(workdir="/w", keep_root_writable=True).container_kwargs("i")
214
+ assert kw["read_only"] is False
215
+
216
+
217
+ def test_boot_scrubs_host_tree_before_container(monkeypatch):
218
+ """PRIMARY reward-hack control: scrub_tree runs on the HOST workdir in boot()
219
+ BEFORE the container starts (the bind mount is shared)."""
220
+ with tempfile.TemporaryDirectory() as d:
221
+ os.makedirs(os.path.join(d, "__pycache__"))
222
+ with open(os.path.join(d, "__pycache__", "x.cpython-311.pyc"), "wb") as f:
223
+ f.write(b"\x00stale-bytecode")
224
+ os.makedirs(os.path.join(d, ".git"))
225
+ with open(os.path.join(d, "mod.pyc"), "wb") as f:
226
+ f.write(b"\x00")
227
+ with open(os.path.join(d, "keep.py"), "w") as f:
228
+ f.write("x = 1\n")
229
+
230
+ container = _MockContainer()
231
+ client = _MockClient(container)
232
+ _patch_client(monkeypatch, client)
233
+
234
+ sb = DockerSandbox(workdir=d)
235
+ sb.boot(_TEST_IMAGE)
236
+
237
+ assert not os.path.exists(os.path.join(d, "__pycache__"))
238
+ assert not os.path.exists(os.path.join(d, ".git"))
239
+ assert not os.path.exists(os.path.join(d, "mod.pyc"))
240
+ assert os.path.exists(os.path.join(d, "keep.py")) # real source survives
241
+ # the container was actually started with the hardened kwargs
242
+ assert client.containers.run_kwargs["network_disabled"] is True
243
+ assert sb.booted_image == _TEST_IMAGE
244
+
245
+
246
+ def test_exec_uses_timeout_and_workdir_and_denylist(monkeypatch):
247
+ """exec() wraps the command with coreutils `timeout`, runs in /work, and
248
+ short-circuits denied commands without touching the container."""
249
+ container = _MockContainer(
250
+ exec_script=lambda argv: (0, b"hello\n", b"")
251
+ )
252
+ client = _MockClient(container)
253
+ _patch_client(monkeypatch, client)
254
+
255
+ sb = DockerSandbox(workdir="/w", exec_timeout_s=42)
256
+ sb._container = container # skip boot for this focused unit
257
+
258
+ out = sb.exec({"command": "echo hello"})
259
+ assert out == "hello\n"
260
+ cmd_argv, kw = container.exec_calls[-1]
261
+ assert cmd_argv == ["/bin/sh", "-c", "timeout 42 echo hello"]
262
+ assert kw["workdir"] == "/work"
263
+ assert kw["demux"] is True
264
+
265
+ # a denylisted first token never reaches the container
266
+ n_before = len(container.exec_calls)
267
+ denied = sorted(SANDBOX_DENYLIST)[0]
268
+ msg = sb.exec({"command": f"{denied} something"})
269
+ assert "not allowed" in msg
270
+ assert len(container.exec_calls) == n_before # no new exec
271
+
272
+
273
+ def test_exec_decodes_non_utf8_bytes(monkeypatch):
274
+ """Untrusted code can emit invalid UTF-8 on stdout; we must not crash."""
275
+ container = _MockContainer(
276
+ exec_script=lambda argv: (0, b"\xff\xfe bad bytes", b"")
277
+ )
278
+ sb = DockerSandbox(workdir="/w")
279
+ sb._container = container
280
+ out = sb.exec({"command": "echo x"})
281
+ assert "bad bytes" in out # replaced, not crashed
282
+ assert "�" in out # U+FFFD replacement char
283
+
284
+
285
+ def test_run_tests_parses_pytest_summary(monkeypatch):
286
+ """run_tests applies the SAME conservative parse as LocalSubprocessSandbox:
287
+ a node id is passed iff '<nodeid> PASSED' appears."""
288
+ tests = ("t.py::test_a", "t.py::test_b")
289
+ out = b"t.py::test_a PASSED\nt.py::test_b FAILED\n1 failed, 1 passed\n"
290
+ container = _MockContainer(exec_script=lambda argv: (1, out, b""))
291
+ sb = DockerSandbox(workdir="/w")
292
+ sb._container = container
293
+
294
+ res = sb.run_tests("pytest -v", tests)
295
+ assert res.passed == frozenset({"t.py::test_a"})
296
+ assert res.failed == frozenset({"t.py::test_b"})
297
+ assert res.collected_ok is True
298
+
299
+
300
+ def test_run_tests_collection_error(monkeypatch):
301
+ tests = ("t.py::test_a",)
302
+ out = b"ERROR collecting t.py\n!!! errors during collection !!!\n"
303
+ container = _MockContainer(exec_script=lambda argv: (2, out, b""))
304
+ sb = DockerSandbox(workdir="/w")
305
+ sb._container = container
306
+ res = sb.run_tests("pytest -v", tests)
307
+ assert res.collected_ok is False
308
+ assert res.failed == frozenset({"t.py::test_a"})
309
+
310
+
311
+ def test_run_tests_quotes_node_ids(monkeypatch):
312
+ """Parametrized node ids with spaces/brackets must be shlex-quoted (shell
313
+ injection guard the repo already fixed for the local sandbox)."""
314
+ captured = {}
315
+
316
+ def script(argv):
317
+ captured["argv"] = argv
318
+ return (0, b"", b"")
319
+
320
+ container = _MockContainer(exec_script=script)
321
+ sb = DockerSandbox(workdir="/w")
322
+ sb._container = container
323
+ sb.run_tests("pytest -v", ("t.py::test_x[a b]",))
324
+ # the dangerous node id is quoted inside the timeout-wrapped sh -c string
325
+ shell_cmd = captured["argv"][-1]
326
+ assert "'t.py::test_x[a b]'" in shell_cmd
327
+
328
+
329
+ def test_exec_before_boot_raises():
330
+ sb = DockerSandbox(workdir="/w")
331
+ with pytest.raises(RuntimeError, match="before boot"):
332
+ sb.exec({"command": "echo hi"})
333
+
334
+
335
+ def test_trajectory_records_actions(monkeypatch):
336
+ container = _MockContainer()
337
+ sb = DockerSandbox(workdir="/w")
338
+ sb._container = container
339
+ sb.exec({"command": "echo a"})
340
+ sb.exec({"command": "echo b"})
341
+ traj = sb.trajectory()
342
+ assert [a["command"] for a in traj] == ["echo a", "echo b"]
343
+
344
+
345
+ def test_close_removes_container_force(monkeypatch):
346
+ container = _MockContainer()
347
+ sb = DockerSandbox(workdir="/w")
348
+ sb._container = container
349
+ sb.close()
350
+ assert container.removed is True
351
+ assert container.remove_force is True
352
+ # idempotent
353
+ sb.close()
354
+
355
+
356
+ def test_context_manager_closes(monkeypatch):
357
+ container = _MockContainer()
358
+ client = _MockClient(container)
359
+ _patch_client(monkeypatch, client)
360
+ with tempfile.TemporaryDirectory() as d:
361
+ with DockerSandbox(workdir=d) as sb:
362
+ sb.boot(_TEST_IMAGE)
363
+ assert sb._container is container
364
+ assert container.removed is True
365
+
366
+
367
+ def test_reap_leaked_sweeps_labelled_containers(monkeypatch):
368
+ leaked = [_MockContainer(), _MockContainer()]
369
+ container = _MockContainer()
370
+ client = _MockClient(container)
371
+ client.containers._listed = leaked
372
+ n = DockerSandbox.reap_leaked(client)
373
+ assert n == 2
374
+ assert all(c.removed for c in leaked)
375
+
376
+
377
+ def test_boot_image_not_found_raises_runtimeerror(monkeypatch):
378
+ container = _MockContainer()
379
+ client = _MockClient(container)
380
+ fake_docker = _patch_client(monkeypatch, client)
381
+ # make run() raise ImageNotFound
382
+ client.containers._run_raises = fake_docker.errors.ImageNotFound("nope")
383
+ with tempfile.TemporaryDirectory() as d:
384
+ sb = DockerSandbox(workdir=d)
385
+ with pytest.raises(RuntimeError, match="not found locally"):
386
+ sb.boot("ghost:latest")
387
+
388
+
389
+ def test_boot_api_error_raises_runtimeerror(monkeypatch):
390
+ container = _MockContainer()
391
+ client = _MockClient(container)
392
+ fake_docker = _patch_client(monkeypatch, client)
393
+ client.containers._run_raises = fake_docker.errors.APIError("bad runtime")
394
+ with tempfile.TemporaryDirectory() as d:
395
+ sb = DockerSandbox(workdir=d, runtime="runsc")
396
+ with pytest.raises(RuntimeError, match="Docker API error"):
397
+ sb.boot(_TEST_IMAGE)
398
+
399
+
400
+ def test_require_docker_missing_sdk_raises(monkeypatch):
401
+ """If the docker SDK is absent, a clear RuntimeError is raised (lazy import
402
+ means the FakeSandbox/core path never needs it)."""
403
+ import builtins
404
+
405
+ real_import = builtins.__import__
406
+
407
+ def fake_import(name, *args, **kwargs):
408
+ if name == "docker":
409
+ raise ImportError("No module named 'docker'")
410
+ return real_import(name, *args, **kwargs)
411
+
412
+ monkeypatch.setattr(builtins, "__import__", fake_import)
413
+ with pytest.raises(RuntimeError, match="requires the 'docker' Python SDK"):
414
+ ds_mod._require_docker()
415
+
416
+
417
+ def test_make_client_dead_daemon_raises(monkeypatch):
418
+ """A dead/unreachable daemon surfaces a clear RuntimeError at client build."""
419
+ fake_docker = types.ModuleType("docker")
420
+
421
+ def from_env():
422
+ raise RuntimeError("Cannot connect to the Docker daemon")
423
+
424
+ fake_docker.from_env = from_env
425
+ monkeypatch.setattr(ds_mod, "_require_docker", lambda: fake_docker)
426
+ with pytest.raises(RuntimeError, match="could not reach a Docker daemon"):
427
+ ds_mod._make_client()
428
+
429
+
430
+ def test_runsc_available_false_when_only_runc(monkeypatch):
431
+ client = _MockClient(_MockContainer(), runtimes={"runc": {}})
432
+ monkeypatch.setattr(ds_mod, "_make_client", lambda: client)
433
+ assert ds_mod.runsc_available() is False
434
+
435
+
436
+ def test_runsc_available_true_when_registered(monkeypatch):
437
+ client = _MockClient(_MockContainer(), runtimes={"runc": {}, "runsc": {}})
438
+ monkeypatch.setattr(ds_mod, "_make_client", lambda: client)
439
+ assert ds_mod.runsc_available() is True
440
+
441
+
442
+ # ===========================================================================
443
+ # TIER 2 — LIVE DOCKER (skipif on daemon availability)
444
+ # ===========================================================================
445
+
446
+ live = pytest.mark.skipif(
447
+ not _docker_available(),
448
+ reason="Docker daemon not available — DockerSandbox live tests are "
449
+ "hardware-gated (mirror test_docker_substrate_e2e.py).",
450
+ )
451
+
452
+ # Minimal synthetic FD task (same shape as test_docker_substrate_e2e.py).
453
+ _MODULE_SOLVED = textwrap.dedent('''\
454
+ def add(a, b):
455
+ return a + b
456
+
457
+ def mul(a, b):
458
+ return a * b
459
+ ''')
460
+ _MODULE_BROKEN = textwrap.dedent('''\
461
+ def add(a, b):
462
+ return a + b
463
+ ''')
464
+
465
+ # A stdlib-only pytest substitute: NO pip install, so --network none holds.
466
+ # Writes a tiny runner that imports `feature`, checks both add/mul, and prints
467
+ # pytest-style '<nodeid> PASSED/FAILED' lines that run_tests parses. We pass the
468
+ # NODE IDs to check on argv (trusted, test-author-controlled) and the runner
469
+ # evaluates FIXED expressions per node id — no eval() of untrusted input.
470
+ _RUNNER_TMPL = '''\
471
+ import sys
472
+ # Fixed expectations keyed by node id. The deleted-feature episode is detected
473
+ # by import-time AttributeError on `feature.mul`, never by evaluating a string.
474
+ CHECKS = {
475
+ "feature.py::test_add": lambda m: m.add(2, 3) == 5,
476
+ "feature.py::test_mul": lambda m: m.mul(2, 3) == 6,
477
+ }
478
+ nodeid = sys.argv[1]
479
+ try:
480
+ import feature
481
+ ok = bool(CHECKS[nodeid](feature))
482
+ except Exception as e:
483
+ print(nodeid, "FAILED", "(exc:", type(e).__name__, e, ")")
484
+ sys.exit(1)
485
+ print(nodeid, "PASSED" if ok else "FAILED")
486
+ sys.exit(0 if ok else 1)
487
+ '''
488
+
489
+ # Host-side network probe written into the workdir, then run inside the
490
+ # container as a plain file (avoids fragile inline `python -c` quoting through
491
+ # `sh -c`). Prints CONNECTED if egress works, BLOCKED otherwise.
492
+ _NETPROBE = '''\
493
+ import socket
494
+ s = socket.socket()
495
+ s.settimeout(3)
496
+ try:
497
+ s.connect(("1.1.1.1", 53))
498
+ print("CONNECTED")
499
+ except Exception as e:
500
+ print("BLOCKED", type(e).__name__)
501
+ '''
502
+
503
+
504
+ def _materialize(d: str, module_src: str) -> None:
505
+ with open(os.path.join(d, "feature.py"), "w") as f:
506
+ f.write(module_src)
507
+ with open(os.path.join(d, "runner.py"), "w") as f:
508
+ f.write(_RUNNER_TMPL)
509
+
510
+
511
+ @live
512
+ def test_live_image_present_guard():
513
+ """The live tests run --network none and cannot pull; assert the image is
514
+ already on the host so a missing-image failure reads clearly."""
515
+ if not _image_present(_TEST_IMAGE):
516
+ pytest.skip(f"{_TEST_IMAGE} not present locally; `docker pull {_TEST_IMAGE}` to enable")
517
+
518
+
519
+ @live
520
+ def test_live_four_inversion_gates_in_hardened_container():
521
+ """The 4 ADR-010 gates against a REAL hardened DockerSandbox container."""
522
+ if not _image_present(_TEST_IMAGE):
523
+ pytest.skip(f"{_TEST_IMAGE} not present locally")
524
+
525
+ target = "feature.py::test_mul" # FAIL_TO_PASS — exercises the deleted symbol
526
+ guard = "feature.py::test_add" # PASS_TO_PASS — must survive the deletion
527
+
528
+ def _run(module_src, node):
529
+ with tempfile.TemporaryDirectory() as d:
530
+ _materialize(d, module_src)
531
+ sb = DockerSandbox(workdir=d, exec_timeout_s=60)
532
+ sb.boot(_TEST_IMAGE)
533
+ try:
534
+ # run_tests appends the shlex-quoted node id to the command, and
535
+ # the runner uses it to pick which FIXED check to run.
536
+ res = sb.run_tests("python runner.py", (node,))
537
+ return node in res.passed, res.stdout
538
+ finally:
539
+ sb.close()
540
+
541
+ # Gate 1 — solved: both pass.
542
+ g1t, _ = _run(_MODULE_SOLVED, target)
543
+ g1g, _ = _run(_MODULE_SOLVED, guard)
544
+ assert g1t and g1g, "Gate 1 (baseline green) failed in hardened container"
545
+
546
+ # Gate 2 — broken: target FAILS (mul gone).
547
+ g2t, out2 = _run(_MODULE_BROKEN, target)
548
+ assert not g2t, f"Gate 2 (deletion breaks target) failed:\n{out2}"
549
+
550
+ # Gate 3 — broken: guard still PASSES.
551
+ g3g, out3 = _run(_MODULE_BROKEN, guard)
552
+ assert g3g, f"Gate 3 (remains functional) failed:\n{out3}"
553
+
554
+ # Gate 4 — gold restores: target passes again.
555
+ g4t, _ = _run(_MODULE_SOLVED, target)
556
+ assert g4t, "Gate 4 (gold restores) failed"
557
+
558
+
559
+ @live
560
+ def test_live_network_is_disabled():
561
+ """--network none / network_disabled actually blocks egress in the live
562
+ container — the reward-hack egress kill-switch."""
563
+ if not _image_present(_TEST_IMAGE):
564
+ pytest.skip(f"{_TEST_IMAGE} not present locally")
565
+ with tempfile.TemporaryDirectory() as d:
566
+ _materialize(d, _MODULE_SOLVED)
567
+ with open(os.path.join(d, "netprobe.py"), "w") as f:
568
+ f.write(_NETPROBE)
569
+ sb = DockerSandbox(workdir=d, exec_timeout_s=30)
570
+ sb.boot(_TEST_IMAGE)
571
+ try:
572
+ out = sb.exec({"command": "python netprobe.py"})
573
+ assert "CONNECTED" not in out, f"network egress was NOT blocked:\n{out}"
574
+ assert "BLOCKED" in out, f"unexpected network probe output:\n{out}"
575
+ finally:
576
+ sb.close()
577
+
578
+
579
+ @live
580
+ def test_live_cache_scrub_removes_bytecode():
581
+ """The cache scrub primary control holds on a real container: a stale .pyc
582
+ on the host mount is removed by boot() before the (broken) episode."""
583
+ with tempfile.TemporaryDirectory() as d:
584
+ _materialize(d, _MODULE_BROKEN)
585
+ os.makedirs(os.path.join(d, "__pycache__"), exist_ok=True)
586
+ with open(os.path.join(d, "__pycache__", "feature.cpython-311.pyc"), "wb") as f:
587
+ f.write(b"\x00stale-bytecode-with-mul-signature")
588
+
589
+ if not _image_present(_TEST_IMAGE):
590
+ # scrub is host-side and needs no daemon, but keep the live gate honest
591
+ pass
592
+ container = None
593
+ try:
594
+ sb = DockerSandbox(workdir=d)
595
+ sb.boot(_TEST_IMAGE)
596
+ container = sb
597
+ assert not os.path.exists(os.path.join(d, "__pycache__")), \
598
+ "cache scrub did not remove __pycache__ in DockerSandbox.boot()"
599
+ finally:
600
+ if container is not None:
601
+ container.close()
602
+
603
+
604
+ @live
605
+ def test_live_runsc_runtime():
606
+ """If gVisor is registered, boot with runtime='runsc' and run a test in it;
607
+ else skip (runsc is not installed on most hosts)."""
608
+ if not ds_mod.runsc_available():
609
+ pytest.skip("gVisor 'runsc' runtime not registered with this daemon")
610
+ if not _image_present(_TEST_IMAGE):
611
+ pytest.skip(f"{_TEST_IMAGE} not present locally")
612
+ with tempfile.TemporaryDirectory() as d:
613
+ _materialize(d, _MODULE_SOLVED)
614
+ sb = DockerSandbox(workdir=d, runtime="runsc", exec_timeout_s=60)
615
+ sb.boot(_TEST_IMAGE)
616
+ try:
617
+ res = sb.run_tests("python runner.py", ("feature.py::test_mul",))
618
+ assert "feature.py::test_mul" in res.passed
619
+ finally:
620
+ sb.close()
composer_replication/diloco/serverless/__init__.py CHANGED
@@ -47,6 +47,7 @@ from composer_replication.diloco.serverless.allreduce import (
47
  MockManager,
48
  ObjectStoreAllReduce,
49
  )
 
50
  from composer_replication.diloco.serverless.executor import (
51
  LocalProcessExecutor,
52
  ReplicaHandle,
@@ -55,13 +56,16 @@ from composer_replication.diloco.serverless.executor import (
55
  from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
56
  from composer_replication.diloco.serverless.modal import ModalExecutor
57
  from composer_replication.diloco.serverless.modal_spawn import ModalSpawnExecutor
 
58
 
59
  __all__ = [
 
60
  "HFJobsExecutor",
61
  "LocalProcessExecutor",
62
  "MockManager",
63
  "ModalExecutor",
64
  "ModalSpawnExecutor",
 
65
  "ObjectStoreAllReduce",
66
  "ReplicaHandle",
67
  "ServerlessExecutor",
 
47
  MockManager,
48
  ObjectStoreAllReduce,
49
  )
50
+ from composer_replication.diloco.serverless.eks import EKSExecutor
51
  from composer_replication.diloco.serverless.executor import (
52
  LocalProcessExecutor,
53
  ReplicaHandle,
 
56
  from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
57
  from composer_replication.diloco.serverless.modal import ModalExecutor
58
  from composer_replication.diloco.serverless.modal_spawn import ModalSpawnExecutor
59
+ from composer_replication.diloco.serverless.sagemaker import SageMakerExecutor
60
 
61
  __all__ = [
62
+ "EKSExecutor",
63
  "HFJobsExecutor",
64
  "LocalProcessExecutor",
65
  "MockManager",
66
  "ModalExecutor",
67
  "ModalSpawnExecutor",
68
+ "SageMakerExecutor",
69
  "ObjectStoreAllReduce",
70
  "ReplicaHandle",
71
  "ServerlessExecutor",
composer_replication/diloco/serverless/eks.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EKSExecutor — production Amazon EKS / Kubernetes-backed serverless executor.
2
+
3
+ This is the v0-finished k8s sibling of `ModalSpawnExecutor`. It implements
4
+ the `ServerlessExecutor` Protocol against the Kubernetes ``BatchV1Api`` using
5
+ the **single Indexed Job** topology recommended for gang-scheduled DiLoCo
6
+ replicas.
7
+
8
+ Topology (the load-bearing design choice)
9
+ ------------------------------------------
10
+ There are two ways to map N replicas onto k8s:
11
+
12
+ (A) ONE Indexed Job — ``completions=N, parallelism=N,
13
+ completionMode='Indexed'``. The control plane assigns each pod a
14
+ ``JOB_COMPLETION_INDEX`` 0..N-1 which IS the rank, all pods share one
15
+ rendezvous URI, scheduling is atomic, and a single delete cancels the
16
+ whole cohort.
17
+ (B) N separate non-indexed Jobs, one per rank.
18
+
19
+ `EKSExecutor` uses **(A)** because it is the better fit for DiLoCo: rank
20
+ assignment is free, scheduling is gang-atomic, and one delete tears down the
21
+ cohort — which matches ``ObjectStoreAllReduce``'s all-or-nothing barrier. The
22
+ reconciliation with the per-replica ``ReplicaHandle`` contract: ``launch_replicas``
23
+ creates ONE Indexed Job but still returns N ``ReplicaHandle`` objects
24
+ (``handles[i].rank == i``) whose ``metadata`` stores the SHARED
25
+ ``job_name``/``namespace`` plus that rank.
26
+
27
+ This is materially different from ``ModalSpawnExecutor`` where each handle is
28
+ an independent ``FunctionCall``:
29
+
30
+ * ``poll(handle)`` reads the single Job status and checks whether
31
+ ``handle.rank`` is in the run-length-compressed ``completed_indexes`` /
32
+ ``failed_indexes`` strings.
33
+ * ``cancel(handle)`` on ANY handle deletes the WHOLE Job (intentional gang
34
+ semantics — cancelling one rank tears down the whole replica cohort).
35
+
36
+ Rank plumbing
37
+ -------------
38
+ The repo's ``replica_entrypoint`` reads ``REPLICA_RANK``. We bridge the k8s
39
+ completion-index to that env var via the downward API rather than relying on
40
+ the auto-injected ``JOB_COMPLETION_INDEX``::
41
+
42
+ V1EnvVar(
43
+ name="REPLICA_RANK",
44
+ value_from=V1EnvVarSource(field_ref=V1ObjectFieldSelector(
45
+ field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']")),
46
+ )
47
+
48
+ so the unchanged entrypoint's ``REPLICA_RANK`` read just works. ``WORLD_SIZE``
49
+ is set as a literal env var.
50
+
51
+ S3 rendezvous via IRSA / Pod Identity
52
+ -------------------------------------
53
+ ``EKSExecutor`` accepts ``service_account_name`` and references it on the
54
+ PodSpec. The EKS Pod Identity / IRSA mutating webhook then injects
55
+ ``AWS_ROLE_ARN`` + ``AWS_WEB_IDENTITY_TOKEN_FILE`` (and a projected token
56
+ volume) into the pod, so ``boto3``/``s3fs``/``fsspec`` pick up credentials via
57
+ the web-identity provider with ZERO code change inside the replica — the
58
+ ``s3://`` rendezvous works out of the box. ``EKSExecutor`` only REFERENCES a
59
+ pre-annotated ServiceAccount; it never creates IAM/OIDC resources.
60
+
61
+ Sandboxing (advanced, optional)
62
+ -------------------------------
63
+ ``runtime_class_name`` references a pre-existing cluster-scoped ``RuntimeClass``
64
+ (``runsc`` for gVisor, ``kata`` for Kata). It defaults to ``None``.
65
+
66
+ .. warning::
67
+ Combining ``gpu`` with ``runtime_class_name`` is advanced and unverified.
68
+ gVisor (runsc) needs ``nvproxy`` enabled and only supports a fixed allowlist
69
+ of NVIDIA driver families; Kata runs a microVM that caps CPU/mem and needs
70
+ GPU passthrough (PCIe/IOMMU + NVIDIA Kata Manager + CDI). Do not silently
71
+ combine the two without operator validation. ``EKSExecutor`` cannot create
72
+ the RuntimeClass — it only references one that already exists.
73
+
74
+ References
75
+ ----------
76
+ - k8s Indexed Jobs: https://kubernetes.io/docs/tasks/job/indexed-parallel-processing-static/
77
+ - kubernetes-client/python job_crud example + V1JobSpec / V1JobStatus docs
78
+ - EKS IRSA: https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
79
+ - ADR-005 (executor protocol design)
80
+ """
81
+ from __future__ import annotations
82
+
83
+ import time
84
+ import uuid
85
+ from collections.abc import Callable, Mapping
86
+ from typing import Any
87
+
88
+ from composer_replication.diloco.serverless.executor import (
89
+ ReplicaHandle,
90
+ )
91
+
92
+ # Logical GPU spec ("A100"/"H100") -> (gpu_count_string, node_selector merge).
93
+ # The Protocol's `gpu` arg is a logical name; map it to a concrete EKS node
94
+ # class + GPU count rather than passing the opaque string straight through.
95
+ _GPU_SPEC_TABLE: dict[str, tuple[str, dict[str, str]]] = {
96
+ "A100": ("1", {"node.kubernetes.io/instance-type": "p4d.24xlarge"}),
97
+ "H100": ("1", {"node.kubernetes.io/instance-type": "p5.48xlarge"}),
98
+ "A10G": ("1", {"node.kubernetes.io/instance-type": "g5.xlarge"}),
99
+ "T4": ("1", {"node.kubernetes.io/instance-type": "g4dn.xlarge"}),
100
+ }
101
+
102
+
103
+ def _expand_indexes(spec: str | None) -> set[int]:
104
+ """Expand a run-length-compressed completion-index string to a set.
105
+
106
+ The k8s ``V1JobStatus.completed_indexes`` / ``failed_indexes`` fields are
107
+ strings like ``"1,3-5,7"`` (comma-separated singletons and ``a-b`` ranges).
108
+ ``_expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}``. Empty/None -> empty set.
109
+ """
110
+ out: set[int] = set()
111
+ if not spec:
112
+ return out
113
+ for token in spec.split(","):
114
+ token = token.strip()
115
+ if not token:
116
+ continue
117
+ if "-" in token:
118
+ lo_s, _, hi_s = token.partition("-")
119
+ try:
120
+ lo, hi = int(lo_s), int(hi_s)
121
+ except ValueError:
122
+ continue
123
+ if hi < lo:
124
+ lo, hi = hi, lo
125
+ out.update(range(lo, hi + 1))
126
+ else:
127
+ try:
128
+ out.add(int(token))
129
+ except ValueError:
130
+ continue
131
+ return out
132
+
133
+
134
+ class EKSExecutor:
135
+ """Run N DiLoCo replicas as a single Kubernetes Indexed Job on EKS.
136
+
137
+ Implements the `ServerlessExecutor` Protocol. ``launch_replicas`` creates
138
+ ONE Indexed Job (``completions == parallelism == n_replicas``,
139
+ ``completionMode='Indexed'``) and returns N ``ReplicaHandle`` objects that
140
+ all share the same ``job_name``/``namespace`` (gang semantics).
141
+
142
+ Args:
143
+ image: container image that has ``composer_replication`` installed and
144
+ runs the replica entrypoint.
145
+ namespace: k8s namespace for the Job. Default ``"default"``.
146
+ service_account_name: ServiceAccount to attach to the PodSpec for IRSA /
147
+ EKS Pod Identity S3 access. ``EKSExecutor`` references it; it does
148
+ NOT create it or any IAM/OIDC resources.
149
+ node_selector: extra node selector merged into the GPU node selector.
150
+ tolerations: PodSpec tolerations. If GPU is requested and the caller did
151
+ not supply tolerations, the standard ``nvidia.com/gpu`` NoSchedule
152
+ toleration is added automatically.
153
+ runtime_class_name: optional pre-existing RuntimeClass (e.g. ``"gvisor"``
154
+ / ``"kata"``). Default ``None``. See the module-level warning before
155
+ combining with ``gpu``.
156
+ command: container command. Defaults to the repo replica entrypoint
157
+ module ``["python", "-m",
158
+ "composer_replication.diloco.serverless.replica_entrypoint"]``.
159
+ cpu_request / memory_request: PodSpec resource requests.
160
+ ttl_seconds_after_finished: auto-delete the finished Job (and its pods,
161
+ cascadingly) after this many seconds. Default 3600.
162
+ backoff_limit: Job retry budget. Default 0 (fail-fast — RL gangs usually
163
+ do NOT want the k8s default of 6 retries).
164
+ gpu_resource_key: the GPU resource key. Default ``"nvidia.com/gpu"``.
165
+ run_id: optional run id baked into the generated Job name.
166
+ batch_api / core_api: dependency-injected ``BatchV1Api`` / ``CoreV1Api``
167
+ instances. When ``None`` (the default), they are built lazily on
168
+ first use via in-cluster or kube-config loading. Tests inject mocks.
169
+
170
+ Raises:
171
+ RuntimeError: if the ``kubernetes`` client is not installed AND no api
172
+ was injected (the import is needed to construct V1 model objects).
173
+ """
174
+
175
+ backend_name = "eks"
176
+ # Pods are network-isolated by default; rendezvous is S3 (ObjectStoreAllReduce).
177
+ supports_inter_replica_network = False
178
+
179
+ def __init__(
180
+ self,
181
+ image: str,
182
+ *,
183
+ namespace: str = "default",
184
+ service_account_name: str | None = None,
185
+ node_selector: dict[str, str] | None = None,
186
+ tolerations: list[Any] | None = None,
187
+ runtime_class_name: str | None = None,
188
+ command: list[str] | None = None,
189
+ cpu_request: str = "4",
190
+ memory_request: str = "16Gi",
191
+ ttl_seconds_after_finished: int = 3600,
192
+ backoff_limit: int = 0,
193
+ gpu_resource_key: str = "nvidia.com/gpu",
194
+ run_id: str | None = None,
195
+ batch_api: Any = None,
196
+ core_api: Any = None,
197
+ ) -> None:
198
+ # `kubernetes` is only strictly required when we have to BUILD V1 model
199
+ # objects ourselves (launch_replicas) or load cluster config (when no
200
+ # api is injected). We surface a clear error here only if we definitely
201
+ # need it and it is absent — i.e. when no api was injected. When apis
202
+ # ARE injected (tests, or callers that pre-built clients), we tolerate a
203
+ # missing top-level package and lazy-import `client` per call.
204
+ if batch_api is None or core_api is None:
205
+ try:
206
+ import kubernetes # noqa: F401
207
+ except ImportError as e:
208
+ raise RuntimeError(
209
+ 'EKSExecutor requires the kubernetes client: '
210
+ 'pip install "kubernetes>=29" (or '
211
+ "`pip install -e .[serverless]`). Got: " + repr(e)
212
+ ) from e
213
+
214
+ self.image = image
215
+ self.namespace = namespace
216
+ self.service_account_name = service_account_name
217
+ self.node_selector = dict(node_selector) if node_selector else None
218
+ self.tolerations = list(tolerations) if tolerations else None
219
+ self.runtime_class_name = runtime_class_name
220
+ self.command = command or [
221
+ "python",
222
+ "-m",
223
+ "composer_replication.diloco.serverless.replica_entrypoint",
224
+ ]
225
+ self.cpu_request = cpu_request
226
+ self.memory_request = memory_request
227
+ self.ttl_seconds_after_finished = ttl_seconds_after_finished
228
+ self.backoff_limit = backoff_limit
229
+ self.gpu_resource_key = gpu_resource_key
230
+ self.run_id = run_id or "diloco"
231
+
232
+ self._batch_api = batch_api
233
+ self._core_api = core_api
234
+ # rank -> {"job_name", "namespace", "result"}; lets poll/collect cache.
235
+ self._handles: dict[int, dict[str, Any]] = {}
236
+
237
+ # -----------------------------------------------------------------
238
+ # Lazy client construction (config loading only when not injected)
239
+ # -----------------------------------------------------------------
240
+
241
+ def _load_config(self) -> None:
242
+ """Load k8s config once: in-cluster first, then ~/.kube/config."""
243
+ from kubernetes import config
244
+
245
+ try:
246
+ config.load_incluster_config()
247
+ except config.ConfigException:
248
+ config.load_kube_config()
249
+
250
+ def _batch(self) -> Any:
251
+ if self._batch_api is None:
252
+ from kubernetes import client
253
+
254
+ self._load_config()
255
+ self._batch_api = client.BatchV1Api()
256
+ return self._batch_api
257
+
258
+ def _core(self) -> Any:
259
+ if self._core_api is None:
260
+ from kubernetes import client
261
+
262
+ self._load_config()
263
+ self._core_api = client.CoreV1Api()
264
+ return self._core_api
265
+
266
+ # -----------------------------------------------------------------
267
+ # Job-spec construction
268
+ # -----------------------------------------------------------------
269
+
270
+ def _build_env(
271
+ self, world_size: int, entrypoint_args: Mapping[str, Any]
272
+ ) -> list[Any]:
273
+ """Build the container env list, including the downward-API rank var."""
274
+ from kubernetes import client
275
+
276
+ env: list[Any] = [
277
+ # REPLICA_RANK from the per-pod completion-index annotation via the
278
+ # downward API — bridges k8s indexing to the repo entrypoint's
279
+ # REPLICA_RANK read with no entrypoint change.
280
+ client.V1EnvVar(
281
+ name="REPLICA_RANK",
282
+ value_from=client.V1EnvVarSource(
283
+ field_ref=client.V1ObjectFieldSelector(
284
+ field_path=(
285
+ "metadata.annotations["
286
+ "'batch.kubernetes.io/job-completion-index']"
287
+ )
288
+ )
289
+ ),
290
+ ),
291
+ client.V1EnvVar(name="WORLD_SIZE", value=str(world_size)),
292
+ ]
293
+ # rendezvous_uri (and any other scalar kwargs) passed as literal env so
294
+ # the entrypoint / user code can read them. `rank_env` is the
295
+ # LocalProcessExecutor convention — drop it (same as ModalSpawnExecutor).
296
+ for key, value in entrypoint_args.items():
297
+ if key == "rank_env":
298
+ continue
299
+ if isinstance(value, (str, int, float, bool)):
300
+ env.append(
301
+ client.V1EnvVar(name=key.upper(), value=str(value))
302
+ )
303
+ return env
304
+
305
+ def _build_resources(self, gpu: str | None) -> tuple[Any, dict[str, str], list[Any]]:
306
+ """Build V1ResourceRequirements + (node_selector, tolerations) for GPU.
307
+
308
+ Returns (resources, node_selector, tolerations). The GPU count is
309
+ ALWAYS a STRING ('1', not int 1) — the OpenAPI type for the limits map
310
+ is dict[str, str] and an int can serialize wrong or raise.
311
+ """
312
+ from kubernetes import client
313
+
314
+ requests = {"cpu": self.cpu_request, "memory": self.memory_request}
315
+ limits: dict[str, str] = {}
316
+ node_selector: dict[str, str] = dict(self.node_selector or {})
317
+ tolerations: list[Any] = list(self.tolerations or [])
318
+
319
+ if gpu is not None:
320
+ gpu_count, gpu_node_selector = _GPU_SPEC_TABLE.get(
321
+ gpu, ("1", {})
322
+ )
323
+ # STRING, always.
324
+ limits[self.gpu_resource_key] = str(gpu_count)
325
+ # Merge the mapped node selector under any caller-supplied one
326
+ # (caller wins on key conflicts).
327
+ for k, v in gpu_node_selector.items():
328
+ node_selector.setdefault(k, v)
329
+ # Auto-add the GPU NoSchedule toleration unless the caller overrode
330
+ # tolerations explicitly.
331
+ if not self.tolerations:
332
+ tolerations.append(
333
+ client.V1Toleration(
334
+ key=self.gpu_resource_key,
335
+ operator="Exists",
336
+ effect="NoSchedule",
337
+ )
338
+ )
339
+
340
+ resources = client.V1ResourceRequirements(
341
+ requests=requests,
342
+ limits=limits or None,
343
+ )
344
+ return resources, node_selector, tolerations
345
+
346
+ def _build_job(
347
+ self,
348
+ *,
349
+ job_name: str,
350
+ n_replicas: int,
351
+ gpu: str | None,
352
+ timeout: int,
353
+ entrypoint_args: Mapping[str, Any],
354
+ ) -> Any:
355
+ """Assemble the full V1Job (Indexed) bottom-up."""
356
+ from kubernetes import client
357
+
358
+ env = self._build_env(n_replicas, entrypoint_args)
359
+ resources, node_selector, tolerations = self._build_resources(gpu)
360
+
361
+ container = client.V1Container(
362
+ name="replica",
363
+ image=self.image,
364
+ command=list(self.command),
365
+ env=env,
366
+ resources=resources,
367
+ )
368
+
369
+ pod_spec = client.V1PodSpec(
370
+ restart_policy="Never", # required for Indexed jobs / fail-fast RL
371
+ containers=[container],
372
+ service_account_name=self.service_account_name,
373
+ node_selector=node_selector or None,
374
+ tolerations=tolerations or None,
375
+ runtime_class_name=self.runtime_class_name,
376
+ )
377
+
378
+ labels = {"app": "composer-diloco", "job-name": job_name}
379
+ pod_template = client.V1PodTemplateSpec(
380
+ metadata=client.V1ObjectMeta(labels=labels),
381
+ spec=pod_spec,
382
+ )
383
+
384
+ job_spec = client.V1JobSpec(
385
+ template=pod_template,
386
+ completions=n_replicas,
387
+ parallelism=n_replicas,
388
+ completion_mode="Indexed",
389
+ backoff_limit=self.backoff_limit,
390
+ ttl_seconds_after_finished=self.ttl_seconds_after_finished,
391
+ active_deadline_seconds=timeout,
392
+ )
393
+
394
+ return client.V1Job(
395
+ api_version="batch/v1",
396
+ kind="Job",
397
+ metadata=client.V1ObjectMeta(name=job_name, labels=labels),
398
+ spec=job_spec,
399
+ )
400
+
401
+ # -----------------------------------------------------------------
402
+ # ServerlessExecutor Protocol
403
+ # -----------------------------------------------------------------
404
+
405
+ def launch_replicas(
406
+ self,
407
+ n_replicas: int,
408
+ entrypoint: str | Callable[..., Any],
409
+ entrypoint_args: Mapping[str, Any],
410
+ *,
411
+ gpu: str | None = None,
412
+ timeout: int = 3600,
413
+ ) -> list[ReplicaHandle]:
414
+ """Create ONE Indexed Job of N pods and return N rank-ordered handles.
415
+
416
+ ``entrypoint`` is ignored when it names a Callable (k8s runs a container
417
+ command, not an in-process callable); the container command is fixed at
418
+ construction (``command`` ctor arg). The repo entrypoint module is the
419
+ default. ``entrypoint_args`` scalar kwargs are passed as upper-cased env
420
+ vars so ``replica_entrypoint`` / user code can read them. ``gpu`` maps to
421
+ a ``nvidia.com/gpu`` limit + node selector; ``timeout`` becomes the Job's
422
+ ``active_deadline_seconds`` hard wall-clock kill.
423
+ """
424
+ del entrypoint # k8s runs a container command, not an in-process fn
425
+
426
+ if n_replicas < 1:
427
+ raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
428
+
429
+ job_name = f"{self.run_id}-{uuid.uuid4().hex[:8]}"
430
+ job = self._build_job(
431
+ job_name=job_name,
432
+ n_replicas=n_replicas,
433
+ gpu=gpu,
434
+ timeout=timeout,
435
+ entrypoint_args=entrypoint_args,
436
+ )
437
+
438
+ self._batch().create_namespaced_job(namespace=self.namespace, body=job)
439
+
440
+ handles: list[ReplicaHandle] = []
441
+ for rank in range(n_replicas):
442
+ handles.append(
443
+ ReplicaHandle(
444
+ rank=rank,
445
+ backend_name=self.backend_name,
446
+ metadata={
447
+ "job_name": job_name,
448
+ "namespace": self.namespace,
449
+ "rank": rank,
450
+ },
451
+ )
452
+ )
453
+ self._handles[rank] = {
454
+ "job_name": job_name,
455
+ "namespace": self.namespace,
456
+ "result": None,
457
+ }
458
+ return handles
459
+
460
+ def poll(self, handle: ReplicaHandle) -> str:
461
+ """Poll this rank's status off the shared Indexed Job.
462
+
463
+ Reads ``read_namespaced_job_status`` once, then maps the whole-job
464
+ status to this rank: ``rank in completed_indexes`` -> ``succeeded``;
465
+ ``rank in failed_indexes`` -> ``failed``; ``active > 0`` -> ``running``;
466
+ else ``pending``. A 404 (Job deleted/cancelled) -> ``cancelled``.
467
+
468
+ Returns one of: ``pending`` | ``running`` | ``succeeded`` | ``failed`` |
469
+ ``cancelled``.
470
+ """
471
+ from kubernetes.client.exceptions import ApiException
472
+
473
+ job_name = handle.metadata["job_name"]
474
+ namespace = handle.metadata["namespace"]
475
+ rank = handle.metadata.get("rank", handle.rank)
476
+
477
+ try:
478
+ status = self._batch().read_namespaced_job_status(
479
+ name=job_name, namespace=namespace
480
+ ).status
481
+ except ApiException as e:
482
+ if getattr(e, "status", None) == 404:
483
+ return "cancelled"
484
+ raise
485
+
486
+ completed = _expand_indexes(getattr(status, "completed_indexes", None))
487
+ if rank in completed:
488
+ return "succeeded"
489
+
490
+ failed = _expand_indexes(getattr(status, "failed_indexes", None))
491
+ if rank in failed:
492
+ return "failed"
493
+
494
+ # Whole-job terminal Failed (e.g. DeadlineExceeded / backoff) with no
495
+ # per-index attribution -> treat this rank as failed.
496
+ for cond in (getattr(status, "conditions", None) or []):
497
+ if (
498
+ getattr(cond, "type", None) == "Failed"
499
+ and getattr(cond, "status", None) == "True"
500
+ ):
501
+ return "failed"
502
+
503
+ active = getattr(status, "active", None) or 0
504
+ if active > 0:
505
+ return "running"
506
+ return "pending"
507
+
508
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
509
+ """Read recent logs for this rank's pod.
510
+
511
+ Finds the pod whose ``batch.kubernetes.io/job-completion-index``
512
+ annotation (or label) equals the rank, then reads its log tail. Returns
513
+ a placeholder string (rather than raising) when the pod has not started
514
+ or the Job is gone — mirrors ``LocalProcessExecutor``.
515
+ """
516
+ from kubernetes.client.exceptions import ApiException
517
+
518
+ job_name = handle.metadata["job_name"]
519
+ namespace = handle.metadata["namespace"]
520
+ rank = handle.metadata.get("rank", handle.rank)
521
+ idx_key = "batch.kubernetes.io/job-completion-index"
522
+
523
+ try:
524
+ pods = self._core().list_namespaced_pod(
525
+ namespace=namespace, label_selector=f"job-name={job_name}"
526
+ )
527
+ except ApiException:
528
+ return f"<rank {rank}: job not found / no pods yet>"
529
+
530
+ pod_name = None
531
+ for pod in getattr(pods, "items", None) or []:
532
+ meta = getattr(pod, "metadata", None)
533
+ annotations = getattr(meta, "annotations", None) or {}
534
+ labels = getattr(meta, "labels", None) or {}
535
+ if annotations.get(idx_key) == str(rank) or labels.get(idx_key) == str(rank):
536
+ pod_name = getattr(meta, "name", None)
537
+ break
538
+
539
+ if pod_name is None:
540
+ # Fall back to the deterministic name prefix on k8s >= 1.28.
541
+ prefix = f"{job_name}-{rank}-"
542
+ for pod in getattr(pods, "items", None) or []:
543
+ name = getattr(getattr(pod, "metadata", None), "name", "") or ""
544
+ if name.startswith(prefix):
545
+ pod_name = name
546
+ break
547
+
548
+ if pod_name is None:
549
+ return f"<rank {rank}: pod not started / no logs yet>"
550
+
551
+ try:
552
+ return self._core().read_namespaced_pod_log(
553
+ name=pod_name,
554
+ namespace=namespace,
555
+ container="replica",
556
+ tail_lines=n_lines,
557
+ )
558
+ except ApiException as e:
559
+ if getattr(e, "status", None) in (400, 404):
560
+ return f"<rank {rank}: pod not started / no logs yet>"
561
+ raise
562
+
563
+ def cancel(self, handle: ReplicaHandle) -> None:
564
+ """Delete the WHOLE shared Indexed Job (gang teardown).
565
+
566
+ Because ``EKSExecutor`` uses one shared Indexed Job, cancelling ANY rank
567
+ tears down the entire replica cohort — intentional gang semantics for
568
+ the DiLoCo all-reduce barrier (a single straggler being cancelled should
569
+ not leave the rest spinning and burning GPU).
570
+
571
+ Uses ``propagation_policy='Background'`` so the pods are cascadingly
572
+ deleted (the k8s default ORPHANS pods, which would keep burning GPU —
573
+ the exact failure mode for RL). Idempotent: a 404 (already deleted) is
574
+ swallowed, and an unknown handle never raises, honoring the Protocol's
575
+ "no exception if already terminated" contract.
576
+ """
577
+ from kubernetes import client
578
+ from kubernetes.client.exceptions import ApiException
579
+
580
+ job_name = handle.metadata.get("job_name")
581
+ namespace = handle.metadata.get("namespace", self.namespace)
582
+ if not job_name:
583
+ return # unknown handle — no-op
584
+
585
+ try:
586
+ self._batch().delete_namespaced_job(
587
+ name=job_name,
588
+ namespace=namespace,
589
+ body=client.V1DeleteOptions(
590
+ propagation_policy="Background",
591
+ grace_period_seconds=0,
592
+ ),
593
+ )
594
+ except ApiException as e:
595
+ if getattr(e, "status", None) == 404:
596
+ return # already deleted
597
+ # Best-effort: swallow other API errors (network blip, etc.).
598
+ return
599
+ except Exception:
600
+ return
601
+
602
+ def collect(
603
+ self,
604
+ handles: list[ReplicaHandle],
605
+ *,
606
+ timeout: int | None = None,
607
+ ) -> list[dict[str, Any]]:
608
+ """Poll until every rank reaches a terminal state or the deadline.
609
+
610
+ Sleeps between polls (Job status is eventually consistent — do not
611
+ hammer the API server). Returns per-rank result dicts in handles order::
612
+
613
+ {"rank", "status", "exit_code", "error", "job_name"}
614
+
615
+ ``exit_code`` is 0 for succeeded, 1 for failed, ``None`` for
616
+ running/pending/cancelled — matching the Protocol's documented shape.
617
+ """
618
+ deadline = time.time() + (timeout if timeout is not None else 86400)
619
+ poll_interval = float(self._collect_poll_interval())
620
+ terminal = {"succeeded", "failed", "cancelled"}
621
+ results_by_rank: dict[int, dict[str, Any]] = {}
622
+
623
+ pending = list(handles)
624
+ while pending and time.time() < deadline:
625
+ still_pending: list[ReplicaHandle] = []
626
+ for h in pending:
627
+ state = self.poll(h)
628
+ if state in terminal:
629
+ results_by_rank[h.rank] = self._result_dict(h, state)
630
+ else:
631
+ still_pending.append(h)
632
+ pending = still_pending
633
+ if not pending:
634
+ break
635
+ remaining = deadline - time.time()
636
+ if remaining <= 0:
637
+ break
638
+ time.sleep(min(poll_interval, max(0.0, remaining)))
639
+
640
+ # Any rank still non-terminal at the deadline -> report its last state.
641
+ for h in pending:
642
+ state = self.poll(h)
643
+ results_by_rank[h.rank] = self._result_dict(h, state)
644
+
645
+ return [results_by_rank[h.rank] for h in handles]
646
+
647
+ # -----------------------------------------------------------------
648
+ # Internals
649
+ # -----------------------------------------------------------------
650
+
651
+ def _collect_poll_interval(self) -> float:
652
+ """Seconds between collect() polls. Overridable in tests."""
653
+ return 5.0
654
+
655
+ @staticmethod
656
+ def _result_dict(handle: ReplicaHandle, state: str) -> dict[str, Any]:
657
+ exit_code = {"succeeded": 0, "failed": 1}.get(state, None)
658
+ error = None
659
+ if state == "failed":
660
+ error = f"rank {handle.rank} reported failed by Job status"
661
+ elif state == "cancelled":
662
+ error = f"rank {handle.rank} Job no longer exists (cancelled)"
663
+ elif state in ("running", "pending"):
664
+ error = f"rank {handle.rank} not terminal at deadline (state={state})"
665
+ return {
666
+ "rank": handle.rank,
667
+ "status": state,
668
+ "exit_code": exit_code,
669
+ "error": error,
670
+ "job_name": handle.metadata.get("job_name"),
671
+ }
672
+
673
+
674
+ __all__ = ["EKSExecutor"]
composer_replication/diloco/serverless/executor.py CHANGED
@@ -36,9 +36,10 @@ class ReplicaHandle:
36
  class ServerlessExecutor(Protocol):
37
  """Uniform interface for launching N replicas on a serverless backend.
38
 
39
- Implementations: `LocalProcessExecutor` (test/dev), `ModalExecutor`
40
- (Modal, v0), `HFJobsExecutor` (HuggingFace Jobs, v0). Future:
41
- `RunPodExecutor`, `SageMakerExecutor`, `K8sExecutor`.
 
42
 
43
  Note on rank assignment: the Protocol guarantees that handles are
44
  returned in rank order (`handles[i].rank == i`). The replica entrypoint
 
36
  class ServerlessExecutor(Protocol):
37
  """Uniform interface for launching N replicas on a serverless backend.
38
 
39
+ Implementations: `LocalProcessExecutor` (test/dev), `ModalSpawnExecutor`
40
+ (Modal, production), `EKSExecutor` (Amazon EKS / Kubernetes Indexed Job,
41
+ production), `ModalExecutor` / `HFJobsExecutor` (v0 skeletons). Future
42
+ adapters: `RunPodExecutor`, `SageMakerExecutor`.
43
 
44
  Note on rank assignment: the Protocol guarantees that handles are
45
  returned in rank order (`handles[i].rank == i`). The replica entrypoint
composer_replication/diloco/serverless/sagemaker.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SageMakerExecutor — production boto3-backed serverless executor.
2
+
3
+ This is a fully-working cloud adapter (the sibling of `ModalSpawnExecutor`,
4
+ not the loud-failing `modal.py` / `hf_jobs.py` skeletons). It implements the
5
+ `ServerlessExecutor` Protocol against Amazon SageMaker Training Jobs via the
6
+ boto3 low-level `sagemaker` client.
7
+
8
+ Design choices
9
+ --------------
10
+
11
+ 1. **N independent single-instance jobs, NOT one multi-instance job.**
12
+ SageMaker's *native* distributed training (``ResourceConfig.InstanceCount > 1``)
13
+ groups instances into ONE job with an in-cluster NCCL/MPI fabric wired via
14
+ ``/opt/ml/input/config/resourceconfig.json``. That is the WRONG model for
15
+ DiLoCo replicas — it would couple replicas through SageMaker's intra-job
16
+ network and break the "each replica is an independent DiLoCo worker that
17
+ syncs only through S3" design. So ``launch_replicas`` submits N **separate**
18
+ training jobs, each with ``ResourceConfig.InstanceCount == 1``, tagged with
19
+ ``REPLICA_RANK=i`` / ``WORLD_SIZE=N`` via the ``Environment`` map. This
20
+ mirrors ``ModalSpawnExecutor`` spawning N independent Modal calls.
21
+
22
+ 2. **Same S3 ``ObjectStoreAllReduce`` rendezvous — DiLoCo math untouched.**
23
+ Cross-replica communication is EXCLUSIVELY the object-store rendezvous; the
24
+ executor passes ``rendezvous_uri`` (an ``s3://...`` URI) through to
25
+ ``replica_entrypoint.py`` unchanged. ``allreduce.py`` / ``MockManager`` /
26
+ ``make_diloco_outer_loop`` / the trainer all stay byte-for-byte identical.
27
+
28
+ 3. **Stateless after launch; rank via ``Environment``.** Handle metadata is the
29
+ ``training_job_name`` (plus submit timestamp). ``replica_entrypoint.py``
30
+ already reads ``REPLICA_RANK`` from ``os.environ``, so the cleanest channel
31
+ is the ``Environment`` map (string->string, max 100 entries, value <= 512
32
+ chars). The container command is baked into the image entrypoint and the
33
+ rendezvous args are passed via ``AlgorithmSpecification.ContainerArguments``.
34
+
35
+ 4. **``supports_inter_replica_network = False``.** Separate single-instance
36
+ training jobs have no mutual network path by design — they rendezvous only
37
+ through S3. (SageMaker's algo-N container fabric and
38
+ ``EnableInterContainerTrafficEncryption`` only exist WITHIN a single
39
+ multi-instance job, which this design deliberately does not use.)
40
+
41
+ Load-bearing gotcha — ``EnableNetworkIsolation`` MUST stay ``False``
42
+ --------------------------------------------------------------------
43
+ When ``EnableNetworkIsolation=True`` the training *container* has no outbound
44
+ network access. SageMaker's host-side processes still stage input channels and
45
+ ship CloudWatch logs, but the container itself cannot make S3 GET/PUT calls.
46
+ ``ObjectStoreAllReduce`` needs live S3 PUT+GET every outer round, so network
47
+ isolation would silently dead-lock the allreduce poll loop until its timeout.
48
+ This executor pins ``EnableNetworkIsolation=False`` (the API default) and never
49
+ exposes it as a knob. The rendezvous bucket access must instead be granted on
50
+ the execution ``RoleArn`` — the SageMaker analog of EKS IRSA.
51
+
52
+ HyperPod <-> EKS 1:1 control-plane mapping (recommended hybrid)
53
+ ---------------------------------------------------------------
54
+ Per the SageMaker docs: *"The high-level architecture of Amazon EKS support in
55
+ HyperPod involves a 1-to-1 mapping between an EKS cluster (control plane) and a
56
+ HyperPod cluster (worker nodes) within a VPC."*
57
+ (https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html)
58
+
59
+ Consequence for this repo's hybrid: "use HyperPod for the inner GRPO trainer"
60
+ does NOT mean leaving EKS — it means attaching a HyperPod-managed
61
+ (auto-recovering, deep-health-checked, PyTorch-job auto-resume) node-group to
62
+ the SAME EKS control plane that runs the outer loop. A future ``EKSExecutor``
63
+ (kubernetes client, Indexed Jobs) therefore targets both plain Karpenter GPU
64
+ nodes AND HyperPod nodes transparently. ``SageMakerExecutor`` (ephemeral
65
+ Training Jobs via boto3) is the SEPARATE bursty-fallback inner-loop path for
66
+ when you don't want a persistent cluster: Training Jobs suit periodic /
67
+ smaller-model / pay-per-use runs; HyperPod suits continuous / large-model /
68
+ persistent runs. Both share the IDENTICAL S3 rendezvous, so a run can move
69
+ between them with zero trainer / loss / DiLoCo changes.
70
+
71
+ References
72
+ ----------
73
+ - create_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/create_training_job.html
74
+ - describe_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/describe_training_job.html
75
+ - stop_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/stop_training_job.html
76
+ - network isolation: https://repost.aws/knowledge-center/sagemaker-access-network-isolation
77
+ - HyperPod-EKS: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html
78
+ - ADR-005 (executor protocol design)
79
+ """
80
+ from __future__ import annotations
81
+
82
+ import json
83
+ import time
84
+ import uuid
85
+ from collections.abc import Callable, Mapping
86
+ from typing import Any
87
+
88
+ from composer_replication.diloco.serverless.executor import (
89
+ ReplicaHandle,
90
+ )
91
+
92
+ # SageMaker TrainingJobStatus -> Protocol status vocabulary.
93
+ # describe_training_job's TrainingJobStatus is EXACTLY one of:
94
+ # 'InProgress' | 'Completed' | 'Failed' | 'Stopping' | 'Stopped'.
95
+ # We map Stopping -> 'running' (transient; still terminating, so collect()
96
+ # keeps waiting) and Stopped -> 'cancelled'.
97
+ _STATUS_MAP = {
98
+ "InProgress": "running",
99
+ "Completed": "succeeded",
100
+ "Failed": "failed",
101
+ "Stopping": "running",
102
+ "Stopped": "cancelled",
103
+ }
104
+
105
+ # SecondaryStatus values that mean "queued / not yet executing user code" —
106
+ # used to refine an InProgress job into the Protocol's 'pending'.
107
+ _PENDING_SECONDARY = frozenset(
108
+ {"Starting", "Pending", "LaunchingMLInstances", "PreparingTrainingStack"}
109
+ )
110
+
111
+ # Abstract Protocol GPU strings -> SageMaker instance types.
112
+ _GPU_INSTANCE_MAP = {
113
+ "A100": "ml.p4d.24xlarge",
114
+ "H100": "ml.p5.48xlarge",
115
+ "H200": "ml.p5e.48xlarge",
116
+ "B200": "ml.p6-b200.48xlarge",
117
+ "L40S": "ml.g6e.12xlarge",
118
+ "A10G": "ml.g5.2xlarge",
119
+ "L4": "ml.g6.2xlarge",
120
+ }
121
+
122
+ _CLOUDWATCH_LOG_GROUP = "/aws/sagemaker/TrainingJobs"
123
+
124
+
125
+ class SageMakerExecutor:
126
+ """Run replicas as N independent SageMaker Training Jobs.
127
+
128
+ Implements the `ServerlessExecutor` Protocol against the boto3
129
+ ``sagemaker`` client. Each replica is one single-instance training job;
130
+ cross-replica communication happens only through the shared S3
131
+ ``ObjectStoreAllReduce`` rendezvous.
132
+
133
+ Args:
134
+ role_arn: IAM execution role SageMaker assumes for the job. Must grant
135
+ S3 access to the rendezvous + output buckets (the boto3 analog of
136
+ EKS IRSA). The caller's credentials need ``iam:PassRole`` on it.
137
+ image_uri: ECR image URI for the training container. The image must
138
+ bake an entrypoint that runs
139
+ ``python -m composer_replication.diloco.serverless.replica_entrypoint``
140
+ (this executor also passes ``ContainerEntrypoint`` explicitly so a
141
+ generic image works too).
142
+ output_s3_path: ``s3://...`` prefix for ``OutputDataConfig.S3OutputPath``
143
+ (model artifacts / failure output).
144
+ instance_type: default SageMaker instance type when ``gpu`` is not
145
+ mapped (e.g. ``"ml.g5.2xlarge"``). ``gpu=None`` at launch falls
146
+ back to ``cpu_instance_type``.
147
+ cpu_instance_type: instance type used when ``gpu`` is ``None`` (CPU
148
+ smoke tests). Default ``"ml.m5.xlarge"``.
149
+ volume_size_gb: ``ResourceConfig.VolumeSizeInGB`` per job.
150
+ run_id: prefix for generated training-job names. Defaults to a short
151
+ random token so names are unique per region+account.
152
+ region: AWS region for the lazily-constructed boto3 clients. ``None``
153
+ uses the ambient boto3 default-region resolution.
154
+ sagemaker_client: inject a pre-built ``boto3.client('sagemaker')`` (or a
155
+ mock) instead of constructing one. Used by tests.
156
+ logs_client: inject a pre-built ``boto3.client('logs')`` (or a mock).
157
+
158
+ Raises:
159
+ RuntimeError: if boto3 is not installed and no client was injected.
160
+ """
161
+
162
+ backend_name = "sagemaker"
163
+ # Separate single-instance jobs have no mutual network path — S3 only.
164
+ supports_inter_replica_network = False
165
+
166
+ def __init__(
167
+ self,
168
+ *,
169
+ role_arn: str,
170
+ image_uri: str,
171
+ output_s3_path: str,
172
+ instance_type: str = "ml.g5.2xlarge",
173
+ cpu_instance_type: str = "ml.m5.xlarge",
174
+ volume_size_gb: int = 100,
175
+ run_id: str | None = None,
176
+ region: str | None = None,
177
+ sagemaker_client: Any = None,
178
+ logs_client: Any = None,
179
+ ) -> None:
180
+ self.role_arn = role_arn
181
+ self.image_uri = image_uri
182
+ self.output_s3_path = output_s3_path
183
+ self.instance_type = instance_type
184
+ self.cpu_instance_type = cpu_instance_type
185
+ self.volume_size_gb = volume_size_gb
186
+ self.run_id = run_id or f"diloco-{uuid.uuid4().hex[:8]}"
187
+ self._region = region
188
+
189
+ # Lazy boto3 — only constructed if the caller didn't inject a client.
190
+ # This keeps `import composer_replication.diloco.serverless` free of a
191
+ # hard boto3 dependency (boto3 lives in the optional [aws] extra), and
192
+ # lets tests inject a _MockSMClient with zero AWS calls.
193
+ if sagemaker_client is None:
194
+ sagemaker_client = self._make_boto3_client("sagemaker")
195
+ self._client = sagemaker_client
196
+ self._logs_client = logs_client # built lazily on first stream_logs()
197
+
198
+ # rank -> {"job_name": str, "result": dict | None}
199
+ self._handles: dict[int, dict[str, Any]] = {}
200
+
201
+ # -----------------------------------------------------------------
202
+ # boto3 plumbing (lazy)
203
+ # -----------------------------------------------------------------
204
+
205
+ def _make_boto3_client(self, service: str) -> Any:
206
+ try:
207
+ import boto3
208
+ except ImportError as e:
209
+ raise RuntimeError(
210
+ "SageMakerExecutor requires boto3. Install with "
211
+ "`pip install -e .[aws]` (or `pip install boto3`). "
212
+ f"Got: {e!r}"
213
+ ) from e
214
+ if self._region is not None:
215
+ return boto3.client(service, region_name=self._region)
216
+ return boto3.client(service)
217
+
218
+ def _map_gpu(self, gpu: str | None) -> str:
219
+ """Translate the Protocol's abstract gpu string to an instance type.
220
+
221
+ ``gpu=None`` -> ``cpu_instance_type`` (smoke tests). Unrecognized gpu
222
+ strings fall back to ``instance_type`` (so a caller can pass a literal
223
+ SageMaker instance type and it's honoured if not in the map).
224
+ """
225
+ if gpu is None:
226
+ return self.cpu_instance_type
227
+ if gpu in _GPU_INSTANCE_MAP:
228
+ return _GPU_INSTANCE_MAP[gpu]
229
+ # Caller may have passed a literal "ml.*" instance type.
230
+ if gpu.startswith("ml."):
231
+ return gpu
232
+ return self.instance_type
233
+
234
+ def _job_name(self, rank: int) -> str:
235
+ """Build a unique, regex-safe training-job name (<= 63 chars).
236
+
237
+ Pattern required by the API: ``[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}``.
238
+ """
239
+ name = f"{self.run_id}-r{rank:04d}-{int(time.time())}"
240
+ return name[:63]
241
+
242
+ # -----------------------------------------------------------------
243
+ # ServerlessExecutor Protocol
244
+ # -----------------------------------------------------------------
245
+
246
+ def launch_replicas(
247
+ self,
248
+ n_replicas: int,
249
+ entrypoint: str | Callable[..., Any],
250
+ entrypoint_args: Mapping[str, Any],
251
+ *,
252
+ gpu: str | None = None,
253
+ timeout: int = 3600,
254
+ ) -> list[ReplicaHandle]:
255
+ """Submit N independent single-instance SageMaker Training Jobs.
256
+
257
+ Args:
258
+ n_replicas: number of replicas (= number of training jobs).
259
+ entrypoint: ignored — the container command is baked into the
260
+ image / passed as ``ContainerEntrypoint``. Kept for Protocol
261
+ compatibility.
262
+ entrypoint_args: must contain ``rendezvous_uri`` (``s3://...``) and
263
+ ``trainer_module``. Optional: ``trainer_fn`` (default
264
+ ``"train"``), ``trainer_kwargs`` (dict, JSON-encoded into the
265
+ container args). The conventional ``rank_env`` key (from
266
+ ``LocalProcessExecutor``) is ignored — rank goes through the
267
+ ``Environment`` map instead.
268
+ gpu: abstract GPU spec mapped to an instance type via ``_map_gpu``.
269
+ ``None`` -> CPU instance.
270
+ timeout: ``StoppingCondition.MaxRuntimeInSeconds`` per job.
271
+
272
+ Returns:
273
+ ``list[ReplicaHandle]`` of length ``n_replicas`` in rank order
274
+ (``handles[i].rank == i``).
275
+ """
276
+ del entrypoint # container command is baked / passed explicitly
277
+
278
+ if n_replicas < 1:
279
+ raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
280
+
281
+ rendezvous_uri = entrypoint_args.get("rendezvous_uri")
282
+ if not rendezvous_uri:
283
+ raise ValueError(
284
+ "entrypoint_args must include 'rendezvous_uri' (the s3:// "
285
+ "ObjectStoreAllReduce rendezvous prefix)."
286
+ )
287
+ trainer_module = entrypoint_args.get("trainer_module")
288
+ if not trainer_module:
289
+ raise ValueError(
290
+ "entrypoint_args must include 'trainer_module' (importable "
291
+ "module path of the user's train function)."
292
+ )
293
+ trainer_fn = entrypoint_args.get("trainer_fn", "train")
294
+ trainer_kwargs = entrypoint_args.get("trainer_kwargs", {})
295
+
296
+ instance_type = self._map_gpu(gpu)
297
+
298
+ # Container args: each element is a SINGLE token (StackOverflow
299
+ # 77994925 — `['--world-size', '4']` NOT `['--world-size 4']`).
300
+ container_args = [
301
+ "--rendezvous", str(rendezvous_uri),
302
+ "--world-size", str(n_replicas),
303
+ "--trainer-module", str(trainer_module),
304
+ "--trainer-fn", str(trainer_fn),
305
+ "--trainer-kwargs-json", json.dumps(trainer_kwargs),
306
+ ]
307
+
308
+ handles: list[ReplicaHandle] = []
309
+ for rank in range(n_replicas):
310
+ job_name = self._job_name(rank)
311
+ request = {
312
+ "TrainingJobName": job_name,
313
+ "AlgorithmSpecification": {
314
+ "TrainingImage": self.image_uri,
315
+ "TrainingInputMode": "File",
316
+ "ContainerEntrypoint": [
317
+ "python", "-m",
318
+ "composer_replication.diloco.serverless.replica_entrypoint",
319
+ ],
320
+ "ContainerArguments": container_args,
321
+ },
322
+ "RoleArn": self.role_arn,
323
+ # InputDataConfig intentionally omitted — the replica pulls
324
+ # data via its own code / the S3 rendezvous, not SM channels.
325
+ "OutputDataConfig": {"S3OutputPath": self.output_s3_path},
326
+ "ResourceConfig": {
327
+ "InstanceType": instance_type,
328
+ "InstanceCount": 1,
329
+ "VolumeSizeInGB": self.volume_size_gb,
330
+ },
331
+ "StoppingCondition": {"MaxRuntimeInSeconds": int(timeout)},
332
+ # REPLICA_RANK / WORLD_SIZE injected as container env vars;
333
+ # replica_entrypoint.py reads os.environ['REPLICA_RANK'].
334
+ "Environment": {
335
+ "REPLICA_RANK": str(rank),
336
+ "WORLD_SIZE": str(n_replicas),
337
+ "RENDEZVOUS_URI": str(rendezvous_uri),
338
+ },
339
+ # MUST stay False — True severs the container's S3 access and
340
+ # dead-locks the allreduce poll loop. See module docstring.
341
+ "EnableNetworkIsolation": False,
342
+ }
343
+ try:
344
+ self._client.create_training_job(**request)
345
+ except Exception as e:
346
+ # Best-effort stop of already-launched siblings, then raise.
347
+ for prior in handles:
348
+ try:
349
+ self.cancel(prior)
350
+ except Exception:
351
+ pass
352
+ raise RuntimeError(
353
+ f"SageMakerExecutor.launch_replicas failed at rank={rank} "
354
+ f"of {n_replicas} (already-launched siblings stopped). "
355
+ f"Underlying error: {e!r}"
356
+ ) from e
357
+
358
+ handle = ReplicaHandle(
359
+ rank=rank,
360
+ backend_name=self.backend_name,
361
+ metadata={
362
+ "training_job_name": job_name,
363
+ "submit_ts": time.time(),
364
+ },
365
+ )
366
+ self._handles[rank] = {"job_name": job_name, "result": None}
367
+ handles.append(handle)
368
+
369
+ return handles
370
+
371
+ def poll(self, handle: ReplicaHandle) -> str:
372
+ """Poll a training job's status.
373
+
374
+ Returns one of: ``"pending"`` | ``"running"`` | ``"succeeded"`` |
375
+ ``"failed"`` | ``"cancelled"``.
376
+
377
+ Maps ``describe_training_job``'s ``TrainingJobStatus`` via
378
+ ``_STATUS_MAP``; refines ``InProgress`` to ``"pending"`` while the job
379
+ is still queued (``SecondaryStatus`` in ``_PENDING_SECONDARY``). A
380
+ vanished job (``ResourceNotFound``) is treated as ``"cancelled"``.
381
+ """
382
+ meta = self._handles.get(handle.rank)
383
+ if meta is None:
384
+ return "cancelled"
385
+ if meta["result"] is not None:
386
+ return meta["result"]["status"]
387
+
388
+ job_name = meta["job_name"]
389
+ try:
390
+ resp = self._client.describe_training_job(TrainingJobName=job_name)
391
+ except Exception as e:
392
+ if self._is_resource_not_found(e):
393
+ return "cancelled"
394
+ raise
395
+
396
+ sm_status = resp.get("TrainingJobStatus", "InProgress")
397
+ mapped = _STATUS_MAP.get(sm_status, "running")
398
+
399
+ if sm_status == "InProgress":
400
+ if resp.get("SecondaryStatus") in _PENDING_SECONDARY:
401
+ return "pending"
402
+ return "running"
403
+
404
+ # Terminal — cache a result dict so collect()/repeat-poll are cheap.
405
+ meta["result"] = self._terminal_result(handle.rank, sm_status, resp)
406
+ return mapped
407
+
408
+ def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
409
+ """Read recent CloudWatch logs for this replica's training job.
410
+
411
+ SageMaker writes container stdout/stderr to the
412
+ ``/aws/sagemaker/TrainingJobs`` log group, stream
413
+ ``<job-name>/algo-<n>-<epoch>``. We discover the exact stream name by
414
+ prefix then read the tail. Falls back to a CloudWatch console pointer
415
+ on any error (mirrors ModalSpawnExecutor's dashboard-URL fallback).
416
+ """
417
+ meta = self._handles.get(handle.rank)
418
+ if meta is None:
419
+ return f"<replica {handle.rank}: no metadata>"
420
+ job_name = meta["job_name"]
421
+
422
+ try:
423
+ logs = self._logs()
424
+ prefix = f"{job_name}/"
425
+ streams = logs.describe_log_streams(
426
+ logGroupName=_CLOUDWATCH_LOG_GROUP,
427
+ logStreamNamePrefix=prefix,
428
+ orderBy="LastEventTime",
429
+ descending=True,
430
+ limit=1,
431
+ )
432
+ stream_list = streams.get("logStreams", [])
433
+ if not stream_list:
434
+ return (
435
+ f"[rank {handle.rank}] job={job_name}: no CloudWatch log "
436
+ f"stream yet (job pending / not started)."
437
+ )
438
+ stream_name = stream_list[0]["logStreamName"]
439
+ events = logs.get_log_events(
440
+ logGroupName=_CLOUDWATCH_LOG_GROUP,
441
+ logStreamName=stream_name,
442
+ limit=n_lines,
443
+ startFromHead=False,
444
+ )
445
+ lines = [e.get("message", "") for e in events.get("events", [])]
446
+ body = "\n".join(lines) if lines else "<no log events>"
447
+ return f"[rank {handle.rank}] job={job_name} stream={stream_name}\n{body}"
448
+ except Exception as e:
449
+ region = self._region or "<region>"
450
+ url = (
451
+ f"https://{region}.console.aws.amazon.com/cloudwatch/home"
452
+ f"?region={region}#logsV2:log-groups/log-group/"
453
+ f"$252Faws$252Fsagemaker$252FTrainingJobs"
454
+ )
455
+ return (
456
+ f"[rank {handle.rank}] job={job_name}: log fetch failed "
457
+ f"({type(e).__name__}: {e!r}).\n CloudWatch console: {url}"
458
+ )
459
+
460
+ def cancel(self, handle: ReplicaHandle) -> None:
461
+ """Best-effort stop of a training job.
462
+
463
+ Calls ``stop_training_job`` (SIGTERM + 120s grace), swallowing
464
+ ``ResourceNotFound`` and "already terminal" ``ValidationException`` so
465
+ the contract — "no exception if already terminated" — holds.
466
+ """
467
+ meta = self._handles.get(handle.rank)
468
+ if meta is None:
469
+ return
470
+ try:
471
+ self._client.stop_training_job(TrainingJobName=meta["job_name"])
472
+ except Exception:
473
+ # ResourceNotFound, already-Completed/Stopped ValidationException,
474
+ # transient network blip — all best-effort no-ops.
475
+ pass
476
+
477
+ def collect(
478
+ self,
479
+ handles: list[ReplicaHandle],
480
+ *,
481
+ timeout: int | None = None,
482
+ ) -> list[dict[str, Any]]:
483
+ """Block until all replicas finish; return per-replica result dicts.
484
+
485
+ Polls ``describe_training_job`` per handle until the job reaches a
486
+ terminal status (``Completed`` / ``Failed`` / ``Stopped``) or the
487
+ shared deadline elapses. Returns results aligned to the input handle
488
+ order (Protocol contract; mirrors ``LocalProcessExecutor.collect``).
489
+
490
+ Each result dict has at least
491
+ ``{"rank", "status", "exit_code", "error"}``.
492
+ """
493
+ deadline = time.time() + (timeout if timeout is not None else 86400)
494
+ poll_interval = 30.0
495
+ results: list[dict[str, Any]] = []
496
+
497
+ for h in handles:
498
+ meta = self._handles.get(h.rank)
499
+ if meta is None:
500
+ results.append({
501
+ "rank": h.rank,
502
+ "status": "cancelled",
503
+ "exit_code": None,
504
+ "error": "handle has no metadata (cancelled or unknown)",
505
+ "result": None,
506
+ "training_job_name": h.metadata.get("training_job_name"),
507
+ })
508
+ continue
509
+
510
+ # Already cached by an earlier poll()/collect().
511
+ if meta["result"] is not None:
512
+ results.append(meta["result"])
513
+ continue
514
+
515
+ job_name = meta["job_name"]
516
+ result_dict: dict[str, Any] | None = None
517
+ while True:
518
+ try:
519
+ resp = self._client.describe_training_job(
520
+ TrainingJobName=job_name
521
+ )
522
+ except Exception as e:
523
+ if self._is_resource_not_found(e):
524
+ result_dict = {
525
+ "rank": h.rank,
526
+ "status": "cancelled",
527
+ "exit_code": None,
528
+ "error": "training job not found (deleted?)",
529
+ "result": None,
530
+ "training_job_name": job_name,
531
+ }
532
+ break
533
+ raise
534
+
535
+ sm_status = resp.get("TrainingJobStatus", "InProgress")
536
+ if sm_status in ("Completed", "Failed", "Stopped"):
537
+ result_dict = self._terminal_result(h.rank, sm_status, resp)
538
+ break
539
+
540
+ if time.time() >= deadline:
541
+ result_dict = {
542
+ "rank": h.rank,
543
+ "status": "running",
544
+ "exit_code": None,
545
+ "error": "timeout before terminal",
546
+ "result": None,
547
+ "training_job_name": job_name,
548
+ }
549
+ break
550
+
551
+ # Sleep, but never overrun the deadline.
552
+ time.sleep(min(poll_interval, max(0.0, deadline - time.time())))
553
+
554
+ # Cache only terminal results (not the timeout 'running' sentinel,
555
+ # so a later collect() can re-check the job).
556
+ if result_dict["status"] in ("succeeded", "failed", "cancelled"):
557
+ meta["result"] = result_dict
558
+ results.append(result_dict)
559
+
560
+ return results
561
+
562
+ # -----------------------------------------------------------------
563
+ # Helpers
564
+ # -----------------------------------------------------------------
565
+
566
+ def _logs(self) -> Any:
567
+ """Lazily build the CloudWatch Logs client (separate from sagemaker)."""
568
+ if self._logs_client is None:
569
+ self._logs_client = self._make_boto3_client("logs")
570
+ return self._logs_client
571
+
572
+ @staticmethod
573
+ def _terminal_result(
574
+ rank: int, sm_status: str, resp: Mapping[str, Any]
575
+ ) -> dict[str, Any]:
576
+ """Build a result dict from a terminal describe_training_job response."""
577
+ mapped = _STATUS_MAP.get(sm_status, "failed")
578
+ if sm_status == "Completed":
579
+ exit_code: int | None = 0
580
+ error = None
581
+ elif sm_status == "Stopped":
582
+ exit_code = None
583
+ error = resp.get("FailureReason")
584
+ else: # Failed
585
+ exit_code = 1
586
+ error = resp.get("FailureReason") or "training job failed"
587
+ artifacts = resp.get("ModelArtifacts", {}) or {}
588
+ return {
589
+ "rank": rank,
590
+ "status": mapped,
591
+ "exit_code": exit_code,
592
+ "error": error,
593
+ "result": artifacts.get("S3ModelArtifacts"),
594
+ "training_job_name": resp.get("TrainingJobName"),
595
+ }
596
+
597
+ def _is_resource_not_found(self, exc: Exception) -> bool:
598
+ """True if ``exc`` is the boto3 ResourceNotFound for the sagemaker client.
599
+
600
+ Handles both the typed client exception
601
+ (``client.exceptions.ResourceNotFound``) and a generic botocore
602
+ ``ClientError`` whose error code is ``ResourceNotFound`` /
603
+ ``ValidationException`` naming a missing job — robust across whether a
604
+ real boto3 client or a mock is in use.
605
+ """
606
+ rnf = getattr(getattr(self._client, "exceptions", None),
607
+ "ResourceNotFound", None)
608
+ if rnf is not None and isinstance(exc, rnf):
609
+ return True
610
+ # Generic botocore ClientError fallback.
611
+ resp = getattr(exc, "response", None)
612
+ if isinstance(resp, Mapping):
613
+ code = resp.get("Error", {}).get("Code", "")
614
+ if code in ("ResourceNotFound", "ValidationException"):
615
+ return True
616
+ return False
617
+
618
+
619
+ __all__ = ["SageMakerExecutor"]
composer_replication/diloco/serverless/tests/test_eks_executor.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for EKSExecutor — the Kubernetes Indexed-Job-backed executor.
2
+
3
+ These tests exercise the executor's contract WITHOUT a live cluster and
4
+ WITHOUT the `kubernetes` client actually being installed. They:
5
+
6
+ * inject a fake `kubernetes` module into ``sys.modules`` so the executor's
7
+ lazy ``from kubernetes import client`` / ``...client.exceptions`` calls
8
+ resolve to recording stand-in V1* model classes (this is the k8s analogue
9
+ of the modal test's ``_MockFunctionCall``), and
10
+ * pass mock ``batch_api`` / ``core_api`` via dependency injection (the
11
+ constructor's ``batch_api=`` / ``core_api=`` args), so no config loading or
12
+ cluster contact happens.
13
+
14
+ For real-cluster integration testing you would gate behind cluster
15
+ availability (e.g. ``config.load_kube_config()`` succeeding), exactly like
16
+ ``test_modal_spawn_executor.py`` gates on ``_is_modal_installed()``.
17
+
18
+ Run: ``.venv/bin/python -m pytest <thisfile> -q``
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import sys
23
+ import types
24
+
25
+ import pytest
26
+
27
+ from composer_replication.diloco.serverless import EKSExecutor, ReplicaHandle
28
+ from composer_replication.diloco.serverless.eks import _expand_indexes
29
+
30
+ # ---------------------------------------------------------------------
31
+ # Fake `kubernetes` module — recording V1* model stand-ins + ApiException
32
+ # ---------------------------------------------------------------------
33
+
34
+
35
+ class _Rec:
36
+ """Generic recording model: stores all ctor kwargs as attributes.
37
+
38
+ Stands in for the kubernetes client's ``V1*`` model classes (V1Job,
39
+ V1JobSpec, V1Container, V1EnvVar, ...). Every attr the executor sets is
40
+ inspectable by tests. Mirrors how the modal mock records ``.spawn`` args.
41
+ """
42
+
43
+ def __init__(self, **kwargs):
44
+ # Default the common optional model fields to None so attribute
45
+ # access in assertions never raises AttributeError.
46
+ for k, v in kwargs.items():
47
+ setattr(self, k, v)
48
+
49
+ def __getattr__(self, name): # only called when attr is genuinely absent
50
+ return None
51
+
52
+
53
+ class _ApiException(Exception): # noqa: N818 — mirrors kubernetes.client.exceptions.ApiException name
54
+ """Stand-in for kubernetes.client.exceptions.ApiException."""
55
+
56
+ def __init__(self, status=None, reason=None, body=None):
57
+ super().__init__(f"ApiException(status={status})")
58
+ self.status = status
59
+ self.reason = reason
60
+ self.body = body
61
+
62
+
63
+ # The set of V1* names the executor constructs. Each maps to _Rec.
64
+ _V1_NAMES = [
65
+ "V1Job",
66
+ "V1JobSpec",
67
+ "V1ObjectMeta",
68
+ "V1PodTemplateSpec",
69
+ "V1PodSpec",
70
+ "V1Container",
71
+ "V1EnvVar",
72
+ "V1EnvVarSource",
73
+ "V1ObjectFieldSelector",
74
+ "V1ResourceRequirements",
75
+ "V1Toleration",
76
+ "V1DeleteOptions",
77
+ ]
78
+
79
+
80
+ @pytest.fixture
81
+ def fake_kubernetes(monkeypatch):
82
+ """Install a fake `kubernetes` package into sys.modules for the test.
83
+
84
+ Provides:
85
+ - kubernetes.client.<V1*> -> recording _Rec classes
86
+ - kubernetes.client.exceptions.ApiException
87
+ - kubernetes.client.BatchV1Api / CoreV1Api (unused — apis are injected)
88
+ - kubernetes.config.load_incluster_config / load_kube_config / ConfigException
89
+ """
90
+ kubernetes = types.ModuleType("kubernetes")
91
+ client = types.ModuleType("kubernetes.client")
92
+ exceptions = types.ModuleType("kubernetes.client.exceptions")
93
+ config = types.ModuleType("kubernetes.config")
94
+
95
+ for name in _V1_NAMES:
96
+ setattr(client, name, _Rec)
97
+
98
+ # Default api classes (only hit if NOT injected — we always inject).
99
+ client.BatchV1Api = lambda *a, **k: pytest.fail("BatchV1Api should be injected")
100
+ client.CoreV1Api = lambda *a, **k: pytest.fail("CoreV1Api should be injected")
101
+
102
+ exceptions.ApiException = _ApiException
103
+ client.exceptions = exceptions
104
+
105
+ class _ConfigException(Exception): # noqa: N818 — mirrors kubernetes.config.ConfigException name
106
+ pass
107
+
108
+ config.ConfigException = _ConfigException
109
+ config.load_incluster_config = lambda *a, **k: (_ for _ in ()).throw(
110
+ _ConfigException("not in cluster")
111
+ )
112
+ config.load_kube_config = lambda *a, **k: None
113
+
114
+ kubernetes.client = client
115
+ kubernetes.config = config
116
+
117
+ monkeypatch.setitem(sys.modules, "kubernetes", kubernetes)
118
+ monkeypatch.setitem(sys.modules, "kubernetes.client", client)
119
+ monkeypatch.setitem(sys.modules, "kubernetes.client.exceptions", exceptions)
120
+ monkeypatch.setitem(sys.modules, "kubernetes.config", config)
121
+ return kubernetes
122
+
123
+
124
+ # ---------------------------------------------------------------------
125
+ # Mock BatchV1Api / CoreV1Api (the _MockBatchV1 the task asks for)
126
+ # ---------------------------------------------------------------------
127
+
128
+
129
+ class _MockBatchV1Api:
130
+ """Records create/read-status/delete calls; returns a settable status."""
131
+
132
+ def __init__(self):
133
+ self.created_jobs: list[tuple[str, object]] = []
134
+ self.delete_calls: list[dict] = []
135
+ # status object returned by read_namespaced_job_status().status
136
+ self.status_obj = _Rec(
137
+ active=None,
138
+ succeeded=None,
139
+ failed=None,
140
+ completed_indexes=None,
141
+ failed_indexes=None,
142
+ conditions=None,
143
+ )
144
+ # Optional: raise this ApiException on read (e.g. 404 -> cancelled)
145
+ self.read_raises: Exception | None = None
146
+
147
+ def create_namespaced_job(self, namespace, body):
148
+ self.created_jobs.append((namespace, body))
149
+ return body
150
+
151
+ def read_namespaced_job_status(self, name, namespace):
152
+ if self.read_raises is not None:
153
+ raise self.read_raises
154
+ return _Rec(status=self.status_obj)
155
+
156
+ def delete_namespaced_job(self, name, namespace, body=None):
157
+ self.delete_calls.append(
158
+ {
159
+ "name": name,
160
+ "namespace": namespace,
161
+ "propagation_policy": getattr(body, "propagation_policy", None),
162
+ "grace_period_seconds": getattr(body, "grace_period_seconds", None),
163
+ }
164
+ )
165
+ return _Rec(status="Success")
166
+
167
+
168
+ class _MockCoreV1Api:
169
+ """Canned list_namespaced_pod + read_namespaced_pod_log."""
170
+
171
+ def __init__(self, pods=None, logs="line1\nline2\n"):
172
+ self._pods = pods if pods is not None else []
173
+ self._logs = logs
174
+ self.log_calls: list[dict] = []
175
+ self.list_calls: list[dict] = []
176
+ self.log_raises: Exception | None = None
177
+
178
+ def list_namespaced_pod(self, namespace, label_selector=None):
179
+ self.list_calls.append({"namespace": namespace, "label_selector": label_selector})
180
+ return _Rec(items=list(self._pods))
181
+
182
+ def read_namespaced_pod_log(self, name, namespace, container=None, tail_lines=None):
183
+ self.log_calls.append(
184
+ {
185
+ "name": name,
186
+ "namespace": namespace,
187
+ "container": container,
188
+ "tail_lines": tail_lines,
189
+ }
190
+ )
191
+ if self.log_raises is not None:
192
+ raise self.log_raises
193
+ return self._logs
194
+
195
+
196
+ def _make_pod(name, rank):
197
+ """Build a fake pod with the completion-index annotation set."""
198
+ return _Rec(
199
+ metadata=_Rec(
200
+ name=name,
201
+ annotations={"batch.kubernetes.io/job-completion-index": str(rank)},
202
+ labels={"job-name": name.rsplit("-", 2)[0]},
203
+ ),
204
+ status=_Rec(phase="Running"),
205
+ )
206
+
207
+
208
+ def _make_executor(fake_kubernetes, *, batch=None, core=None, **kwargs):
209
+ batch = batch or _MockBatchV1Api()
210
+ core = core or _MockCoreV1Api()
211
+ ex = EKSExecutor(
212
+ image="myrepo/composer-replica:latest",
213
+ batch_api=batch,
214
+ core_api=core,
215
+ **kwargs,
216
+ )
217
+ # Speed up collect() loops in tests.
218
+ ex._collect_poll_interval = lambda: 0.0
219
+ return ex, batch, core
220
+
221
+
222
+ # ---------------------------------------------------------------------
223
+ # _expand_indexes — the run-length-range parser
224
+ # ---------------------------------------------------------------------
225
+
226
+
227
+ def test_expand_indexes_singletons_and_ranges():
228
+ assert _expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}
229
+ assert _expand_indexes("0") == {0}
230
+ assert _expand_indexes("0-3") == {0, 1, 2, 3}
231
+ assert _expand_indexes("") == set()
232
+ assert _expand_indexes(None) == set()
233
+ # Reversed range is tolerated.
234
+ assert _expand_indexes("5-3") == {3, 4, 5}
235
+ # Whitespace / junk tolerated.
236
+ assert _expand_indexes(" 2 , 4-6 ") == {2, 4, 5, 6}
237
+
238
+
239
+ # ---------------------------------------------------------------------
240
+ # Construction / preconditions
241
+ # ---------------------------------------------------------------------
242
+
243
+
244
+ def test_missing_kubernetes_raises_runtime_error_when_no_api_injected():
245
+ """With kubernetes absent AND no injected api, ctor must raise clearly.
246
+
247
+ The import-guard path can ONLY be exercised when `kubernetes` is genuinely
248
+ not importable in this interpreter. When it IS installed (e.g. via the
249
+ `[eks]`/`[serverless]` extra in CI), the lazy import succeeds and the ctor
250
+ legitimately does not raise — so skip rather than assert a false precondition.
251
+ """
252
+ import importlib.util
253
+
254
+ if importlib.util.find_spec("kubernetes") is not None:
255
+ pytest.skip("kubernetes is importable in this interpreter; the absent-path cannot be exercised")
256
+ with pytest.raises(RuntimeError, match="kubernetes"):
257
+ EKSExecutor(image="x")
258
+
259
+
260
+ def test_construction_with_injected_apis_does_not_need_kubernetes():
261
+ """When both apis are injected, ctor must not require the kubernetes import."""
262
+ batch = _MockBatchV1Api()
263
+ core = _MockCoreV1Api()
264
+ ex = EKSExecutor(image="img", batch_api=batch, core_api=core)
265
+ assert ex.backend_name == "eks"
266
+ assert ex.supports_inter_replica_network is False
267
+ assert ex.image == "img"
268
+
269
+
270
+ # ---------------------------------------------------------------------
271
+ # launch_replicas — N handles, indexed-job spec correctness
272
+ # ---------------------------------------------------------------------
273
+
274
+
275
+ def test_launch_returns_n_rank_ordered_handles(fake_kubernetes):
276
+ ex, batch, _ = _make_executor(fake_kubernetes)
277
+ handles = ex.launch_replicas(
278
+ n_replicas=4,
279
+ entrypoint="ignored",
280
+ entrypoint_args={"rendezvous_uri": "s3://b/run42/", "world_size": 4},
281
+ )
282
+ assert len(handles) == 4
283
+ for i, h in enumerate(handles):
284
+ assert isinstance(h, ReplicaHandle)
285
+ assert h.rank == i
286
+ assert h.backend_name == "eks"
287
+ assert h.metadata["rank"] == i
288
+ # ALL handles share the same job_name / namespace (gang).
289
+ assert h.metadata["job_name"] == handles[0].metadata["job_name"]
290
+ assert h.metadata["namespace"] == "default"
291
+
292
+ # Exactly ONE job was created (single Indexed Job topology).
293
+ assert len(batch.created_jobs) == 1
294
+
295
+
296
+ def test_launch_creates_indexed_job_spec(fake_kubernetes):
297
+ ex, batch, _ = _make_executor(fake_kubernetes)
298
+ ex.launch_replicas(
299
+ n_replicas=3,
300
+ entrypoint="ignored",
301
+ entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 3},
302
+ )
303
+ ns, job = batch.created_jobs[0]
304
+ assert ns == "default"
305
+ assert job.api_version == "batch/v1"
306
+ assert job.kind == "Job"
307
+ spec = job.spec
308
+ assert spec.completions == 3
309
+ assert spec.parallelism == 3
310
+ assert spec.completion_mode == "Indexed"
311
+ assert spec.backoff_limit == 0
312
+ assert spec.ttl_seconds_after_finished == 3600
313
+ # active_deadline_seconds == timeout (default 3600 here).
314
+ assert spec.active_deadline_seconds == 3600
315
+ # restart_policy Never (required for Indexed jobs).
316
+ assert spec.template.spec.restart_policy == "Never"
317
+
318
+
319
+ def test_launch_rank_env_uses_downward_api_field_ref(fake_kubernetes):
320
+ ex, batch, _ = _make_executor(fake_kubernetes)
321
+ ex.launch_replicas(
322
+ n_replicas=2,
323
+ entrypoint="ignored",
324
+ entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 2},
325
+ )
326
+ _, job = batch.created_jobs[0]
327
+ env = job.spec.template.spec.containers[0].env
328
+ by_name = {e.name: e for e in env}
329
+
330
+ # REPLICA_RANK from the downward-API annotation (NOT a literal value).
331
+ rr = by_name["REPLICA_RANK"]
332
+ assert rr.value is None
333
+ field_ref = rr.value_from.field_ref
334
+ assert (
335
+ field_ref.field_path
336
+ == "metadata.annotations['batch.kubernetes.io/job-completion-index']"
337
+ )
338
+
339
+ # WORLD_SIZE is a literal string.
340
+ assert by_name["WORLD_SIZE"].value == "2"
341
+
342
+ # rendezvous_uri passed through as an upper-cased literal env var.
343
+ assert by_name["RENDEZVOUS_URI"].value == "s3://b/r/"
344
+
345
+
346
+ def test_launch_strips_rank_env_kwarg(fake_kubernetes):
347
+ """`rank_env` is the LocalProcessExecutor convention — must not become env."""
348
+ ex, batch, _ = _make_executor(fake_kubernetes)
349
+ ex.launch_replicas(
350
+ n_replicas=1,
351
+ entrypoint="ignored",
352
+ entrypoint_args={"rank_env": "REPLICA_RANK", "rendezvous_uri": "s3://x/"},
353
+ )
354
+ _, job = batch.created_jobs[0]
355
+ env_names = {e.name for e in job.spec.template.spec.containers[0].env}
356
+ assert "RANK_ENV" not in env_names
357
+ assert "RENDEZVOUS_URI" in env_names
358
+
359
+
360
+ def test_launch_gpu_limit_is_string(fake_kubernetes):
361
+ ex, batch, _ = _make_executor(fake_kubernetes)
362
+ ex.launch_replicas(
363
+ n_replicas=2,
364
+ entrypoint="ignored",
365
+ entrypoint_args={"rendezvous_uri": "s3://x/"},
366
+ gpu="A100",
367
+ )
368
+ _, job = batch.created_jobs[0]
369
+ container = job.spec.template.spec.containers[0]
370
+ limits = container.resources.limits
371
+ assert limits["nvidia.com/gpu"] == "1"
372
+ # MUST be a string, not an int.
373
+ assert isinstance(limits["nvidia.com/gpu"], str)
374
+ # GPU node selector merged in.
375
+ node_selector = job.spec.template.spec.node_selector
376
+ assert node_selector["node.kubernetes.io/instance-type"] == "p4d.24xlarge"
377
+ # GPU NoSchedule toleration auto-added.
378
+ tols = job.spec.template.spec.tolerations
379
+ assert any(
380
+ t.key == "nvidia.com/gpu" and t.effect == "NoSchedule" for t in tols
381
+ )
382
+
383
+
384
+ def test_launch_cpu_only_omits_gpu_limit(fake_kubernetes):
385
+ ex, batch, _ = _make_executor(fake_kubernetes)
386
+ ex.launch_replicas(
387
+ n_replicas=2,
388
+ entrypoint="ignored",
389
+ entrypoint_args={"rendezvous_uri": "s3://x/"},
390
+ gpu=None,
391
+ )
392
+ _, job = batch.created_jobs[0]
393
+ limits = job.spec.template.spec.containers[0].resources.limits
394
+ # No GPU -> no nvidia.com/gpu key at all (limits is None or empty).
395
+ assert not limits or "nvidia.com/gpu" not in (limits or {})
396
+
397
+
398
+ def test_launch_passes_service_account_and_runtime_class(fake_kubernetes):
399
+ ex, batch, _ = _make_executor(
400
+ fake_kubernetes,
401
+ service_account_name="diloco-irsa-sa",
402
+ runtime_class_name="gvisor",
403
+ )
404
+ ex.launch_replicas(
405
+ n_replicas=1,
406
+ entrypoint="ignored",
407
+ entrypoint_args={"rendezvous_uri": "s3://x/"},
408
+ )
409
+ _, job = batch.created_jobs[0]
410
+ pod_spec = job.spec.template.spec
411
+ assert pod_spec.service_account_name == "diloco-irsa-sa"
412
+ assert pod_spec.runtime_class_name == "gvisor"
413
+
414
+
415
+ def test_launch_timeout_becomes_active_deadline(fake_kubernetes):
416
+ ex, batch, _ = _make_executor(fake_kubernetes)
417
+ ex.launch_replicas(
418
+ n_replicas=1,
419
+ entrypoint="ignored",
420
+ entrypoint_args={"rendezvous_uri": "s3://x/"},
421
+ timeout=7200,
422
+ )
423
+ _, job = batch.created_jobs[0]
424
+ assert job.spec.active_deadline_seconds == 7200
425
+
426
+
427
+ def test_launch_uses_default_entrypoint_command(fake_kubernetes):
428
+ ex, batch, _ = _make_executor(fake_kubernetes)
429
+ ex.launch_replicas(
430
+ n_replicas=1, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}
431
+ )
432
+ _, job = batch.created_jobs[0]
433
+ cmd = job.spec.template.spec.containers[0].command
434
+ assert cmd == [
435
+ "python",
436
+ "-m",
437
+ "composer_replication.diloco.serverless.replica_entrypoint",
438
+ ]
439
+
440
+
441
+ def test_launch_rejects_zero_or_negative(fake_kubernetes):
442
+ ex, _, _ = _make_executor(fake_kubernetes)
443
+ with pytest.raises(ValueError, match="n_replicas"):
444
+ ex.launch_replicas(n_replicas=0, entrypoint="x", entrypoint_args={})
445
+ with pytest.raises(ValueError, match="n_replicas"):
446
+ ex.launch_replicas(n_replicas=-1, entrypoint="x", entrypoint_args={})
447
+
448
+
449
+ # ---------------------------------------------------------------------
450
+ # poll — state mapping from completed/failed indexes + active count
451
+ # ---------------------------------------------------------------------
452
+
453
+
454
+ def _launch_two(fake_kubernetes, batch=None, core=None):
455
+ ex, batch, core = _make_executor(fake_kubernetes, batch=batch, core=core)
456
+ handles = ex.launch_replicas(
457
+ n_replicas=4, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://x/"}
458
+ )
459
+ return ex, batch, core, handles
460
+
461
+
462
+ def test_poll_pending_when_nothing_active(fake_kubernetes):
463
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
464
+ batch.status_obj = _Rec(active=0, completed_indexes=None, failed_indexes=None)
465
+ assert ex.poll(handles[0]) == "pending"
466
+
467
+
468
+ def test_poll_running_when_active(fake_kubernetes):
469
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
470
+ batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None)
471
+ assert ex.poll(handles[2]) == "running"
472
+
473
+
474
+ def test_poll_succeeded_when_rank_in_completed_indexes(fake_kubernetes):
475
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
476
+ # completed_indexes "0,2-3" -> ranks {0,2,3} succeeded; rank 1 still running.
477
+ batch.status_obj = _Rec(
478
+ active=1, completed_indexes="0,2-3", failed_indexes=None
479
+ )
480
+ assert ex.poll(handles[0]) == "succeeded"
481
+ assert ex.poll(handles[2]) == "succeeded"
482
+ assert ex.poll(handles[3]) == "succeeded"
483
+ assert ex.poll(handles[1]) == "running"
484
+
485
+
486
+ def test_poll_failed_when_rank_in_failed_indexes(fake_kubernetes):
487
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
488
+ batch.status_obj = _Rec(
489
+ active=0, completed_indexes="0", failed_indexes="1,3"
490
+ )
491
+ assert ex.poll(handles[1]) == "failed"
492
+ assert ex.poll(handles[3]) == "failed"
493
+ assert ex.poll(handles[0]) == "succeeded"
494
+
495
+
496
+ def test_poll_failed_on_whole_job_failed_condition(fake_kubernetes):
497
+ """DeadlineExceeded etc.: a Failed condition with no per-index info -> failed."""
498
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
499
+ batch.status_obj = _Rec(
500
+ active=0,
501
+ completed_indexes=None,
502
+ failed_indexes=None,
503
+ conditions=[_Rec(type="Failed", status="True", reason="DeadlineExceeded")],
504
+ )
505
+ assert ex.poll(handles[0]) == "failed"
506
+
507
+
508
+ def test_poll_cancelled_on_404(fake_kubernetes):
509
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
510
+ batch.read_raises = _ApiException(status=404)
511
+ assert ex.poll(handles[0]) == "cancelled"
512
+
513
+
514
+ def test_poll_reraises_non_404_api_exception(fake_kubernetes):
515
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
516
+ batch.read_raises = _ApiException(status=500)
517
+ with pytest.raises(_ApiException):
518
+ ex.poll(handles[0])
519
+
520
+
521
+ # ---------------------------------------------------------------------
522
+ # cancel — Background propagation on the shared job, idempotent
523
+ # ---------------------------------------------------------------------
524
+
525
+
526
+ def test_cancel_uses_background_propagation_on_shared_job(fake_kubernetes):
527
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
528
+ ex.cancel(handles[2])
529
+ assert len(batch.delete_calls) == 1
530
+ call = batch.delete_calls[0]
531
+ assert call["propagation_policy"] == "Background"
532
+ assert call["grace_period_seconds"] == 0
533
+ # Cancelling ANY rank deletes the WHOLE shared job (gang semantics).
534
+ assert call["name"] == handles[0].metadata["job_name"]
535
+ assert call["namespace"] == "default"
536
+
537
+
538
+ def test_cancel_swallows_404(fake_kubernetes):
539
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
540
+
541
+ def _raise_404(name, namespace, body=None):
542
+ raise _ApiException(status=404)
543
+
544
+ batch.delete_namespaced_job = _raise_404
545
+ # Must NOT raise (already deleted == success per the Protocol).
546
+ ex.cancel(handles[0])
547
+
548
+
549
+ def test_cancel_unknown_handle_is_noop(fake_kubernetes):
550
+ ex, batch, _, _ = _launch_two(fake_kubernetes)
551
+ fake = ReplicaHandle(rank=99, backend_name="eks", metadata={})
552
+ ex.cancel(fake) # no job_name in metadata -> no-op, no delete call
553
+ assert len(batch.delete_calls) == 0
554
+
555
+
556
+ # ---------------------------------------------------------------------
557
+ # stream_logs — find pod by completion-index annotation
558
+ # ---------------------------------------------------------------------
559
+
560
+
561
+ def test_stream_logs_reads_pod_for_rank(fake_kubernetes):
562
+ pods = [
563
+ _make_pod("diloco-abcd1234-0-xyz", 0),
564
+ _make_pod("diloco-abcd1234-1-xyz", 1),
565
+ ]
566
+ core = _MockCoreV1Api(pods=pods, logs="hello from rank 1\n")
567
+ ex, _, core2, handles = _launch_two(fake_kubernetes, core=core)
568
+ out = ex.stream_logs(handles[1], n_lines=50)
569
+ assert out == "hello from rank 1\n"
570
+ # Read the right pod, container 'replica', tail_lines honored.
571
+ last = core.log_calls[-1]
572
+ assert last["name"] == "diloco-abcd1234-1-xyz"
573
+ assert last["container"] == "replica"
574
+ assert last["tail_lines"] == 50
575
+
576
+
577
+ def test_stream_logs_placeholder_when_pod_missing(fake_kubernetes):
578
+ core = _MockCoreV1Api(pods=[]) # no pods yet
579
+ ex, _, _, handles = _launch_two(fake_kubernetes, core=core)
580
+ out = ex.stream_logs(handles[0])
581
+ assert "rank 0" in out
582
+ assert "not started" in out or "no logs" in out
583
+
584
+
585
+ def test_stream_logs_placeholder_on_400(fake_kubernetes):
586
+ pods = [_make_pod("diloco-abcd1234-0-xyz", 0)]
587
+ core = _MockCoreV1Api(pods=pods)
588
+ core.log_raises = _ApiException(status=400) # pod not started yet
589
+ ex, _, _, handles = _launch_two(fake_kubernetes, core=core)
590
+ out = ex.stream_logs(handles[0])
591
+ assert "rank 0" in out
592
+
593
+
594
+ # ---------------------------------------------------------------------
595
+ # collect — per-rank result dicts in handles order
596
+ # ---------------------------------------------------------------------
597
+
598
+
599
+ def test_collect_returns_terminal_results_in_order(fake_kubernetes):
600
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
601
+ # All four ranks done: 0-2 succeeded, 3 failed.
602
+ batch.status_obj = _Rec(
603
+ active=0, completed_indexes="0-2", failed_indexes="3"
604
+ )
605
+ results = ex.collect(handles, timeout=5)
606
+ assert len(results) == 4
607
+ for i, r in enumerate(results):
608
+ assert r["rank"] == i
609
+ assert r["job_name"] == handles[0].metadata["job_name"]
610
+ assert results[0]["status"] == "succeeded" and results[0]["exit_code"] == 0
611
+ assert results[1]["status"] == "succeeded"
612
+ assert results[2]["status"] == "succeeded"
613
+ assert results[3]["status"] == "failed" and results[3]["exit_code"] == 1
614
+ assert results[3]["error"] is not None
615
+
616
+
617
+ def test_collect_returns_non_terminal_state_at_deadline(fake_kubernetes):
618
+ ex, batch, _, handles = _launch_two(fake_kubernetes)
619
+ # Never finishes: active stays > 0.
620
+ batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None)
621
+ results = ex.collect(handles, timeout=0) # immediate deadline
622
+ assert len(results) == 4
623
+ for r in results:
624
+ assert r["status"] in ("running", "pending")
625
+ assert r["exit_code"] is None
composer_replication/diloco/serverless/tests/test_sagemaker_executor.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for SageMakerExecutor (composer_replication.diloco.serverless.sagemaker).
2
+
3
+ The executor is exercised with an INJECTED mock boto3 sagemaker client (the
4
+ `sagemaker_client=` ctor arg), so these run on any host without boto3 or AWS
5
+ credentials — mirroring the _MockFunctionCall pattern in
6
+ test_modal_spawn_executor.py and the _MockBatchV1Api pattern in
7
+ test_eks_executor.py.
8
+
9
+ Closes the test-coverage gap left when the SageMakerExecutor was first written
10
+ without a test module (caught during Wave-2 integration, 2026-06-09).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import importlib.util
15
+
16
+ import pytest
17
+
18
+ from composer_replication.diloco.serverless import SageMakerExecutor
19
+ from composer_replication.diloco.serverless.executor import ReplicaHandle
20
+
21
+ # ---------------------------------------------------------------------
22
+ # Mock boto3 sagemaker client
23
+ # ---------------------------------------------------------------------
24
+
25
+
26
+ class _MockSMClient:
27
+ """Records create/stop calls and serves a scripted status per job name."""
28
+
29
+ def __init__(self):
30
+ self.created: list[dict] = []
31
+ self.stopped: list[str] = []
32
+ # job_name -> (TrainingJobStatus, SecondaryStatus)
33
+ self._status: dict[str, tuple[str, str]] = {}
34
+ self.raise_not_found_on: set[str] = set()
35
+
36
+ def create_training_job(self, **request):
37
+ self.created.append(request)
38
+ # default a newly-created job to InProgress/Starting (== pending)
39
+ self._status[request["TrainingJobName"]] = ("InProgress", "Starting")
40
+ return {"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{request['TrainingJobName']}"}
41
+
42
+ def describe_training_job(self, TrainingJobName): # noqa: N803 (boto3 casing)
43
+ if TrainingJobName in self.raise_not_found_on:
44
+ raise _ResourceNotFoundError(f"job {TrainingJobName} not found")
45
+ status, secondary = self._status.get(TrainingJobName, ("InProgress", "Training"))
46
+ return {
47
+ "TrainingJobName": TrainingJobName,
48
+ "TrainingJobStatus": status,
49
+ "SecondaryStatus": secondary,
50
+ "TrainingJobArn": f"arn:aws:sagemaker:::training-job/{TrainingJobName}",
51
+ }
52
+
53
+ def stop_training_job(self, TrainingJobName): # noqa: N803
54
+ self.stopped.append(TrainingJobName)
55
+
56
+ # test helper
57
+ def set_status(self, job_name, status, secondary="Completed"):
58
+ self._status[job_name] = (status, secondary)
59
+
60
+
61
+ class _ResourceNotFoundError(Exception):
62
+ """Stand-in for botocore ResourceNotFound (the executor matches on name/text)."""
63
+
64
+ def __init__(self, msg):
65
+ super().__init__(msg)
66
+ # botocore-style response shape some impls check
67
+ self.response = {"Error": {"Code": "ResourceNotFound", "Message": msg}}
68
+
69
+
70
+ def _make_executor(client=None):
71
+ return SageMakerExecutor(
72
+ image_uri="123.dkr.ecr.us-east-1.amazonaws.com/trainer:latest",
73
+ role_arn="arn:aws:iam::123:role/SMRole",
74
+ output_s3_path="s3://bucket/out/",
75
+ region="us-east-1",
76
+ sagemaker_client=client or _MockSMClient(),
77
+ )
78
+
79
+
80
+ _VALID_ARGS = {
81
+ "rendezvous_uri": "s3://bucket/rendezvous/run1/",
82
+ "trainer_module": "my_pkg.trainer",
83
+ }
84
+
85
+
86
+ # ---------------------------------------------------------------------
87
+ # Construction
88
+ # ---------------------------------------------------------------------
89
+
90
+
91
+ def test_backend_identity():
92
+ ex = _make_executor()
93
+ assert ex.backend_name == "sagemaker"
94
+ assert ex.supports_inter_replica_network is False
95
+
96
+
97
+ def test_missing_boto3_raises_when_no_client_injected():
98
+ """The import-guard path only fires when boto3 is genuinely absent."""
99
+ if importlib.util.find_spec("boto3") is not None:
100
+ pytest.skip("boto3 importable; absent-path cannot be exercised")
101
+ with pytest.raises(RuntimeError, match="boto3"):
102
+ SageMakerExecutor(
103
+ image_uri="x", role_arn="r", output_s3_path="s3://b/o/",
104
+ )
105
+
106
+
107
+ def test_construction_with_injected_client_needs_no_boto3():
108
+ ex = _make_executor()
109
+ assert ex is not None
110
+
111
+
112
+ # ---------------------------------------------------------------------
113
+ # launch_replicas
114
+ # ---------------------------------------------------------------------
115
+
116
+
117
+ def test_launch_returns_rank_ordered_handles():
118
+ client = _MockSMClient()
119
+ ex = _make_executor(client)
120
+ handles = ex.launch_replicas(3, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
121
+ assert len(handles) == 3
122
+ assert [h.rank for h in handles] == [0, 1, 2]
123
+ assert all(isinstance(h, ReplicaHandle) and h.backend_name == "sagemaker" for h in handles)
124
+ assert len(client.created) == 3
125
+
126
+
127
+ def test_launch_injects_rank_world_size_and_rendezvous_env():
128
+ client = _MockSMClient()
129
+ ex = _make_executor(client)
130
+ ex.launch_replicas(2, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
131
+ for rank, req in enumerate(client.created):
132
+ env = req["Environment"]
133
+ assert env["REPLICA_RANK"] == str(rank)
134
+ assert env["WORLD_SIZE"] == "2"
135
+ assert env["RENDEZVOUS_URI"] == _VALID_ARGS["rendezvous_uri"]
136
+ # network isolation MUST stay False (else S3 rendezvous deadlocks)
137
+ assert req["EnableNetworkIsolation"] is False
138
+ assert req["OutputDataConfig"]["S3OutputPath"] == "s3://bucket/out/"
139
+ assert req["ResourceConfig"]["InstanceCount"] == 1
140
+
141
+
142
+ def test_launch_validates_n_replicas():
143
+ ex = _make_executor()
144
+ with pytest.raises(ValueError, match="n_replicas"):
145
+ ex.launch_replicas(0, entrypoint="x", entrypoint_args=_VALID_ARGS)
146
+
147
+
148
+ def test_launch_requires_rendezvous_and_trainer_module():
149
+ ex = _make_executor()
150
+ with pytest.raises(ValueError, match="rendezvous_uri"):
151
+ ex.launch_replicas(1, entrypoint="x", entrypoint_args={"trainer_module": "m"})
152
+ with pytest.raises(ValueError, match="trainer_module"):
153
+ ex.launch_replicas(1, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://b/r/"})
154
+
155
+
156
+ def test_launch_partial_failure_stops_siblings_and_raises():
157
+ class _FailingClient(_MockSMClient):
158
+ def create_training_job(self, **request):
159
+ if len(self.created) >= 2: # 3rd create fails
160
+ raise RuntimeError("ThrottlingException")
161
+ return super().create_training_job(**request)
162
+
163
+ client = _FailingClient()
164
+ ex = _make_executor(client)
165
+ with pytest.raises(RuntimeError, match="rank=2"):
166
+ ex.launch_replicas(3, entrypoint="x", entrypoint_args=_VALID_ARGS)
167
+ # the two already-launched siblings were best-effort stopped
168
+ assert len(client.stopped) == 2
169
+
170
+
171
+ # ---------------------------------------------------------------------
172
+ # poll status mapping
173
+ # ---------------------------------------------------------------------
174
+
175
+
176
+ def test_poll_status_mapping():
177
+ client = _MockSMClient()
178
+ ex = _make_executor(client)
179
+ handles = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)
180
+ h = handles[0]
181
+ job = client.created[0]["TrainingJobName"]
182
+
183
+ client.set_status(job, "InProgress", "Starting")
184
+ assert ex.poll(h) == "pending"
185
+ client.set_status(job, "InProgress", "Training")
186
+ assert ex.poll(h) == "running"
187
+ client.set_status(job, "Completed")
188
+ assert ex.poll(h) == "succeeded"
189
+
190
+
191
+ def test_poll_failed_and_stopped():
192
+ client = _MockSMClient()
193
+ ex = _make_executor(client)
194
+ h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
195
+ job = client.created[0]["TrainingJobName"]
196
+ client.set_status(job, "Failed")
197
+ assert ex.poll(h) == "failed"
198
+
199
+ client2 = _MockSMClient()
200
+ ex2 = _make_executor(client2)
201
+ h2 = ex2.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
202
+ job2 = client2.created[0]["TrainingJobName"]
203
+ client2.set_status(job2, "Stopped")
204
+ assert ex2.poll(h2) == "cancelled"
205
+
206
+
207
+ def test_poll_vanished_job_is_cancelled():
208
+ client = _MockSMClient()
209
+ ex = _make_executor(client)
210
+ h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
211
+ client.raise_not_found_on.add(client.created[0]["TrainingJobName"])
212
+ assert ex.poll(h) == "cancelled"
213
+
214
+
215
+ def test_poll_unknown_handle_is_cancelled():
216
+ ex = _make_executor()
217
+ orphan = ReplicaHandle(rank=99, backend_name="sagemaker", metadata={})
218
+ assert ex.poll(orphan) == "cancelled"
219
+
220
+
221
+ # ---------------------------------------------------------------------
222
+ # cancel
223
+ # ---------------------------------------------------------------------
224
+
225
+
226
+ def test_cancel_calls_stop_training_job():
227
+ client = _MockSMClient()
228
+ ex = _make_executor(client)
229
+ h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
230
+ ex.cancel(h)
231
+ assert client.stopped == [client.created[0]["TrainingJobName"]]
232
+
233
+
234
+ def test_cancel_swallows_errors():
235
+ class _RaisingStop(_MockSMClient):
236
+ def stop_training_job(self, TrainingJobName): # noqa: N803
237
+ raise _ResourceNotFoundError("already terminal")
238
+
239
+ client = _RaisingStop()
240
+ ex = _make_executor(client)
241
+ h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
242
+ ex.cancel(h) # must not raise
243
+ # unknown handle must also be a no-op
244
+ ex.cancel(ReplicaHandle(rank=42, backend_name="sagemaker", metadata={}))
composer_replication/safety/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """composer_replication.safety — run-level collapse safeguards.
2
+
3
+ The #2 collapse safeguard for the self-evolving RL flywheel: a held-out disjoint
4
+ eval + a depth/generation kill-switch. The per-task controls live in
5
+ ``composer_replication.datagen`` (4-gate validator, ``HackMonitor`` provenance,
6
+ sandbox denylist); this package adds the missing ACROSS-GENERATION / run-level
7
+ control that watches in-loop (proxy) reward against a disjoint held-out (real)
8
+ eval and HALTS the run when collapse / reward-hacking is caught in the act.
9
+
10
+ Public surface:
11
+ - HeldOutGuard — the stateful kill-switch (kill_switch.py)
12
+ - TripwireStatus — the structured per-update verdict (.fire / .halt / .reason /
13
+ .proxy_real_gap)
14
+ - CollapseStopError — typed exception for exception-based trainer control flow
15
+ - kl_token_trust_filter — per-token KL trust-region mask (torchrl KL-Mask analog)
16
+
17
+ Pure-Python, no torch / cloud deps. See docs/adrs/ADR-015-*.md and the
18
+ 'holdout-killswitch' research digest.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ from composer_replication.safety.kill_switch import (
23
+ CollapseStopError,
24
+ HeldOutGuard,
25
+ TripwireStatus,
26
+ kl_token_trust_filter,
27
+ )
28
+
29
+ __all__ = [
30
+ "HeldOutGuard",
31
+ "TripwireStatus",
32
+ "CollapseStopError",
33
+ "kl_token_trust_filter",
34
+ ]
composer_replication/safety/kill_switch.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """kill_switch.py — held-out collapse tripwire (the #2 collapse safeguard).
2
+
3
+ This is the missing RUN-LEVEL / across-generation control for the self-evolving
4
+ RL flywheel. The per-task controls already exist in ``composer_replication.datagen``
5
+ (the 4-gate solvability validator, the ``HackMonitor`` provenance check, and the
6
+ sandbox denylist); this module sits ABOVE them and watches the whole run.
7
+
8
+ Rationale (the literature is unambiguous that a held-out eval + hard stop is the
9
+ load-bearing control, not a nice-to-have):
10
+
11
+ - **Reward hacking rises monotonically with optimization depth.** Zhao et al.,
12
+ "Reward Hacking in Self-Improving Code Agents" (ICLR 2026 Workshop on RSI,
13
+ OpenReview ``ikrQWGgxYg``) show that going from 10 -> 100 optimization steps
14
+ drives the hacking rate from 26.4% to 57.8% (+31.4 points), and that
15
+ 73.8% of KernelBench / 46.8% of ALE-Bench optimizations show *proxy gains
16
+ without real gains*. They define **Hacking Gap = proxy gain - real gain**;
17
+ this module's ``proxy_real_gap()`` is exactly that quantity. They label an
18
+ optimization reward-hacking when it "improves the public metric WITHOUT
19
+ improving the private metric" — the canonical signature this tripwire fires on.
20
+
21
+ - **Self-critique alone is insufficient.** The same paper's "retrospection"
22
+ self-critique sometimes *increased* hacking; their conclusion: "mitigating
23
+ reward hacking likely requires stronger evaluations and constraints beyond
24
+ self-critique alone." So we build a genuinely disjoint held-out eval plus a
25
+ hard stop, not a critique hook.
26
+
27
+ - **Held-out eval is necessary but NOT sufficient by itself.** EvilGenie
28
+ (arXiv 2511.21654) found "only minimal improvement from the use of held out
29
+ test cases" in isolation and that "holdout tests have many surprising failure
30
+ modes." This module is therefore explicitly *defense-in-depth*, layered ON
31
+ TOP of ``HackMonitor`` (provenance) — neither is sufficient alone, matching
32
+ the repo's existing defense-in-depth framing in ``datagen/monitor.py``.
33
+
34
+ - **Closed-loop RL on self-generated data collapses.** The self-evolving-agents
35
+ survey (Gao et al., TMLR 2026; arXiv 2507.21046 v4) §8.3 names "model
36
+ collapse from closed-loop RL on static synthetic data" and prescribes
37
+ "continuous monitoring ... to detect long-horizon value drift" — i.e. a
38
+ per-generation online tripwire, not a one-time eval. Shumailov et al. (Nature
39
+ 2024, "AI models collapse when trained on recursively generated data") show
40
+ self-training first loses the distribution tails, then converges to a
41
+ low-variance point estimate; the mitigation that matters here is that the
42
+ held-out eval must stay anchored to REAL tasks that are NEVER fed back to the
43
+ generator (see ``HeldoutSplit``), otherwise the eval drifts with the train set.
44
+
45
+ - **KL-to-init hard stop.** The GRPO "healthy progression" band (Orchestra
46
+ Research GRPO SKILL) climbs 0.02 -> 0.05 -> 0.08 -> 0.12 nats/token over a
47
+ run, with 0.08 the top of the "good progression" band and just below the
48
+ code-generation drift zone (0.05-0.15 per-token); >0.5 is "diverging too
49
+ much." So 0.08 nats/token is a sound HARD-STOP default. Catastrophic Goodhart
50
+ (OpenReview ``UXuBzWoZGK``) proves KL regularization alone does NOT prevent
51
+ heavy-tailed reward misspecification, so the KL hard stop is ONE tripwire
52
+ among several, never the sole control.
53
+
54
+ UNITS GOTCHA (load-bearing): the ``kl_to_init`` this module consumes is
55
+ **token-mean KL in nats/token**, matching the repo convention in
56
+ ``composer_replication.integrations.altered_minds.kl_logging.token_mean_kl``.
57
+ A token-mean KL is NOT comparable to a sequence-level / sequence-summed KL
58
+ (whose healthy band is ~0.05-10). The 0.08 default is per-token. Do not pass a
59
+ sequence-summed KL into the per-token hard stop — it will fire instantly.
60
+
61
+ This module is pure-Python: no torch, no cloud deps. ``kl_to_init`` is just a
62
+ float the caller passes (computed upstream by ``token_mean_kl``). It is fully
63
+ CPU-testable.
64
+ """
65
+ from __future__ import annotations
66
+
67
+ from dataclasses import dataclass, field
68
+
69
+
70
+ class CollapseStopError(RuntimeError):
71
+ """Raised (by the caller, optionally) when the tripwire fires a hard stop.
72
+
73
+ The trainer loop can either check ``TripwireStatus.fire`` and stop softly,
74
+ or call ``HeldOutGuard.raise_if_fired(status)`` to convert a fired verdict
75
+ into this typed exception. Carries the structured verdict for logging.
76
+ """
77
+
78
+ def __init__(self, status: TripwireStatus) -> None:
79
+ super().__init__(status.reason)
80
+ self.status = status
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class TripwireStatus:
85
+ """Structured verdict returned by every ``HeldOutGuard.update(...)`` call.
86
+
87
+ Attributes:
88
+ fire: True => the run should HALT (collapse / reward-hacking detected).
89
+ reason: human-readable WHY (empty string when ``fire`` is False), so the
90
+ trainer can log exactly which tripwire tripped, mirroring how
91
+ ``datagen/monitor.py`` logs suspected hacks for review.
92
+ step: the round/generation index this verdict was computed at.
93
+ proxy_real_gap: the RSI "Hacking Gap" at this step = (in-loop reward gain
94
+ since baseline) - (held-out score gain since baseline). Positive and
95
+ widening => proxy improving faster than (or while) real declines.
96
+ in_loop_ema: EMA of the in-loop / proxy reward at this step.
97
+ heldout_ema: EMA of the held-out / real eval score at this step.
98
+ kl_ema: EMA of ``kl_to_init`` (nats/token), or None if never supplied.
99
+ """
100
+
101
+ fire: bool
102
+ reason: str
103
+ step: int
104
+ proxy_real_gap: float
105
+ in_loop_ema: float
106
+ heldout_ema: float
107
+ kl_ema: float | None = None
108
+
109
+ # `halt` is a documented alias for `fire` — the task spec describes a
110
+ # `should_halt()` / verdict with a `halt` field; expose both names so callers
111
+ # reading either convention work.
112
+ @property
113
+ def halt(self) -> bool:
114
+ return self.fire
115
+
116
+
117
+ @dataclass
118
+ class HeldOutGuard:
119
+ """Across-generation collapse / reward-hacking kill-switch (HeldOutGuard).
120
+
121
+ Tracks, per generation/round: in-loop (proxy) oracle reward, held-out (real)
122
+ eval score, and optional KL-to-init / entropy / reward-std. Computes the
123
+ proxy-minus-real "Hacking Gap" tripwire and fires a structured ``halt``
124
+ verdict when collapse is caught in the act.
125
+
126
+ The guard is **stateful**: call ``update(round_idx, ...)`` once per checkpoint
127
+ in the trainer loop (the same cadence at which ``DifficultyCurriculum.update``
128
+ is called). It maintains denoised EMAs of every metric (raw single-step
129
+ values are too noisy to threshold — theneuralbase early-stopping guidance) and
130
+ returns a ``TripwireStatus``.
131
+
132
+ Fires (``fire=True``) when ANY of:
133
+
134
+ (a) **collapse-caught-in-the-act** — the in-loop reward EMA is RISING while
135
+ the held-out score EMA has DECLINED for >= ``decline_patience``
136
+ consecutive checkpoints (default 3, matching the "monotone for >=3
137
+ checkpoints" rule). This is the canonical reward-hacking signature.
138
+
139
+ (b) **KL breach** — the ``kl_to_init`` EMA exceeds ``kl_hard_stop`` (default
140
+ 0.08 nats/token) on/after ``min_steps``.
141
+
142
+ (c) **proxy-real gap blowout** — the Hacking Gap (proxy gain - real gain
143
+ since baseline) widens beyond ``max_proxy_real_gap``, even if held-out
144
+ has not strictly declined for the full patience window (a fast
145
+ single-generation divergence).
146
+
147
+ No tripwire fires before ``min_steps`` (avoids halting on early-run noise,
148
+ when both signals are still warming up).
149
+
150
+ The guard is idempotent in the sense that re-querying ``last_status`` or
151
+ calling ``should_halt()`` does not advance state — only ``update`` does.
152
+ """
153
+
154
+ # --- thresholds (calibratable; see calibrate_kl_threshold) ---------------
155
+ kl_hard_stop: float = 0.08 # nats/token; top of GRPO "good" band
156
+ max_proxy_real_gap: float = 0.10 # absolute Hacking-Gap blowout ceiling
157
+ # --- temporal gates ------------------------------------------------------
158
+ min_steps: int = 20 # no fire before this many updates
159
+ decline_patience: int = 3 # consecutive held-out declines to fire (a)
160
+ # --- denoising -----------------------------------------------------------
161
+ ema_alpha: float = 0.9 # EMA weight on the PRIOR (0.9 => slow)
162
+ rise_eps: float = 1e-4 # min EMA delta to count as "rising"/"declining"
163
+
164
+ # --- internal state (do not set directly) --------------------------------
165
+ _n: int = field(default=0, init=False)
166
+ _in_loop_ema: float | None = field(default=None, init=False)
167
+ _heldout_ema: float | None = field(default=None, init=False)
168
+ _kl_ema: float | None = field(default=None, init=False)
169
+ _entropy_ema: float | None = field(default=None, init=False)
170
+ _reward_std_ema: float | None = field(default=None, init=False)
171
+ _in_loop_baseline: float | None = field(default=None, init=False)
172
+ _heldout_baseline: float | None = field(default=None, init=False)
173
+ _prev_in_loop_ema: float | None = field(default=None, init=False)
174
+ _prev_heldout_ema: float | None = field(default=None, init=False)
175
+ _heldout_decline_streak: int = field(default=0, init=False)
176
+ _last_status: TripwireStatus | None = field(default=None, init=False)
177
+ _fired: bool = field(default=False, init=False)
178
+
179
+ def __post_init__(self) -> None:
180
+ if not (0.0 <= self.ema_alpha < 1.0):
181
+ raise ValueError(
182
+ f"ema_alpha must be in [0, 1), got {self.ema_alpha!r} "
183
+ "(it is the weight on the PRIOR EMA)."
184
+ )
185
+ if self.kl_hard_stop <= 0.0:
186
+ raise ValueError(f"kl_hard_stop must be > 0, got {self.kl_hard_stop!r}")
187
+ if self.decline_patience < 1:
188
+ raise ValueError(
189
+ f"decline_patience must be >= 1, got {self.decline_patience!r}"
190
+ )
191
+
192
+ # ------------------------------------------------------------------------
193
+ # core API
194
+ # ------------------------------------------------------------------------
195
+ def update(
196
+ self,
197
+ round_idx: int,
198
+ in_loop_reward: float,
199
+ heldout_score: float,
200
+ kl_to_init: float | None = None,
201
+ entropy: float | None = None,
202
+ reward_std: float | None = None,
203
+ ) -> TripwireStatus:
204
+ """Fold one checkpoint's metrics in and return the current verdict.
205
+
206
+ Args:
207
+ round_idx: the generation / round index (for logging; not used for
208
+ gating — the internal update counter ``_n`` drives ``min_steps``
209
+ so the guard is robust to non-contiguous round indices).
210
+ in_loop_reward: mean in-loop (proxy / oracle) reward this round. This
211
+ is what the policy is optimizing against.
212
+ heldout_score: mean score on the DISJOINT held-out eval pool this
213
+ round — REAL tasks the generator never trains on. See
214
+ ``composer_replication.safety.holdout`` design notes / the
215
+ ``HeldoutSplit`` discipline; if held-out drifts with the train
216
+ set the gap signal is meaningless.
217
+ kl_to_init: optional token-mean KL(policy || init) in nats/token
218
+ (this repo's ``token_mean_kl`` convention). NOT sequence-level KL.
219
+ entropy: optional policy entropy (early-warning of entropy collapse,
220
+ "the silent killer of RLVR generalization"). Tracked + exposed,
221
+ not currently a hard gate.
222
+ reward_std: optional std of the reward distribution (tracked; a
223
+ collapsing std is an early collapse signal).
224
+
225
+ Returns:
226
+ A ``TripwireStatus``. Once the guard has fired, every subsequent
227
+ ``update`` keeps ``fire=True`` (latched) so a transient recovery
228
+ after a detected collapse cannot silently un-halt the run.
229
+ """
230
+ self._n += 1
231
+
232
+ # --- EMA folds (alpha on the prior; first sample seeds the EMA) -------
233
+ self._in_loop_ema = self._fold(self._in_loop_ema, float(in_loop_reward))
234
+ self._heldout_ema = self._fold(self._heldout_ema, float(heldout_score))
235
+ if kl_to_init is not None:
236
+ self._kl_ema = self._fold(self._kl_ema, float(kl_to_init))
237
+ if entropy is not None:
238
+ self._entropy_ema = self._fold(self._entropy_ema, float(entropy))
239
+ if reward_std is not None:
240
+ self._reward_std_ema = self._fold(self._reward_std_ema, float(reward_std))
241
+
242
+ # --- baselines: seed on the first update so gains are measured from
243
+ # run start (the RSI Hacking-Gap is a gain-since-baseline quantity). -
244
+ if self._in_loop_baseline is None:
245
+ self._in_loop_baseline = self._in_loop_ema
246
+ if self._heldout_baseline is None:
247
+ self._heldout_baseline = self._heldout_ema
248
+
249
+ # --- track the held-out decline streak (uses EMA deltas, denoised) ----
250
+ in_loop_rising = (
251
+ self._prev_in_loop_ema is not None
252
+ and (self._in_loop_ema - self._prev_in_loop_ema) > self.rise_eps
253
+ )
254
+ heldout_declining = (
255
+ self._prev_heldout_ema is not None
256
+ and (self._heldout_ema - self._prev_heldout_ema) < -self.rise_eps
257
+ )
258
+ # The collapse signature is held-out DOWN while in-loop UP. We only count
259
+ # a decline toward the streak when in-loop is simultaneously rising — a
260
+ # held-out dip during an in-loop dip is just noise / a hard batch, not
261
+ # reward hacking.
262
+ if heldout_declining and in_loop_rising:
263
+ self._heldout_decline_streak += 1
264
+ elif not heldout_declining:
265
+ self._heldout_decline_streak = 0
266
+ # (if held-out declines but in-loop is flat/down we neither grow nor reset
267
+ # the streak immediately — but the elif above resets on any non-decline,
268
+ # so a single clean checkpoint clears it.)
269
+
270
+ gap = self.proxy_real_gap()
271
+ status = self._evaluate(round_idx, gap)
272
+
273
+ # advance "previous EMA" trackers AFTER evaluation
274
+ self._prev_in_loop_ema = self._in_loop_ema
275
+ self._prev_heldout_ema = self._heldout_ema
276
+ self._last_status = status
277
+ if status.fire:
278
+ self._fired = True
279
+ return status
280
+
281
+ def _evaluate(self, round_idx: int, gap: float) -> TripwireStatus:
282
+ """Decide the verdict from current state. Pure (no state mutation)."""
283
+ assert self._in_loop_ema is not None and self._heldout_ema is not None
284
+
285
+ base = dict(
286
+ step=round_idx,
287
+ proxy_real_gap=gap,
288
+ in_loop_ema=self._in_loop_ema,
289
+ heldout_ema=self._heldout_ema,
290
+ kl_ema=self._kl_ema,
291
+ )
292
+
293
+ # Latched: once fired, stay fired (cannot silently un-halt).
294
+ if self._fired:
295
+ prev_reason = self._last_status.reason if self._last_status else "collapse"
296
+ return TripwireStatus(fire=True, reason=f"latched: {prev_reason}", **base)
297
+
298
+ # Warm-up guard: never fire on early-run noise.
299
+ if self._n < self.min_steps:
300
+ return TripwireStatus(fire=False, reason="", **base)
301
+
302
+ # (b) KL hard stop — checked first; it's the cheapest unambiguous breach.
303
+ if self._kl_ema is not None and self._kl_ema > self.kl_hard_stop:
304
+ return TripwireStatus(
305
+ fire=True,
306
+ reason=(
307
+ f"kl_to_init EMA {self._kl_ema:.4f} nats/token exceeds hard "
308
+ f"stop {self.kl_hard_stop:.4f} (policy drifting from init)"
309
+ ),
310
+ **base,
311
+ )
312
+
313
+ # (a) collapse caught in the act — held-out declines while in-loop rises.
314
+ if self._heldout_decline_streak >= self.decline_patience:
315
+ return TripwireStatus(
316
+ fire=True,
317
+ reason=(
318
+ f"reward-hacking signature: held-out score declined while "
319
+ f"in-loop reward rose for {self._heldout_decline_streak} "
320
+ f"consecutive checkpoints (Hacking Gap {gap:.4f})"
321
+ ),
322
+ **base,
323
+ )
324
+
325
+ # (c) proxy-real gap blowout — fast single-generation divergence.
326
+ if gap > self.max_proxy_real_gap:
327
+ return TripwireStatus(
328
+ fire=True,
329
+ reason=(
330
+ f"proxy-real Hacking Gap {gap:.4f} exceeds ceiling "
331
+ f"{self.max_proxy_real_gap:.4f} (proxy reward improving far "
332
+ f"faster than real held-out eval)"
333
+ ),
334
+ **base,
335
+ )
336
+
337
+ return TripwireStatus(fire=False, reason="", **base)
338
+
339
+ # ------------------------------------------------------------------------
340
+ # query helpers (do NOT advance state — idempotent)
341
+ # ------------------------------------------------------------------------
342
+ def should_halt(self) -> bool:
343
+ """True if the most recent ``update`` produced a halt verdict.
344
+
345
+ Idempotent: querying does not advance the EMA state.
346
+ """
347
+ return self._last_status is not None and self._last_status.fire
348
+
349
+ @property
350
+ def last_status(self) -> TripwireStatus | None:
351
+ """The most recent verdict, or None if ``update`` was never called."""
352
+ return self._last_status
353
+
354
+ def raise_if_fired(self, status: TripwireStatus | None = None) -> None:
355
+ """Convert a fired verdict into a typed ``CollapseStopError`` exception.
356
+
357
+ Pass the status returned by ``update`` (or omit to use ``last_status``).
358
+ Trainer loops that prefer exception-based control flow call this right
359
+ after ``update``; loops that prefer flag-checking just read
360
+ ``status.fire`` / ``should_halt()``.
361
+ """
362
+ st = status if status is not None else self._last_status
363
+ if st is not None and st.fire:
364
+ raise CollapseStopError(st)
365
+
366
+ def proxy_real_gap(self) -> float:
367
+ """The RSI Hacking Gap = (in-loop gain) - (held-out gain), both measured
368
+ as EMA-minus-baseline since run start.
369
+
370
+ Returns 0.0 before the first ``update`` (no baseline yet). A positive,
371
+ widening value is the reward-hacking fingerprint: the proxy the policy
372
+ optimizes is improving more than the real held-out objective.
373
+ """
374
+ if (
375
+ self._in_loop_ema is None
376
+ or self._heldout_ema is None
377
+ or self._in_loop_baseline is None
378
+ or self._heldout_baseline is None
379
+ ):
380
+ return 0.0
381
+ in_loop_gain = self._in_loop_ema - self._in_loop_baseline
382
+ heldout_gain = self._heldout_ema - self._heldout_baseline
383
+ return in_loop_gain - heldout_gain
384
+
385
+ # ------------------------------------------------------------------------
386
+ # calibration
387
+ # ------------------------------------------------------------------------
388
+ def calibrate_kl_threshold(
389
+ self, baseline_kls: list[float], factor: float = 3.0
390
+ ) -> float:
391
+ """Set ``kl_hard_stop`` to ``factor`` x the mean of early-run baseline KLs.
392
+
393
+ theneuralbase guidance: "Record baseline KL during the first ~100 steps,
394
+ set max to 3x that." Single fixed thresholds are dataset-dependent; this
395
+ adapts to the run's own KL scale.
396
+
397
+ SAFETY CLAMP: calibration may only ever TIGHTEN the hard stop, never
398
+ loosen it past the documented collapse band. The returned (and stored)
399
+ threshold is ``min(3x baseline, current kl_hard_stop)`` — so a noisy /
400
+ already-drifting baseline cannot raise the ceiling above 0.08 nats/token.
401
+
402
+ Args:
403
+ baseline_kls: per-step token-mean KL values from early in the run.
404
+ factor: multiplier on the baseline mean (default 3.0).
405
+
406
+ Returns:
407
+ The new ``kl_hard_stop`` (also stored on the instance).
408
+
409
+ Raises:
410
+ ValueError: if ``baseline_kls`` is empty.
411
+ """
412
+ if not baseline_kls:
413
+ raise ValueError("baseline_kls must be non-empty to calibrate")
414
+ mean_kl = sum(baseline_kls) / len(baseline_kls)
415
+ calibrated = factor * mean_kl
416
+ # Only tighten: never let calibration loosen past the current ceiling.
417
+ self.kl_hard_stop = min(calibrated, self.kl_hard_stop)
418
+ return self.kl_hard_stop
419
+
420
+ # ------------------------------------------------------------------------
421
+ # internals
422
+ # ------------------------------------------------------------------------
423
+ def _fold(self, prev: float | None, x: float) -> float:
424
+ """EMA fold; the first observation seeds the EMA (no warm-up bias)."""
425
+ if prev is None:
426
+ return x
427
+ return self.ema_alpha * prev + (1.0 - self.ema_alpha) * x
428
+
429
+
430
+ def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool:
431
+ """Per-token KL trust-region mask, mirroring torchrl's GRPO "KL-Mask".
432
+
433
+ torchrl masks any TOKEN whose ``0.5 * (log pi/pi_ref)^2`` (the Schulman k2
434
+ estimator of per-token KL) exceeds a threshold, forming a per-token trust
435
+ region. This helper returns True when the token should be MASKED OUT (its
436
+ KL contribution is too large), so it can be wired into a loss later without
437
+ pulling torch into this module — the caller computes ``0.5 * logratio**2``.
438
+
439
+ Args:
440
+ logratio_sq_half: ``0.5 * (log pi/pi_ref)^2`` for one token (nats).
441
+ threshold: per-token KL ceiling (default 0.08 nats, the same band as the
442
+ run-level hard stop).
443
+
444
+ Returns:
445
+ True if the token exceeds the trust region and should be masked.
446
+ """
447
+ return logratio_sq_half > threshold
composer_replication/safety/tests/__init__.py ADDED
File without changes
composer_replication/safety/tests/test_kill_switch.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the held-out collapse kill-switch (HeldOutGuard).
2
+
3
+ CPU-only, pure-Python — no torch, no cloud. Mirrors the
4
+ ``datagen/tests/test_feature_deletion.py`` style (small helpers, behavioral
5
+ asserts). Covers:
6
+ - no-halt on a healthy co-rising run (the held-out-twin "within noise" case);
7
+ - HALT on the canonical signature: held-out declines while in-loop rises;
8
+ - HALT on KL-to-init hard-stop breach;
9
+ - HALT on a fast proxy-real Hacking-Gap blowout;
10
+ - window / patience behavior (min_steps warm-up; decline_patience streak);
11
+ - calibration tightens-only;
12
+ - idempotent query + latched-fire edge cases.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import pytest
17
+
18
+ from composer_replication.safety import (
19
+ CollapseStopError,
20
+ HeldOutGuard,
21
+ TripwireStatus,
22
+ kl_token_trust_filter,
23
+ )
24
+
25
+
26
+ def _guard(**kw) -> HeldOutGuard:
27
+ # Small min_steps keeps tests fast while still exercising the warm-up gate.
28
+ base = dict(min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08)
29
+ base.update(kw)
30
+ return HeldOutGuard(**base)
31
+
32
+
33
+ # --- healthy run: both rise => never halt -----------------------------------
34
+
35
+ def test_no_halt_when_both_rise():
36
+ """Clean run: in-loop and held-out rise together, KL stays in band. The
37
+ held-out twin scores within noise of the proxy => no fire (the well-behaved
38
+ case the literature says a clean model exhibits)."""
39
+ g = _guard()
40
+ status = None
41
+ for i in range(30):
42
+ status = g.update(
43
+ i,
44
+ in_loop_reward=0.30 + 0.01 * i,
45
+ heldout_score=0.28 + 0.01 * i, # tracks proxy within noise
46
+ kl_to_init=0.03,
47
+ )
48
+ assert not status.fire, f"fired unexpectedly at step {i}: {status.reason}"
49
+ assert not g.should_halt()
50
+ # Gap stays near zero because both gained equally.
51
+ assert abs(g.proxy_real_gap()) < 0.05
52
+
53
+
54
+ # --- canonical signature: held-out declines while in-loop rises -------------
55
+
56
+ def test_halt_on_heldout_declines_while_reward_rises():
57
+ g = _guard(max_proxy_real_gap=10.0) # disable gap-blowout path to isolate (a)
58
+ # Warm up past min_steps with a stable healthy stretch.
59
+ for i in range(6):
60
+ s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
61
+ assert not s.fire
62
+ # Now: proxy reward climbs, held-out eval falls — the reward-hacking
63
+ # fingerprint. Should fire once the decline streak hits decline_patience (3).
64
+ fired_at = None
65
+ for j, i in enumerate(range(6, 12)):
66
+ s = g.update(
67
+ i,
68
+ in_loop_reward=0.40 + 0.05 * (j + 1), # rising
69
+ heldout_score=0.40 - 0.05 * (j + 1), # declining
70
+ kl_to_init=0.03, # KL stays in band
71
+ )
72
+ if s.fire:
73
+ fired_at = i
74
+ break
75
+ assert fired_at is not None, "tripwire never fired on the collapse signature"
76
+ assert g.should_halt()
77
+ s = g.last_status
78
+ assert "held-out" in s.reason and "consecutive" in s.reason
79
+ assert s.proxy_real_gap > 0.0 # proxy gained while real lost
80
+
81
+
82
+ def test_does_not_fire_before_patience_window():
83
+ """Held-out declining while in-loop rises for FEWER than decline_patience
84
+ checkpoints must NOT fire (window behavior)."""
85
+ g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
86
+ for i in range(6):
87
+ g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
88
+ # Only 2 divergent checkpoints (< patience of 3) => no fire.
89
+ s1 = g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
90
+ s2 = g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
91
+ assert not s1.fire and not s2.fire
92
+
93
+
94
+ def test_decline_streak_resets_on_recovery():
95
+ """A clean checkpoint (held-out recovers) resets the decline streak, so a
96
+ later short divergence does not inherit prior declines."""
97
+ g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
98
+ for i in range(6):
99
+ g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
100
+ # 2 declines...
101
+ g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
102
+ g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
103
+ # ...then held-out recovers (resets streak)...
104
+ s = g.update(8, in_loop_reward=0.50, heldout_score=0.45, kl_to_init=0.03)
105
+ assert not s.fire
106
+ # ...one more decline is only streak=1, still below patience.
107
+ s = g.update(9, in_loop_reward=0.55, heldout_score=0.40, kl_to_init=0.03)
108
+ assert not s.fire
109
+
110
+
111
+ # --- KL hard-stop ------------------------------------------------------------
112
+
113
+ def test_halt_on_kl_hard_stop_breach():
114
+ g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
115
+ # Healthy KL through the warm-up; both metrics flat so only KL can fire.
116
+ for i in range(5):
117
+ s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
118
+ assert not s.fire
119
+ # KL spikes well above 0.08; EMA climbs across a couple steps then crosses.
120
+ fired = False
121
+ for i in range(5, 12):
122
+ s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.20)
123
+ if s.fire:
124
+ fired = True
125
+ assert "kl_to_init" in s.reason and "hard stop" in s.reason
126
+ break
127
+ assert fired, "KL hard-stop never fired despite KL EMA crossing the ceiling"
128
+
129
+
130
+ def test_kl_none_never_fires_kl_path():
131
+ """If the caller never supplies kl_to_init, the KL path must be inert (and
132
+ kl_ema stays None) — KL is an optional float."""
133
+ g = _guard(max_proxy_real_gap=10.0)
134
+ s = None
135
+ for i in range(20):
136
+ s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=None)
137
+ assert s is not None and not s.fire
138
+ assert s.kl_ema is None
139
+
140
+
141
+ # --- proxy-real gap blowout (fast divergence) -------------------------------
142
+
143
+ def test_halt_on_proxy_real_gap_blowout():
144
+ """A large single-generation divergence (proxy jumps, real stays flat) fires
145
+ via the gap-blowout path even before the decline streak reaches patience."""
146
+ g = _guard(max_proxy_real_gap=0.10, decline_patience=100) # disable (a)
147
+ for i in range(5):
148
+ g.update(i, in_loop_reward=0.30, heldout_score=0.30, kl_to_init=0.03)
149
+ # Proxy blows up; held-out flat. With ema_alpha=0.5 the gap crosses 0.10 fast.
150
+ fired = False
151
+ for i in range(5, 12):
152
+ s = g.update(i, in_loop_reward=0.90, heldout_score=0.30, kl_to_init=0.03)
153
+ if s.fire:
154
+ fired = True
155
+ assert "Hacking Gap" in s.reason
156
+ assert s.proxy_real_gap > 0.10
157
+ break
158
+ assert fired, "gap-blowout tripwire never fired"
159
+
160
+
161
+ # --- warm-up window (min_steps) ---------------------------------------------
162
+
163
+ def test_respects_min_steps_no_early_fire():
164
+ """Even with every signal tripped, no fire before min_steps (avoids halting
165
+ on early-run noise)."""
166
+ g = _guard(min_steps=10, decline_patience=2, kl_hard_stop=0.08,
167
+ max_proxy_real_gap=0.01)
168
+ # Egregiously bad signals from step 0: KL huge, proxy up, held-out down.
169
+ for i in range(9): # 9 updates, all < min_steps=10
170
+ s = g.update(i, in_loop_reward=0.10 + 0.1 * i, heldout_score=0.90 - 0.1 * i,
171
+ kl_to_init=0.9)
172
+ assert not s.fire, f"fired during warm-up at step {i}: {s.reason}"
173
+ # The 10th update (n==10, not < min_steps) is now allowed to fire.
174
+ s = g.update(9, in_loop_reward=1.5, heldout_score=0.0, kl_to_init=0.9)
175
+ assert s.fire
176
+
177
+
178
+ # --- calibration -------------------------------------------------------------
179
+
180
+ def test_calibrate_kl_threshold_tightens_only():
181
+ g = _guard(kl_hard_stop=0.08)
182
+ # Baseline mean 0.01 => 3x = 0.03 < 0.08 => tightens to 0.03.
183
+ new = g.calibrate_kl_threshold([0.008, 0.010, 0.012], factor=3.0)
184
+ assert new == pytest.approx(0.03, abs=1e-9)
185
+ assert g.kl_hard_stop == pytest.approx(0.03, abs=1e-9)
186
+
187
+
188
+ def test_calibrate_never_loosens_past_band():
189
+ g = _guard(kl_hard_stop=0.08)
190
+ # A drifting baseline (mean 0.05 => 3x = 0.15) must NOT loosen past 0.08.
191
+ new = g.calibrate_kl_threshold([0.05, 0.05, 0.05], factor=3.0)
192
+ assert new == pytest.approx(0.08, abs=1e-9)
193
+ assert g.kl_hard_stop == pytest.approx(0.08, abs=1e-9)
194
+
195
+
196
+ def test_calibrate_empty_raises():
197
+ g = _guard()
198
+ with pytest.raises(ValueError, match="non-empty"):
199
+ g.calibrate_kl_threshold([])
200
+
201
+
202
+ # --- proxy_real_gap definition ----------------------------------------------
203
+
204
+ def test_proxy_real_gap_is_gain_difference():
205
+ g = _guard(min_steps=100, max_proxy_real_gap=10.0) # disable firing
206
+ g.update(0, in_loop_reward=0.20, heldout_score=0.20, kl_to_init=0.02) # baseline
207
+ # With ema_alpha=0.5 the second sample moves each EMA halfway.
208
+ g.update(1, in_loop_reward=0.60, heldout_score=0.30, kl_to_init=0.02)
209
+ # in_loop EMA: 0.5*0.20 + 0.5*0.60 = 0.40; gain = 0.40-0.20 = 0.20
210
+ # heldout EMA: 0.5*0.20 + 0.5*0.30 = 0.25; gain = 0.25-0.20 = 0.05
211
+ # gap = 0.20 - 0.05 = 0.15
212
+ assert g.proxy_real_gap() == pytest.approx(0.15, abs=1e-9)
213
+
214
+
215
+ def test_proxy_real_gap_zero_before_update():
216
+ g = _guard()
217
+ assert g.proxy_real_gap() == 0.0
218
+
219
+
220
+ # --- idempotency / edge cases -----------------------------------------------
221
+
222
+ def test_should_halt_is_idempotent_query():
223
+ g = _guard(max_proxy_real_gap=10.0)
224
+ for i in range(6):
225
+ g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
226
+ # Querying repeatedly must not advance state or change the verdict.
227
+ snap_gap = g.proxy_real_gap()
228
+ assert g.should_halt() is False
229
+ assert g.should_halt() is False
230
+ assert g.proxy_real_gap() == snap_gap # unchanged by querying
231
+ assert g.last_status is not None and not g.last_status.fire
232
+
233
+
234
+ def test_fire_is_latched():
235
+ """Once fired, a subsequent recovery cannot silently un-halt the run."""
236
+ g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
237
+ for i in range(5):
238
+ g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
239
+ # Drive a KL breach.
240
+ fired = False
241
+ for i in range(5, 12):
242
+ s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
243
+ if s.fire:
244
+ fired = True
245
+ break
246
+ assert fired
247
+ # Now KL recovers to healthy — verdict must stay fired (latched).
248
+ s = g.update(99, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.01)
249
+ assert s.fire and s.reason.startswith("latched:")
250
+ assert g.should_halt()
251
+
252
+
253
+ def test_raise_if_fired_raises_typed_exception():
254
+ g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
255
+ for i in range(5):
256
+ g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
257
+ status = None
258
+ for i in range(5, 12):
259
+ status = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
260
+ if status.fire:
261
+ break
262
+ assert status is not None and status.fire
263
+ with pytest.raises(CollapseStopError) as exc:
264
+ g.raise_if_fired(status)
265
+ assert exc.value.status is status
266
+ assert isinstance(str(exc.value), str) and str(exc.value)
267
+
268
+
269
+ def test_raise_if_fired_noop_when_clean():
270
+ g = _guard(max_proxy_real_gap=10.0)
271
+ s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
272
+ # No fire => no raise (uses last_status when arg omitted).
273
+ g.raise_if_fired(s)
274
+ g.raise_if_fired()
275
+
276
+
277
+ def test_status_halt_alias_matches_fire():
278
+ g = _guard(max_proxy_real_gap=10.0)
279
+ s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
280
+ assert s.halt == s.fire is False
281
+ assert isinstance(s, TripwireStatus)
282
+
283
+
284
+ def test_non_contiguous_round_idx_uses_internal_counter():
285
+ """min_steps gates on the internal update counter, not round_idx, so a caller
286
+ that logs sparse / non-contiguous round indices still warms up correctly."""
287
+ g = _guard(min_steps=3, max_proxy_real_gap=0.01, decline_patience=1)
288
+ # Pass huge round_idx values; only the 3rd UPDATE clears warm-up.
289
+ g.update(1000, in_loop_reward=0.10, heldout_score=0.90, kl_to_init=0.9)
290
+ g.update(2000, in_loop_reward=0.50, heldout_score=0.50, kl_to_init=0.9)
291
+ s = g.update(3000, in_loop_reward=0.90, heldout_score=0.10, kl_to_init=0.9)
292
+ assert s.fire # 3rd update, n==3 not < min_steps
293
+
294
+
295
+ # --- config validation -------------------------------------------------------
296
+
297
+ def test_bad_ema_alpha_rejected():
298
+ with pytest.raises(ValueError, match="ema_alpha"):
299
+ HeldOutGuard(ema_alpha=1.0)
300
+ with pytest.raises(ValueError, match="ema_alpha"):
301
+ HeldOutGuard(ema_alpha=-0.1)
302
+
303
+
304
+ def test_bad_kl_hard_stop_rejected():
305
+ with pytest.raises(ValueError, match="kl_hard_stop"):
306
+ HeldOutGuard(kl_hard_stop=0.0)
307
+
308
+
309
+ def test_bad_decline_patience_rejected():
310
+ with pytest.raises(ValueError, match="decline_patience"):
311
+ HeldOutGuard(decline_patience=0)
312
+
313
+
314
+ # --- kl_token_trust_filter helper -------------------------------------------
315
+
316
+ def test_kl_token_trust_filter_masks_above_threshold():
317
+ # 0.5 * logratio^2; mask when it exceeds the per-token KL ceiling.
318
+ assert kl_token_trust_filter(0.20, threshold=0.08) is True # too large -> mask
319
+ assert kl_token_trust_filter(0.05, threshold=0.08) is False # within trust region
320
+ assert kl_token_trust_filter(0.08, threshold=0.08) is False # boundary, not masked
docs/BACKLOG_RESOLUTION_2026-06-09.md CHANGED
@@ -52,6 +52,31 @@ Goal-driven systematic resolution of every pending item. This doc is the live au
52
  | F1 (`…-cb74`) | **ROTATE exposed HF write-token** — USER-ONLY (requires HF account access). AUDIT done: no live token in tracked tree (only env-var reads). Action = user rotates on huggingface.co. | P1 | DOCUMENTED (user-only) |
53
  | F2 | Real 8B LMA run (A2/A3/A4 arms `…-42f5`,`…-dd7b`) + higher-lr sweep RUNS — GPU + budget + user go/no-go. Harness buildable (E1/E2); the spend is user-only. | — | GATED (harness only) |
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ## Wave plan
56
  - **Wave 1 (parallel):** B1, B2, B3, B4, B5, B6, B7, B8 (bugs + doc debt) ‖ D1 (Docker E2E) ‖ research fan-out (Tavily/Exa/DeepWiki) for C1/C2/E1/E2 best practices.
57
  - **Wave 2 (parallel, after research):** C1 (held-out eval + kill-switch) ‖ C2 (EKSExecutor) ‖ C3 (containerized sandbox) ‖ E1/E2/E3 harnesses.
 
52
  | F1 (`…-cb74`) | **ROTATE exposed HF write-token** — USER-ONLY (requires HF account access). AUDIT done: no live token in tracked tree (only env-var reads). Action = user rotates on huggingface.co. | P1 | DOCUMENTED (user-only) |
53
  | F2 | Real 8B LMA run (A2/A3/A4 arms `…-42f5`,`…-dd7b`) + higher-lr sweep RUNS — GPU + budget + user go/no-go. Harness buildable (E1/E2); the spend is user-only. | — | GATED (harness only) |
54
 
55
+ ## Status log
56
+
57
+ **Wave 1 — DONE (commit `c11cf49`):** B1 ✅ (fixture generated, 8 tests pass), B2 ✅ ([dev] installs on arm64), B3 ✅ ([serverless] deps), B4 ✅ (266/62 canonical), B5 ✅ (WSL footers), B6 ✅ (dead ADR link), B7 ✅ (config factories re-exported + documented), B8 ✅ (refine-summary + OVERVIEW xref), **D1 ✅ (Docker substrate E2E GREEN — 2/2 gates on real container; long-blocked item closed)**. F1 (token rotation) audited — no live token in tracked tree; user-only action documented.
58
+
59
+ **Wave 2 — DONE (built + integrated + tested):** C1 ✅ HeldOutGuard kill-switch (`composer_replication/safety/`, 23 tests), C2 ✅ EKSExecutor (single Indexed Job → N handles, gang-cancel; `eks.py` + 28 tests), C3 ✅ DockerSandbox (`docker_sandbox.py` + shared `scrub_tree` refactor; live Docker tests pass), E3 ✅ SageMakerExecutor (`sagemaker.py`; +13-test module I added — the build agent shipped it test-less, gap closed during integration). All 4 modules lint-clean, re-exported, 90/3 on targeted suite. Grounded in Phase-3 research.
60
+
61
+ **Wave 3 — Phase-7 reconciliation (from the concurrent review team `research/review-*.json`):**
62
+ | ID | Item | Sev | Status |
63
+ |---|---|---|---|
64
+ | R1 | **Wire `HeldOutGuard` into `composer_trainer.py`** at per-checkpoint cadence (alongside `DifficultyCurriculum.update`), feeding `token_mean_kl` as `kl_to_init`, converting a fired verdict to halt via `raise_if_fired`. Currently dead code — the #2 safeguard never fires in production. | HIGH | OPEN |
65
+ | R2 | **Build `composer_replication/safety/holdout.py` `HeldoutSplit`** disjointness enforcer (id/hash set-difference, raises on train↔held-out overlap) — the un-built second half of C1; the guard's gap signal is meaningless without it. | HIGH | OPEN |
66
+ | R3 | **EKS contract bug:** `launch_replicas` default container command runs `replica_entrypoint __main__` (argparse needs `--rendezvous/--world-size/--trainer-module`) but the indexed-job spec passes rank/world via env, not argv → a real run would fail arg-parsing. Reconcile the entrypoint contract. | HIGH | OPEN |
67
+ | R4 | `calibrate_kl_threshold` can yield a NEGATIVE `kl_hard_stop` on `factor<=0`/negative baseline → fires every healthy step. Guard inputs / clamp to positive floor. | LOW | OPEN |
68
+ | R5 | EKS/SageMaker `cancel()` swallow ALL exceptions (report success on real failure). Narrow to already-terminated (404/ResourceNotFound). | LOW | OPEN |
69
+ | R6 | `EKSExecutor.collect()` result dicts miss the `result` key the other backends include — cross-backend shape uniformity. | LOW | OPEN |
70
+ | R7 | **Doc-debt:** the 4 new Wave-2 public symbols (EKSExecutor, SageMakerExecutor, DockerSandbox, HeldOutGuard/safety) are undocumented in API_REFERENCE.md; add §12 + `.eks`/`.aws` extras. | MED | OPEN |
71
+ | R8 | **ADR-015** for the held-out kill-switch — referenced by `safety/__init__.py:17` + kill_switch docstrings but doesn't exist (dangling refs). Author it or drop the refs. | LOW | OPEN |
72
+ | R9 | Re-measure + refresh canonical test count in V1_V8_COVERAGE (Wave 2 added ~93 tests; 328→~420 collected). | LOW | OPEN |
73
+ | R10 | Add a test pinning the kill-switch path-(c) both-rising gap-blowout behavior; document path-(c) as a divergence-rate gate. | LOW | OPEN |
74
+
75
+ | R11 | Flaky test `spikes/006-real-hf-model-smoke/tests/test_strict.py::test_alternating_batches_loss_decreases` — fails under CPU contention (full suite w/ concurrent pytest + Docker), PASSES in isolation (verified 3x). Loss-trend assertion is timing/noise-sensitive. Pin seed / widen tolerance / mark flaky. Pre-existing, not a Wave-2 regression. | LOW | OPEN |
76
+ | R12 | B7-complete ✅ (top-level `__all__` now includes the 3 factories) + B4-complete ✅ (the 4 surviving "115" claims → 266/62). | — | DONE |
77
+
78
+ Sandbox refactor verdict: **clean** (no regression to LocalSubprocessSandbox/FeatureDeletionEnv).
79
+
80
  ## Wave plan
81
  - **Wave 1 (parallel):** B1, B2, B3, B4, B5, B6, B7, B8 (bugs + doc debt) ‖ D1 (Docker E2E) ‖ research fan-out (Tavily/Exa/DeepWiki) for C1/C2/E1/E2 best practices.
82
  - **Wave 2 (parallel, after research):** C1 (held-out eval + kill-switch) ‖ C2 (EKSExecutor) ‖ C3 (containerized sandbox) ‖ E1/E2/E3 harnesses.
docs/OVERVIEW.md CHANGED
@@ -52,8 +52,8 @@ where channel 1 is real GRPO rather than the LM-CE stub. See
52
  trainer on a real reasoning benchmark.
53
  - **Economic feasibility of channel 3.** 150 real OpenRouter calls, $0.98/trace mean, 0
54
  errors (Spike 001).
55
- - **Installable + tested.** `pip install -e .` works; **115 passing tests + 1 skip-marked**
56
- (canonical count: [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md)).
57
 
58
  ## What's gapped (honest, NOT closed)
59
 
 
52
  trainer on a real reasoning benchmark.
53
  - **Economic feasibility of channel 3.** 150 real OpenRouter calls, $0.98/trace mean, 0
54
  errors (Spike 001).
55
+ - **Installable + tested.** `pip install -e .` works; **266 passing / 62 skipped** (measured 2026-06-09;
56
+ canonical count + why skips vary by env: [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md)).
57
 
58
  ## What's gapped (honest, NOT closed)
59
 
docs/VISION_VALIDATION.md CHANGED
@@ -1,7 +1,7 @@
1
  # Vision Validation: Does the Framework Encapsulate the Original Brief?
2
 
3
  > **## Status as of 2026-06 (current through ADR-014)**
4
- > The framework is past-skeleton: 8 subpackages (`composer_replication/*`), 115 passing
5
  > tests + 1 skip-marked (see [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md) for the
6
  > canonical count), and operational end-to-end examples (`gsm8k_grpo`,
7
  > `sdpo_with_real_traces_production`). The 3-channel loss, layered hint-generation,
 
1
  # Vision Validation: Does the Framework Encapsulate the Original Brief?
2
 
3
  > **## Status as of 2026-06 (current through ADR-014)**
4
+ > The framework is past-skeleton: 8 subpackages (`composer_replication/*`), 266 passing (canonical count + env-variance note in docs/V1_V8_COVERAGE.md)
5
  > tests + 1 skip-marked (see [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md) for the
6
  > canonical count), and operational end-to-end examples (`gsm8k_grpo`,
7
  > `sdpo_with_real_traces_production`). The 3-channel loss, layered hint-generation,
pyproject.toml CHANGED
@@ -69,6 +69,15 @@ serverless = [
69
  "boto3>=1.34", # SageMakerExecutor (create_training_job) + S3 IAM
70
  "kubernetes>=29.0", # EKSExecutor (indexed k8s Jobs via BatchV1Api)
71
  ]
 
 
 
 
 
 
 
 
 
72
  # Replaysim dataset normalization (per ADR-004)
73
  #
74
  # NOTE: data-juicer is intentionally NOT pinned as an extra. The package
 
69
  "boto3>=1.34", # SageMakerExecutor (create_training_job) + S3 IAM
70
  "kubernetes>=29.0", # EKSExecutor (indexed k8s Jobs via BatchV1Api)
71
  ]
72
+ # Amazon EKS / Kubernetes Indexed-Job executor (EKSExecutor, per ADR-005).
73
+ # kubernetes is lazy-imported at adapter-init/method time (not at package import).
74
+ eks = [
75
+ "kubernetes>=29",
76
+ ]
77
+ # Amazon SageMaker training-job executor (SageMakerExecutor, per ADR-005).
78
+ aws = [
79
+ "boto3>=1.34",
80
+ ]
81
  # Replaysim dataset normalization (per ADR-004)
82
  #
83
  # NOTE: data-juicer is intentionally NOT pinned as an extra. The package
research/review-executors.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "area": "composer_replication/diloco/serverless: EKSExecutor + SageMakerExecutor vs ServerlessExecutor Protocol",
3
+ "verdict": "minor-issues",
4
+ "findings": [
5
+ {
6
+ "severity": "high",
7
+ "what": "EKS rank/arg plumbing contract mismatch: launch_replicas defaults the container command to ['python','-m','composer_replication.diloco.serverless.replica_entrypoint'] with NO container args, and plumbs rendezvous_uri/world_size as UPPER-CASED env vars (RENDEZVOUS_URI, WORLD_SIZE). But replica_entrypoint.py's __main__ block uses argparse with --rendezvous, --world-size, --trainer-module ALL required=True and reads NONE of those env vars (only REPLICA_RANK via os.environ). It also reads trainer_module via --trainer-module, which EKS never plumbs in any form. A pod launched with the documented EKS defaults therefore SystemExits at startup ('the following arguments are required: --rendezvous, --world-size, --trainer-module'). SageMakerExecutor does this correctly (passes ContainerArguments=['--rendezvous',...,'--world-size',...,'--trainer-module',...,'--trainer-fn',...,'--trainer-kwargs-json',...] matching the entrypoint argparse exactly). This is an end-to-end run correctness bug, not a Protocol-signature gap, and it is untested (test_launch_uses_default_entrypoint_command only asserts the command vector; no test asserts the entrypoint can actually parse what EKS supplies; trainer_module is never asserted to reach the container).",
8
+ "where": "composer_replication/diloco/serverless/eks.py:220-224 (default command), :296-303 (_build_env upper-cases scalars, drops nothing else), :405-458 (launch_replicas passes no ContainerArguments); contract owner composer_replication/diloco/serverless/replica_entrypoint.py:91-109 (argparse required=True, no env fallback)",
9
+ "recommendation": "Pick ONE: (a) make EKS pass the same arg vector SageMaker does by appending args to self.command (e.g. command + ['--rendezvous', uri, '--world-size', N, '--trainer-module', tm, ...]); OR (b) add an env-var fallback to replica_entrypoint.__main__ (read RENDEZVOUS_URI/WORLD_SIZE/TRAINER_MODULE/etc. from os.environ when CLI args are absent) so the env-only EKS plumbing works unchanged. Add a test that constructs the entrypoint argv/env exactly as EKS would and asserts main() can be invoked (e.g. argparse parse over the supplied tokens, or env-driven path). Ensure trainer_module is plumbed in whichever channel is chosen."
10
+ },
11
+ {
12
+ "severity": "low",
13
+ "what": "EKS cancel() swallows ALL non-404 ApiExceptions and even generic Exceptions with a bare 'return', reporting success even when the gang delete genuinely failed (e.g. 403 RBAC-denied, 409 conflict). Because the whole point of gang-cancel is to stop the entire GPU-burning cohort, a silently-swallowed real teardown failure leaves the cohort running while the caller believes it was cancelled — the exact failure mode the design calls out. SageMakerExecutor.cancel has the same broad swallow. The Protocol only requires 'no exception if already terminated', which a 404 satisfies; swallowing 403/409 is broader than the contract needs.",
14
+ "where": "composer_replication/diloco/serverless/eks.py:594-600 (except ApiException -> swallow non-404; except Exception -> swallow); composer_replication/diloco/serverless/sagemaker.py:470-475 (bare except Exception: pass)",
15
+ "recommendation": "Narrow the swallow to the 'already terminated' cases (404 / ResourceNotFound, and SageMaker's already-terminal ValidationException) and at minimum log/warn (or re-raise) on other API errors so a failed gang-teardown of GPU resources is observable rather than silent. Best-effort can still mean 'do not raise', but it should emit a warning on a non-idempotent failure."
16
+ },
17
+ {
18
+ "severity": "low",
19
+ "what": "Result-dict shape inconsistency across executors in collect(): SageMaker and Modal/Local include a 'result' key (SageMaker surfaces ModelArtifacts.S3ModelArtifacts path; Local/Modal include the in-process return value), but EKS _result_dict omits 'result' entirely and instead adds 'job_name'. The Protocol only mandates {rank,status,exit_code,error} 'at least', so this is conformant, but the divergence makes a backend-agnostic caller that reads result['result'] KeyError on EKS. Note: this is NOT a 'collect() not reading S3' Protocol violation — the Protocol/ADR-005 do not require collect() to read S3 contents; the payload flows through ObjectStoreAllReduce/S3 written by the replica itself, and collect() correctly returns status metadata (the reference LocalProcessExecutor returns an in-process value, not S3). SageMaker surfacing the S3 artifact path is a nice-to-have, not a requirement.",
20
+ "where": "composer_replication/diloco/serverless/eks.py:655-671 (_result_dict: no 'result' key, adds 'job_name'); compare sagemaker.py:588-595 (includes 'result': artifacts.get('S3ModelArtifacts')) and executor.py:104-107 (Protocol documents only the 4 required keys)",
21
+ "recommendation": "For cross-backend uniformity, add a 'result': None (or the rendezvous output URI if known) key to EKS _result_dict so callers can read result['result'] uniformly across executors. Optionally document in the Protocol docstring that 'result' is an optional, backend-specific extra key so callers use .get('result')."
22
+ }
23
+ ],
24
+ "confirmed_good": [
25
+ "Both EKSExecutor and SageMakerExecutor satisfy the runtime_checkable ServerlessExecutor Protocol: isinstance(EKSExecutor(image=...,batch_api=...,core_api=...), ServerlessExecutor) is True and isinstance(SageMakerExecutor(...), ServerlessExecutor) is True (verified at runtime); both expose backend_name ('eks'/'sagemaker'), supports_inter_replica_network (both False, correct — S3-only rendezvous), and all five methods launch_replicas/poll/stream_logs/cancel/collect.",
26
+ "Both are exported from serverless/__init__.py and present in __all__ (EKSExecutor line 50/62, SageMakerExecutor line 59/68).",
27
+ "EKS single-Indexed-Job -> N-handles topology is correct: exactly one create_namespaced_job, completions==parallelism==n_replicas, completionMode='Indexed', restartPolicy='Never', backoffLimit=0, active_deadline_seconds==timeout, ttl_seconds_after_finished set; returns N rank-ordered handles (handles[i].rank==i) all sharing job_name/namespace (test_launch_returns_n_rank_ordered_handles, test_launch_creates_indexed_job_spec).",
28
+ "EKS gang-cancel is correct: cancel(any handle) deletes the WHOLE shared Job with propagation_policy='Background' (cascading pod deletion, not the k8s default Orphan) and grace_period_seconds=0; idempotent on 404 (test_cancel_uses_background_propagation_on_shared_job, test_cancel_swallows_404, test_cancel_unknown_handle_is_noop).",
29
+ "EKS rank plumbing via downward API is correct: REPLICA_RANK set via V1EnvVarSource.field_ref field_path metadata.annotations['batch.kubernetes.io/job-completion-index'] (value is None, value_from set), bridging k8s completion-index to the entrypoint's REPLICA_RANK read without modifying the entrypoint; rank_env LocalProcessExecutor convention is stripped (test_launch_rank_env_uses_downward_api_field_ref, test_launch_strips_rank_env_kwarg). NOTE: this rank channel works; the BROKEN channel is rendezvous_uri/trainer_module (see high finding).",
30
+ "EKS poll status mapping covers all five Protocol states: rank in completed_indexes->succeeded (checked first, so a succeeded rank is not mis-flagged by a whole-job Failed condition), rank in failed_indexes->failed, whole-job Failed condition->failed (DeadlineExceeded/backoff), active>0->running, else pending, 404->cancelled, non-404 ApiException re-raised; run-length index strings expanded correctly incl. reversed ranges and whitespace (test_poll_* x7, test_expand_indexes_*).",
31
+ "EKS GPU resource limit is always a STRING ('1' not int 1) per OpenAPI dict[str,str] typing; GPU node selector merged (caller wins) and nvidia.com/gpu NoSchedule toleration auto-added; CPU-only omits the gpu limit (test_launch_gpu_limit_is_string, test_launch_cpu_only_omits_gpu_limit).",
32
+ "EKS partial-failure sibling cleanup is correctly N/A: launch issues exactly ONE create_namespaced_job (atomic gang scheduling), so there are no siblings to clean up if it fails — a genuine advantage of the single-Indexed-Job topology over N-job designs.",
33
+ "SageMaker correctly uses N independent single-instance training jobs (ResourceConfig.InstanceCount==1) with rank via the Environment map (REPLICA_RANK/WORLD_SIZE/RENDEZVOUS_URI), and correctly passes the entrypoint args via ContainerArguments matching replica_entrypoint argparse; EnableNetworkIsolation pinned False (else S3 rendezvous deadlocks) — verified in test_launch_injects_rank_world_size_and_rendezvous_env.",
34
+ "SageMaker partial-failure sibling cleanup is correct: a create_training_job failure at rank k best-effort stops the k already-launched siblings then raises with rank context (test_launch_partial_failure_stops_siblings_and_raises asserts 2 siblings stopped).",
35
+ "SageMaker poll status mapping covers all 5 documented TrainingJobStatus values (InProgress->running, with SecondaryStatus refinement to pending for Starting/Pending/LaunchingMLInstances/PreparingTrainingStack; Completed->succeeded; Failed->failed; Stopping->running; Stopped->cancelled), vanished job (ResourceNotFound)->cancelled, unknown handle->cancelled; collect() correctly checks RAW SM status for terminality so Stopping keeps polling until Stopped (test_poll_status_mapping, test_poll_failed_and_stopped, test_poll_vanished_job_is_cancelled, test_poll_unknown_handle_is_cancelled).",
36
+ "collect() reading S3: NOT a violation. The Protocol (executor.py:96-108) and ADR-005 require collect() to return status/exit metadata, not S3 contents — the result payload flows through ObjectStoreAllReduce written to S3 by each replica. SageMaker even surfaces the ModelArtifacts S3 path in result['result']. The reference LocalProcessExecutor returns an in-process value, confirming collect is not contractually an S3 reader.",
37
+ "Full suite green: .venv/bin/python -m pytest composer_replication/diloco/serverless -q => 53 passed, 17 skipped (skips are the boto3/kubernetes/modal absent-path guards that cannot fire when the package is importable in this interpreter, plus integration gates)."
38
+ ],
39
+ "new_backlog_items": [
40
+ "EKS end-to-end run bug: default container command runs replica_entrypoint __main__ (argparse --rendezvous/--world-size/--trainer-module required) but EKSExecutor supplies env vars + no args and never plumbs trainer_module -> pod crashes on startup. Fix by passing ContainerArguments-equivalent args OR adding an env-var fallback to replica_entrypoint.__main__; add a test that the EKS-supplied argv/env actually parses. (Not in BACKLOG_RESOLUTION_2026-06-09; C2 only tracked building EKSExecutor, not the entrypoint contract.)",
41
+ "Tighten EKSExecutor.cancel and SageMakerExecutor.cancel exception handling: only swallow 'already-terminated' errors (404/ResourceNotFound, already-terminal ValidationException); log/warn on other API errors so a failed gang-teardown of GPU resources is observable instead of silently leaving the cohort burning compute.",
42
+ "Add a 'result' key to EKSExecutor.collect() result dicts (None or the rendezvous output URI) for cross-backend uniformity with Local/Modal/SageMaker, OR document in the Protocol that 'result' is an optional backend-specific extra so callers use .get('result')."
43
+ ]
44
+ }
research/review-newgaps.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "area": "Wave-1+2 broad sweep for NEW gaps (imports/laziness, unfinished-work markers, doc-debt, ADR-015, optional-dep eager-load)",
3
+ "verdict": "minor-issues",
4
+ "findings": [
5
+ {
6
+ "severity": "medium",
7
+ "what": "Doc-debt: the 4 NEW Wave-2 public symbols are entirely undocumented in docs/API_REFERENCE.md. grep for EKSExecutor / SageMakerExecutor / DockerSandbox / HeldOutGuard / TripwireStatus / CollapseStopError / kl_token_trust_filter all return 0 hits. API_REFERENCE §12 (serverless) header (line 23) lists `.modal`, `.hf_jobs` but not `.eks` / `.sagemaker`, and documents the loud-failing ModalExecutor/HFJobsExecutor stubs while omitting the two NEW *production* executors. There is no `safety` section at all, and no `datagen` section (DockerSandbox + its LocalSubprocessSandbox/FakeSandbox siblings are all undocumented). All four are real, exported public API (in their package __all__) and Protocol-conformant (isinstance(eks, ServerlessExecutor) == True).",
8
+ "where": "docs/API_REFERENCE.md (§12 line 1153-1376; header line 23); new public symbols in composer_replication/diloco/serverless/{eks,sagemaker}.py, composer_replication/datagen/docker_sandbox.py, composer_replication/safety/kill_switch.py",
9
+ "recommendation": "Add API_REFERENCE entries: under §12 add `class EKSExecutor` and `class SageMakerExecutor` (and update the §12 line-23 module list to include `.eks`, `.sagemaker`); add a `composer_replication.safety` section documenting HeldOutGuard / TripwireStatus / CollapseStopError / kl_token_trust_filter; and a `composer_replication.datagen` section documenting DockerSandbox (alongside the existing-but-also-undocumented LocalSubprocessSandbox/FakeSandbox)."
10
+ },
11
+ {
12
+ "severity": "low",
13
+ "what": "Dangling ADR reference: composer_replication/safety/__init__.py:17 says 'See docs/adrs/ADR-015-*.md' but no ADR-015 file exists (docs/adrs/ stops at ADR-014). The research plan called for ADR-015 to document the safety/kill-switch design decision; the module docstring already cites the literature (Zhao et al. RSI, EvilGenie, Gao self-evolving survey, Shumailov collapse, Catastrophic Goodhart, GRPO KL band) so the design rationale exists in-code but is not captured as an ADR, and the __init__ points readers to a file that isn't there.",
14
+ "where": "composer_replication/safety/__init__.py:17 (the dangling 'docs/adrs/ADR-015-*.md' pointer); docs/adrs/ (ADR-015 absent)",
15
+ "recommendation": "Either author docs/adrs/ADR-015-holdout-killswitch.md (the kill_switch.py module docstring is effectively the ADR draft already — proxy_real_gap Hacking-Gap, KL 0.08 nats/token hard stop, decline-patience collapse signature, defense-in-depth-over-HackMonitor) and index it in docs/adrs/README.md, OR remove the forward reference from safety/__init__.py until the ADR lands."
16
+ },
17
+ {
18
+ "severity": "low",
19
+ "what": "Test-count drift re-introduced by Wave 2. docs/V1_V8_COVERAGE.md:117 still states the canonical count as '266 passed / 62 skipped / 328 collected (measured 2026-06-09)' — that was the Wave-1 figure. Wave 2 added 93 tests across 4 new files (test_kill_switch 23, test_eks_executor 28, test_sagemaker_executor 14, test_docker_sandbox 28); the tree now collects 420 tests (328 -> 420, +92 net). B4 closed test-drift in Wave 1 but the doc is stale again post-Wave-2.",
20
+ "where": "docs/V1_V8_COVERAGE.md:117-134 (canonical count claim) vs actual `pytest --collect-only` = 420 collected",
21
+ "recommendation": "Re-run `.venv/bin/python -m pytest` to get the post-Wave-2 passed/skipped split and update the single canonical figure in V1_V8_COVERAGE.md (the doc explicitly says this line is 'the one canonical figure' that other docs reference)."
22
+ }
23
+ ],
24
+ "confirmed_good": [
25
+ "Required import smoke test passes: `import composer_replication; from composer_replication.diloco.serverless import EKSExecutor, SageMakerExecutor; from composer_replication.datagen import DockerSandbox; from composer_replication.safety import HeldOutGuard` -> exit 0, 'ALL IMPORTS OK'.",
26
+ "Optional-dep laziness (question 5) is CORRECT for all 4 new modules: no top-level `import kubernetes/boto3/docker` in eks.py / sagemaker.py / docker_sandbox.py / kill_switch.py (grep for eager imports returns empty). Blocking kubernetes+docker at import time and importing the new modules in isolation succeeds. EKSExecutor lazy-imports `kubernetes` only when no api injected / per-method; SageMakerExecutor lazy-imports boto3 in _make_boto3_client (construction-time, not import-time); DockerSandbox lazy-imports docker via _require_docker() inside methods.",
27
+ "NOTE on the whole-package blocked-import failure: blocking boto3 breaks `import composer_replication`, but the cause is PRE-EXISTING and NOT a Wave-2 regression — composer_replication/__init__.py:98 imports the trainer, which imports `trl.GRPOTrainer` -> accelerate.commands.config.sagemaker -> `import boto3`. boto3 is already a hard transitive dependency of the base trainer stack on main; Wave 2 did not introduce it.",
28
+ "No NEW unfinished-work markers (question 2): all NotImplementedError/TODO/FIXME/STUB hits in composer_replication/ are PRE-EXISTING and intentional (prime_rl/composer_loss.py deferred SDPO channel-2, recipes/monarch/actors.py v0 skeleton per ADR-006, diloco/serverless/{modal,hf_jobs,modal_spawn}.py documented loud-failing stubs). The 4 new modules contain ZERO NotImplementedError/TODO/FIXME/STUB — they are finished, not skeletons. SageMakerExecutor's docstring explicitly contrasts itself as 'fully-working, not the loud-failing modal.py/hf_jobs.py skeletons'.",
29
+ "Both new executors satisfy the runtime_checkable ServerlessExecutor Protocol (isinstance checks pass), expose correct backend_name ('eks'/'sagemaker') and supports_inter_replica_network=False (S3-only rendezvous).",
30
+ "All 90 collectable Wave-2 tests pass (3 skipped, the live-docker-daemon gated ones) via `pytest composer_replication/safety/tests composer_replication/diloco/serverless/tests/test_{eks,sagemaker}_executor.py composer_replication/datagen/tests/test_docker_sandbox.py`. Whole suite still collects cleanly (420 tests, no collection errors).",
31
+ "DockerSandbox.run_tests pytest-pass heuristic (`f\"{t} PASSED\" in out or (returncode==0 and not failed)`) is a faithful copy of the established LocalSubprocessSandbox.run_tests (sandbox.py:214) — not a new bug, consistent with the documented sibling behavior.",
32
+ "safety/ not being in the top-level composer_replication.__all__ is consistent with existing structure (datagen/diloco subpackages aren't fully surfaced at top level either); `composer_replication.safety` imports correctly as a subpackage."
33
+ ],
34
+ "new_backlog_items": [
35
+ "DOC: Document the 4 NEW Wave-2 public symbols in docs/API_REFERENCE.md — add EKSExecutor + SageMakerExecutor under §12 (and add .eks/.sagemaker to the §12 module list at line 23), add a new `composer_replication.safety` section (HeldOutGuard, TripwireStatus, CollapseStopError, kl_token_trust_filter), and a `composer_replication.datagen` section covering DockerSandbox (+ the also-undocumented LocalSubprocessSandbox/FakeSandbox).",
36
+ "ADR: Author docs/adrs/ADR-015-holdout-killswitch.md (the safety kill-switch / held-out-guard design) — currently referenced by composer_replication/safety/__init__.py:17 as 'docs/adrs/ADR-015-*.md' but the file does not exist; index it in docs/adrs/README.md. The kill_switch.py module docstring is the ready-made draft.",
37
+ "DOC: Refresh the canonical test count in docs/V1_V8_COVERAGE.md:117 — Wave 2 added 93 tests (collection 328 -> 420); the stated '266 passed / 62 skipped / 328 collected' is the Wave-1 figure and is now stale."
38
+ ]
39
+ }
research/review-safety.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "area": "composer_replication/safety/kill_switch.py + test_kill_switch.py (Wave-2 C1)",
3
+ "verdict": "material-issues",
4
+ "findings": [
5
+ {
6
+ "severity": "high",
7
+ "what": "C1 was scoped as 'Held-out disjoint eval + depth/generation kill-switch' but ONLY the kill-switch half (HeldOutGuard) was built. The HeldoutSplit disjointness-enforcer does not exist anywhere in the tree (no composer_replication/safety/holdout.py, no HeldoutSplit class). The guard's heldout_score is an unvalidated caller-supplied float; nothing enforces that the held-out pool is actually disjoint from the train/generator set. The module's own docstring (kill_switch.py:41-43, 214-216) states this is load-bearing: 'if held-out drifts with the train set the gap signal is meaningless.' So the kill-switch's central proxy-real-gap and decline-streak signals can be silently meaningless with no guard rail.",
8
+ "where": "composer_replication/safety/ (missing holdout.py / HeldoutSplit); referenced at kill_switch.py:43, kill_switch.py:214-216",
9
+ "recommendation": "Build the HeldoutSplit disjointness enforcer (hash/id-based set-difference check that the held-out eval IDs never intersect the generator/train IDs, raising on overlap) as the second half of C1, OR explicitly re-scope C1 to two items and track the disjointness enforcer as a distinct OPEN backlog item. Do not mark C1 done with only the guard built."
10
+ },
11
+ {
12
+ "severity": "high",
13
+ "what": "HeldOutGuard is NOT wired into the trainer. Zero references to HeldOutGuard / kill_switch / CollapseStopError / should_halt / raise_if_fired in composer_replication/trainer/composer_trainer.py (or anywhere outside the safety package + its own test). The 'most load-bearing collapse safeguard (#2)' for the self-evolving flywheel exists as dead, never-invoked code. The trainer's GRPO loop never calls update() per checkpoint, so the run-level tripwire cannot fire in production.",
14
+ "where": "composer_replication/trainer/composer_trainer.py (no integration); HeldOutGuard defined composer_replication/safety/kill_switch.py:117",
15
+ "recommendation": "Wire HeldOutGuard.update(round_idx, in_loop_reward, heldout_score, kl_to_init=token_mean_kl(...)) into the trainer loop at the same checkpoint cadence DifficultyCurriculum.update is called (curriculum.py:78), and convert a fired verdict to a halt via raise_if_fired / should_halt. token_mean_kl already exists (kl_logging.py:53) to supply the per-token KL. Until wired, C1's safety claim is unrealized."
16
+ },
17
+ {
18
+ "severity": "low",
19
+ "what": "calibrate_kl_threshold does not re-validate the > 0 invariant that __post_init__ enforces. A negative factor (or negative baseline_kls) yields min(negative, 0.08) = a NEGATIVE kl_hard_stop, after which the KL tripwire fires on EVERY healthy step (any positive KL EMA > negative ceiling). Verified empirically: factor=-3.0 on baseline [0.01] sets kl_hard_stop=-0.03 and a healthy KL of 0.01 then fires. The min() 'tighten-only' clamp is satisfied in the literal numeric sense but violates the documented collapse-band semantics.",
20
+ "where": "composer_replication/safety/kill_switch.py:412-418 (calibrate_kl_threshold)",
21
+ "recommendation": "Validate factor > 0 and all(k >= 0 for k in baseline_kls) at the top of calibrate_kl_threshold, and/or clamp the result to a small positive floor (e.g. assert calibrated > 0). KL values are non-negative by definition so a negative factor is nonsensical input, but the invariant should be guarded since the method mutates a field __post_init__ otherwise protects."
22
+ },
23
+ {
24
+ "severity": "low",
25
+ "what": "Dangling cross-references in docstrings to artifacts that do not exist: safety/__init__.py:17-18 cites 'docs/adrs/ADR-015-*.md' (highest existing ADR is ADR-014; no ADR-015 file) and a \"'holdout-killswitch' research digest\" (no such file under research/). kill_switch.py:43,214 cite composer_replication.safety.holdout / HeldoutSplit 'design notes' that do not exist (same missing module as the high finding).",
26
+ "where": "composer_replication/safety/__init__.py:17-18; composer_replication/safety/kill_switch.py:43, 214-216",
27
+ "recommendation": "Either author ADR-015 documenting the kill-switch design decision (the module is substantial enough to warrant one and the docstring already promises it), or drop the dangling citations. Keep doc references honest to avoid the stale-cross-ref foot-guns the backlog (B5/B6/B8) is already cleaning up."
28
+ },
29
+ {
30
+ "severity": "low",
31
+ "what": "Gap-blowout path (c) fires when the proxy gain exceeds real gain by max_proxy_real_gap EVEN WHEN the held-out (real) score is still genuinely RISING. Verified: with both rising but proxy faster, it halts the run while real improvement is ongoing. This is defensible per the docstring ('fast single-generation divergence', lines 144-145), and the reason string is accurate, but it is a potential false-positive halt on a healthy-but-fast-proxy run and is not covered by a test asserting the desired behavior in the both-rising case (only the proxy-flat-real case is tested at test_kill_switch.py:143).",
32
+ "where": "composer_replication/safety/kill_switch.py:326-335 (path c); test gap test_kill_switch.py:143-158 only exercises real-flat",
33
+ "recommendation": "Add a test pinning the intended behavior when BOTH rise but proxy outpaces real beyond the ceiling (assert whether it should fire), and document in the docstring that path (c) is a divergence-RATE gate, not a real-decline gate, so future readers do not mistake a fired path-(c) for confirmed real regression."
34
+ }
35
+ ],
36
+ "confirmed_good": [
37
+ "All 23 tests in composer_replication/safety pass (.venv pytest, 23 passed).",
38
+ "Latched-fire is correct and cannot un-halt: _fired flips True in update() (line 277-278) and _evaluate() short-circuits with a 'latched:' verdict carrying the original reason before any threshold re-check (lines 294-296). Verified a full KL/gap recovery after fire stays fire=True.",
39
+ "Three halt conditions are individually correct: (b) KL EMA > kl_hard_stop checked first; (a) held-out-declines-while-in-loop-rises only increments the streak when BOTH conditions hold (a both-declining 'hard batch' correctly does NOT count, verified), fires at decline_patience; (c) proxy-real gap > ceiling. min_steps warm-up gate uses the internal _n counter (robust to non-contiguous round_idx, tested).",
40
+ "EMA denoising is sound: _fold seeds on first sample (no warm-up bias), alpha is weight-on-prior validated to [0,1); first-sample baseline seeding makes proxy_real_gap a gain-since-baseline quantity exactly matching the RSI Hacking-Gap definition. proxy_real_gap math verified (0.15 expected case) and returns 0.0 before first update.",
41
+ "CollapseStopError raise path: raise_if_fired raises the typed exception carrying .status only when fired, is a no-op when clean, and is a safe no-op before any update (last_status None). Strict > boundary on gap/KL confirmed (gap==ceiling does not fire).",
42
+ "calibrate tighten-only works for the intended (positive) inputs: min(3x baseline, current) so a drifting baseline cannot loosen past 0.08 (tested), and only tightens for a clean low baseline.",
43
+ "kl_token_trust_filter boundary correct (strict >, so threshold value itself is not masked).",
44
+ "Docstring cross-refs that DO resolve: DifficultyCurriculum.update (curriculum.py:78) and token_mean_kl (kl_logging.py:53) both exist, so the claimed cadence and KL-units convention are anchored to real code.",
45
+ "No false claim anywhere in examples/ or docs that the kill-switch is already wired/used (grep clean)."
46
+ ],
47
+ "new_backlog_items": [
48
+ "Build composer_replication/safety/holdout.py with a HeldoutSplit disjointness enforcer (id/hash set-difference, raises on train/held-out overlap) — the un-built second half of C1 that the kill-switch's gap/decline signals depend on for validity.",
49
+ "Wire HeldOutGuard into composer_replication/trainer/composer_trainer.py at the per-checkpoint cadence (alongside DifficultyCurriculum.update), feeding token_mean_kl as kl_to_init and converting a fired verdict to a halt via raise_if_fired/should_halt — the C1 safeguard is currently dead code.",
50
+ "Guard calibrate_kl_threshold against factor<=0 / negative baseline_kls (or clamp result to a positive floor) so calibration cannot drive kl_hard_stop negative and make the KL tripwire fire on every healthy step.",
51
+ "Author docs/adrs/ADR-015 for the held-out kill-switch (referenced by safety/__init__.py:17 but nonexistent) or remove the dangling ADR-015 + 'holdout-killswitch research digest' citations.",
52
+ "Add a test pinning path-(c) gap-blowout behavior in the BOTH-rising case (proxy outpaces a still-rising real) to lock the intended false-positive/true-positive decision."
53
+ ]
54
+ }
research/review-sandbox.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "area": "composer_replication/datagen/docker_sandbox.py + sandbox.py scrub_tree refactor",
3
+ "verdict": "clean",
4
+ "findings": [
5
+ {
6
+ "severity": "low",
7
+ "what": "run_tests pass/fail parse carries the order-dependent fallback clause `if f\"{t} PASSED\" in out or (returncode == 0 and not failed)` verbatim from LocalSubprocessSandbox. If a runner exits 0 but does not print '<nodeid> PASSED' for every node id, the first un-printed node is marked passed solely on the exit code (and `not failed` is true only until the first failure is recorded). This is a pre-existing pattern (identical on main's LocalSubprocessSandbox at sandbox.py:214) faithfully mirrored into DockerSandbox, NOT a new regression — flagged only for completeness.",
8
+ "where": "composer_replication/datagen/docker_sandbox.py:272-276 (and the source LocalSubprocessSandbox at sandbox.py:212-217)",
9
+ "recommendation": "No action required for this review. If ever hardened, require an explicit PASSED token per node id and stop trusting the bare exit code; do it in both sandboxes together so they stay in lock-step."
10
+ }
11
+ ],
12
+ "confirmed_good": [
13
+ "REFACTOR DID NOT BREAK LocalSubprocessSandbox: boot() still scrubs — boot() (sandbox.py:169-172) calls self._scrub_tree() which delegates to the shared module-level scrub_tree() free function (sandbox.py:174-177). Smoke test confirmed __pycache__, .git, and *.pyc are removed on boot while real source (keep.py) survives.",
14
+ "No broken/dangling references to the old per-class _scrub_tree: the only remaining _scrub_tree occurrences are (a) the intentional back-compat delegating method + its self-call in boot, and (b) one descriptive comment in test_docker_substrate_e2e.py:161. grep for external callers of .SCRUB_NAMES/._SCRUB_NAMES/.SCRUB_SUFFIXES returned EMPTY.",
15
+ "Back-compat preserved: LocalSubprocessSandbox._SCRUB_NAMES / ._SCRUB_SUFFIXES class aliases still point at the module-level SCRUB_NAMES/SCRUB_SUFFIXES; the _scrub_tree() method is retained.",
16
+ "FeatureDeletionEnv unaffected: env.py uses the Sandbox Protocol generically (boot/exec/run_tests/trajectory at env.py:59,69,86,89) — agnostic to the scrub refactor.",
17
+ "SCRUB-BEFORE-MOUNT ORDERING IS CORRECT (no security bug): DockerSandbox.boot() runs scrub_tree(self.workdir) at line 190 BEFORE self._client.containers.run(**kwargs) at line 198. The container (and thus the RW bind mount) does not exist when the host-side scrub runs, so the scrub is provably pre-mount. The scrub-AFTER-mount security bug the audit asked to rule out is NOT present.",
18
+ "--network none: both network_disabled=True AND network_mode='none' set (docker_sandbox.py:154-155); live test_live_network_is_disabled actually ran on a real container and asserted egress BLOCKED / not CONNECTED.",
19
+ "Resource limits: mem_limit == memswap_limit (forbids swap), pids_limit (fork-bomb guard), nano_cpus (CPU quota); all present, configurable, and unit-asserted.",
20
+ "Ephemeral teardown: close() force-removes (idempotent, swallows errors), reap_leaked() sweeps label-filtered orphan containers at boot and shutdown, __enter__/__exit__/__del__ wired. Verified by test_close_removes_container_force, test_context_manager_closes, test_reap_leaked_sweeps_labelled_containers.",
21
+ "gVisor runtime option: runtime defaults to None (=> 'runtime' kwarg omitted, daemon-default runc); 'runsc' is only passed through when explicitly set (docker_sandbox.py:178-179) and gated by runsc_available(). test_live_runsc_runtime correctly SKIPPED (gVisor not installed on host).",
22
+ "Lazy docker import: _require_docker() imports `docker` inside the function with a clear RuntimeError on ImportError; docker SDK is never required by the FakeSandbox/pure-core path. Verified by test_require_docker_missing_sdk_raises.",
23
+ "Privilege lockdown: cap_drop=['ALL'], security_opt=['no-new-privileges:true'], user='1000:1000' (non-root), read_only root fs with tmpfs /tmp (noexec,nosuid), keep_root_writable escape hatch.",
24
+ "shlex.quote applied to every test node id in run_tests (shell-injection guard, matches LocalSubprocessSandbox); non-UTF-8 output decoded with errors='replace' (test_exec_decodes_non_utf8_bytes); exec wraps commands in coreutils `timeout`.",
25
+ "TEST SUITE: `.venv/bin/python -m pytest composer_replication/datagen -q` => 61 passed, 1 skipped (runsc only). The LIVE Docker E2E genuinely RAN (not skipped): test_live_four_inversion_gates_in_hardened_container, test_live_network_is_disabled, test_live_cache_scrub_removes_bytecode all PASSED on a real python:3.11-slim container. The long-blocked D1 substrate E2E (test_docker_substrate_e2e.py) is also GREEN (2/2). Broader regression datagen+safety+serverless => 137 passed, 18 skipped, no failures.",
26
+ "Public surface re-exports DockerSandbox and scrub_tree from composer_replication/datagen/__init__.py and __all__; package imports cleanly."
27
+ ],
28
+ "new_backlog_items": []
29
+ }