Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Wave 2: 4 new modules (kill-switch, EKS/SageMaker executors, DockerSandbox) + B4/B7 completion
Browse filesBuilt by the parallel execution team (worktree-isolated), integrated + tested here.
New modules (all CPU-testable via mock/lazy-import; optional deps gated):
- composer_replication/safety/ — HeldOutGuard run-level collapse kill-switch
(held-out-declines-while-reward-rises streak + KL-to-init hard-stop 0.08 +
proxy-real hacking-gap; EMA-denoised, latched-fire, CollapseStopError). The
documented #2 safeguard for the self-evolving flywheel. 23 tests.
- composer_replication/diloco/serverless/eks.py — EKSExecutor satisfying the
ServerlessExecutor Protocol via a SINGLE Kubernetes Indexed Job → N rank-ordered
ReplicaHandles, gang-cancel (Background propagation), REPLICA_RANK via the
downward API, S3 rendezvous (IRSA). 28 tests (mock BatchV1/CoreV1).
- composer_replication/diloco/serverless/sagemaker.py — SageMakerExecutor
(boto3 create_training_job, same S3 rendezvous, status mapping). +13-test
module written during integration (the build agent shipped it test-less).
- composer_replication/datagen/docker_sandbox.py — DockerSandbox (ephemeral
container, --network none, mem/pids limits, gVisor runtime option) + refactored
the per-class _scrub_tree into a shared module-level scrub_tree free function
so every sandbox backend applies the identical reward-hack control. Live Docker
tests pass; LocalSubprocessSandbox/FeatureDeletionEnv unaffected (review: clean).
Wiring + completion:
- Re-exported EKSExecutor/SageMakerExecutor (serverless __init__) and
DockerSandbox/scrub_tree (datagen __init__).
- pyproject: added [eks] (kubernetes) + [aws] (boto3) extras.
- B7-complete: added make_dr_grpo_config/make_po_config/PO_OBJECTIVES to the
TOP-LEVEL __all__ (were importable but missing from __all__).
- B4-complete: reconciled the 4 surviving stale "115 passing" current-framed
claims (README/OVERVIEW/VISION_VALIDATION) to the canonical 266/62.
- All new files ruff-clean (E,F,W,I,N,UP,B).
Full suite: 355 passed / 65 skipped / 1 flaky-under-contention (spike-006
loss-trend test, passes in isolation — tracked as R11, not a regression).
Wave-3 backlog (R1-R12) filed in docs/BACKLOG_RESOLUTION_2026-06-09.md from the
concurrent review team.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- composer_replication/__init__.py +3 -0
- composer_replication/datagen/__init__.py +5 -0
- composer_replication/datagen/docker_sandbox.py +331 -0
- composer_replication/datagen/sandbox.py +49 -28
- composer_replication/datagen/tests/test_docker_sandbox.py +620 -0
- composer_replication/diloco/serverless/__init__.py +4 -0
- composer_replication/diloco/serverless/eks.py +674 -0
- composer_replication/diloco/serverless/executor.py +4 -3
- composer_replication/diloco/serverless/sagemaker.py +619 -0
- composer_replication/diloco/serverless/tests/test_eks_executor.py +625 -0
- composer_replication/diloco/serverless/tests/test_sagemaker_executor.py +244 -0
- composer_replication/safety/__init__.py +34 -0
- composer_replication/safety/kill_switch.py +447 -0
- composer_replication/safety/tests/__init__.py +0 -0
- composer_replication/safety/tests/test_kill_switch.py +320 -0
- docs/BACKLOG_RESOLUTION_2026-06-09.md +25 -0
- docs/OVERVIEW.md +2 -2
- docs/VISION_VALIDATION.md +1 -1
- pyproject.toml +9 -0
- research/review-executors.json +44 -0
- research/review-newgaps.json +39 -0
- research/review-safety.json +54 -0
- research/review-sandbox.json +29 -0
|
@@ -132,6 +132,9 @@ __all__ = [
|
|
| 132 |
"replay_trace",
|
| 133 |
# Trainer
|
| 134 |
"ComposerReplicationTrainer",
|
|
|
|
|
|
|
|
|
|
| 135 |
# DiLoCo (optional)
|
| 136 |
"make_diloco_outer_loop",
|
| 137 |
# Meta
|
|
|
|
| 132 |
"replay_trace",
|
| 133 |
# Trainer
|
| 134 |
"ComposerReplicationTrainer",
|
| 135 |
+
"make_dr_grpo_config",
|
| 136 |
+
"make_po_config",
|
| 137 |
+
"PO_OBJECTIVES",
|
| 138 |
# DiLoCo (optional)
|
| 139 |
"make_diloco_outer_loop",
|
| 140 |
# Meta
|
|
@@ -8,6 +8,7 @@ Public surface:
|
|
| 8 |
- FeatureDeletionTask — the task tuple (schema.py)
|
| 9 |
- FeatureDeletionEnv — Gym/OpenEnv-style env + TRL reward_fn adapter (env.py)
|
| 10 |
- Sandbox / FakeSandbox / LocalSubprocessSandbox — execution backends (sandbox.py)
|
|
|
|
| 11 |
- HackMonitor — reward-hacking provenance monitor (monitor.py)
|
| 12 |
- DifficultyCurriculum — online pass-rate difficulty gate (curriculum.py)
|
| 13 |
- validate_task — 4-gate solvability validator (validator.py)
|
|
@@ -20,11 +21,13 @@ from __future__ import annotations
|
|
| 20 |
from composer_replication.datagen.curriculum import DifficultyCurriculum
|
| 21 |
from composer_replication.datagen.env import FeatureDeletionEnv, StepResult
|
| 22 |
from composer_replication.datagen.monitor import HackMonitor
|
|
|
|
| 23 |
from composer_replication.datagen.sandbox import (
|
| 24 |
FakeSandbox,
|
| 25 |
LocalSubprocessSandbox,
|
| 26 |
Sandbox,
|
| 27 |
TestRunResult,
|
|
|
|
| 28 |
)
|
| 29 |
from composer_replication.datagen.schema import FeatureDeletionTask
|
| 30 |
from composer_replication.datagen.substrates import SweBenchAdapter
|
|
@@ -37,7 +40,9 @@ __all__ = [
|
|
| 37 |
"Sandbox",
|
| 38 |
"FakeSandbox",
|
| 39 |
"LocalSubprocessSandbox",
|
|
|
|
| 40 |
"TestRunResult",
|
|
|
|
| 41 |
"HackMonitor",
|
| 42 |
"DifficultyCurriculum",
|
| 43 |
"validate_task",
|
|
|
|
| 8 |
- FeatureDeletionTask — the task tuple (schema.py)
|
| 9 |
- FeatureDeletionEnv — Gym/OpenEnv-style env + TRL reward_fn adapter (env.py)
|
| 10 |
- Sandbox / FakeSandbox / LocalSubprocessSandbox — execution backends (sandbox.py)
|
| 11 |
+
- DockerSandbox — hardened ephemeral-container backend (docker_sandbox.py)
|
| 12 |
- HackMonitor — reward-hacking provenance monitor (monitor.py)
|
| 13 |
- DifficultyCurriculum — online pass-rate difficulty gate (curriculum.py)
|
| 14 |
- validate_task — 4-gate solvability validator (validator.py)
|
|
|
|
| 21 |
from composer_replication.datagen.curriculum import DifficultyCurriculum
|
| 22 |
from composer_replication.datagen.env import FeatureDeletionEnv, StepResult
|
| 23 |
from composer_replication.datagen.monitor import HackMonitor
|
| 24 |
+
from composer_replication.datagen.docker_sandbox import DockerSandbox
|
| 25 |
from composer_replication.datagen.sandbox import (
|
| 26 |
FakeSandbox,
|
| 27 |
LocalSubprocessSandbox,
|
| 28 |
Sandbox,
|
| 29 |
TestRunResult,
|
| 30 |
+
scrub_tree,
|
| 31 |
)
|
| 32 |
from composer_replication.datagen.schema import FeatureDeletionTask
|
| 33 |
from composer_replication.datagen.substrates import SweBenchAdapter
|
|
|
|
| 40 |
"Sandbox",
|
| 41 |
"FakeSandbox",
|
| 42 |
"LocalSubprocessSandbox",
|
| 43 |
+
"DockerSandbox",
|
| 44 |
"TestRunResult",
|
| 45 |
+
"scrub_tree",
|
| 46 |
"HackMonitor",
|
| 47 |
"DifficultyCurriculum",
|
| 48 |
"validate_task",
|
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""docker_sandbox.py — the hardened container backend for the FD env (ADR-010 §3).
|
| 2 |
+
|
| 3 |
+
`DockerSandbox` is a drop-in `Sandbox` (boot/exec/run_tests/trajectory/
|
| 4 |
+
is_command_allowed) that runs the agent's tool calls and the verifiable test
|
| 5 |
+
command inside an ephemeral, locked-down Docker container instead of a raw host
|
| 6 |
+
subprocess. It is the production execution path for genuinely UNTRUSTED
|
| 7 |
+
model-generated code — the LocalSubprocessSandbox sibling runs everything in the
|
| 8 |
+
host process and only enforces the scrub + denylist, which is fine for
|
| 9 |
+
first-party / dev use but is NOT a host-security boundary.
|
| 10 |
+
|
| 11 |
+
The lockdown recipe (CIS Docker benchmark 5.x + gVisor guidance, see the
|
| 12 |
+
`sandbox-container` research digest):
|
| 13 |
+
- `network_mode='none'` — no egress: no decompiler downloads, no
|
| 14 |
+
signature exfil/recovery over the wire (the SWE-RL reward-hack threat).
|
| 15 |
+
- `read_only=True` root fs + a small `tmpfs={'/tmp': ...}` for scratch.
|
| 16 |
+
- the working tree bind-mounted RW at /work (the agent must mutate the repo).
|
| 17 |
+
- `cap_drop=['ALL']` + `security_opt=['no-new-privileges:true']`.
|
| 18 |
+
- `user='1000:1000'` — never run agent code as root in the container.
|
| 19 |
+
- `pids_limit` (fork-bomb guard) + `mem_limit==memswap_limit` (OOM guard, no
|
| 20 |
+
swap) + `nano_cpus` (CPU quota).
|
| 21 |
+
- optional `runtime='runsc'` (gVisor) — a userspace kernel that intercepts
|
| 22 |
+
syscalls so a kernel-exploit payload hits the Sentry, not the host. This is
|
| 23 |
+
the RECOMMENDED runtime for untrusted model code per the ADR-010 threat
|
| 24 |
+
model, but requires host setup (`sudo runsc install` writes the 'runsc'
|
| 25 |
+
entry into /etc/docker/daemon.json + dockerd restart) so it is OPTIONAL and
|
| 26 |
+
defaults to None (= the daemon default, runc).
|
| 27 |
+
|
| 28 |
+
PRIMARY reward-hack control: the SAME host-side `scrub_tree(workdir)` used by
|
| 29 |
+
LocalSubprocessSandbox runs in `boot()` BEFORE the container starts. The bind
|
| 30 |
+
mount is shared host<->container, so scrubbing __pycache__/.git/*.pyc on the
|
| 31 |
+
host pre-boot is exactly equivalent to scrubbing inside the container. The
|
| 32 |
+
command denylist remains cheap defense-in-depth, not the wall.
|
| 33 |
+
|
| 34 |
+
`docker` is LAZY-imported inside methods so the pure-Python core and the
|
| 35 |
+
FakeSandbox path never require the SDK; a clear RuntimeError is raised if the
|
| 36 |
+
SDK or the daemon is absent.
|
| 37 |
+
"""
|
| 38 |
+
from __future__ import annotations
|
| 39 |
+
|
| 40 |
+
import shlex
|
| 41 |
+
from dataclasses import dataclass, field
|
| 42 |
+
from uuid import uuid4
|
| 43 |
+
|
| 44 |
+
from composer_replication.datagen.sandbox import (
|
| 45 |
+
SANDBOX_DENYLIST,
|
| 46 |
+
TestRunResult,
|
| 47 |
+
scrub_tree,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Label stamped on every container we create, so a reaper can sweep ephemeral
|
| 51 |
+
# containers leaked by a crashed episode (docker-py-native `--rm` durability
|
| 52 |
+
# without the auto_remove log-loss/race problem).
|
| 53 |
+
_LABEL_KEY = "composer_replication"
|
| 54 |
+
_LABEL_VALUE = "datagen"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _require_docker():
|
| 58 |
+
"""Lazy-import the docker SDK and return the module, or raise a clear
|
| 59 |
+
RuntimeError if the SDK is not installed. (Kept separate from the client so
|
| 60 |
+
callers can introspect `docker.errors` without opening a connection.)"""
|
| 61 |
+
try:
|
| 62 |
+
import docker # noqa: PLC0415 (intentional lazy import)
|
| 63 |
+
except ImportError as e: # pragma: no cover - exercised via monkeypatch
|
| 64 |
+
raise RuntimeError(
|
| 65 |
+
"DockerSandbox requires the 'docker' Python SDK (docker>=7). "
|
| 66 |
+
"Install it with `pip install docker` (or the project's [datagen] "
|
| 67 |
+
"extra). It is lazy-imported so the FakeSandbox/core path never "
|
| 68 |
+
"needs it."
|
| 69 |
+
) from e
|
| 70 |
+
return docker
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _make_client():
|
| 74 |
+
"""Construct a `docker.from_env()` client or raise a clear RuntimeError if
|
| 75 |
+
the daemon is unreachable. The SDK constructs lazily, so we ping with
|
| 76 |
+
`client.ping()` to surface a dead daemon here rather than at first use."""
|
| 77 |
+
docker = _require_docker()
|
| 78 |
+
try:
|
| 79 |
+
client = docker.from_env()
|
| 80 |
+
client.ping()
|
| 81 |
+
return client
|
| 82 |
+
except Exception as e: # docker.errors.DockerException, ConnectionError, ...
|
| 83 |
+
raise RuntimeError(
|
| 84 |
+
"DockerSandbox could not reach a Docker daemon (is `docker info` "
|
| 85 |
+
f"healthy?). Underlying error: {e!r}"
|
| 86 |
+
) from e
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def runsc_available(client=None) -> bool:
|
| 90 |
+
"""True iff the gVisor 'runsc' runtime is registered with the daemon.
|
| 91 |
+
|
| 92 |
+
Mirrors the `_docker_available()` gating philosophy: runsc is not installed
|
| 93 |
+
on most dev/CI boxes, so any runsc-specific behavior must be gated on this.
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
client = client or _make_client()
|
| 97 |
+
runtimes = client.info().get("Runtimes", {}) or {}
|
| 98 |
+
return "runsc" in runtimes
|
| 99 |
+
except Exception:
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class DockerSandbox:
|
| 105 |
+
"""Hardened ephemeral-container `Sandbox`. See module docstring.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
workdir: host path to the materialized repo; bind-mounted RW at /work and
|
| 109 |
+
scrubbed on the host before boot (the primary reward-hack
|
| 110 |
+
control). MUST be an existing directory by `boot()` time.
|
| 111 |
+
runtime: None (=> daemon default, runc) or 'runsc' (gVisor) for untrusted
|
| 112 |
+
model code. Requires host-side `sudo runsc install` + dockerd
|
| 113 |
+
restart; gate with `runsc_available()`.
|
| 114 |
+
mem_limit / memswap_limit: OOM guard; equal values forbid swap.
|
| 115 |
+
pids_limit: fork-bomb guard.
|
| 116 |
+
nano_cpus: CPU quota in 1e-9 CPUs (2_000_000_000 == 2 CPUs).
|
| 117 |
+
user: non-root uid:gid the agent code runs as inside the container.
|
| 118 |
+
exec_timeout_s: wall-clock cap injected via coreutils `timeout` (exec_run
|
| 119 |
+
has no timeout param — docker-py #2651).
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
workdir: str
|
| 123 |
+
runtime: str | None = None
|
| 124 |
+
mem_limit: str = "1g"
|
| 125 |
+
memswap_limit: str = "1g"
|
| 126 |
+
pids_limit: int = 256
|
| 127 |
+
nano_cpus: int = 2_000_000_000 # 2 CPUs
|
| 128 |
+
user: str = "1000:1000"
|
| 129 |
+
container_workdir: str = "/work"
|
| 130 |
+
tmpfs_size: str = "64m"
|
| 131 |
+
exec_timeout_s: int = 600
|
| 132 |
+
keep_root_writable: bool = False # escape hatch if read-only fs breaks tooling
|
| 133 |
+
|
| 134 |
+
_trajectory: list[dict] = field(default_factory=list, init=False)
|
| 135 |
+
booted_image: str | None = field(default=None, init=False)
|
| 136 |
+
_client: object | None = field(default=None, init=False)
|
| 137 |
+
_container: object | None = field(default=None, init=False)
|
| 138 |
+
|
| 139 |
+
# ---- construction of the hardening kwargs --------------------------------
|
| 140 |
+
|
| 141 |
+
def container_kwargs(self, image: str) -> dict:
|
| 142 |
+
"""The full hardened `containers.run` kwarg set. Pulled out as a method
|
| 143 |
+
so the pure-unit tests can assert the config (network_disabled,
|
| 144 |
+
mem_limit, runtime, ...) WITHOUT a live daemon."""
|
| 145 |
+
kwargs: dict = {
|
| 146 |
+
"image": image,
|
| 147 |
+
# Long-lived idle container; exec_run drives the actual work.
|
| 148 |
+
"command": ["sleep", "infinity"],
|
| 149 |
+
"detach": True,
|
| 150 |
+
# --- network egress kill-switch ---
|
| 151 |
+
# network_disabled removes networking entirely; we ALSO set
|
| 152 |
+
# network_mode='none' for parity with the existing CLI substrate
|
| 153 |
+
# test (`--network none`) and belt-and-suspenders.
|
| 154 |
+
"network_disabled": True,
|
| 155 |
+
"network_mode": "none",
|
| 156 |
+
# --- filesystem lockdown ---
|
| 157 |
+
"read_only": not self.keep_root_writable,
|
| 158 |
+
"tmpfs": {"/tmp": f"rw,noexec,nosuid,size={self.tmpfs_size}"},
|
| 159 |
+
"volumes": {
|
| 160 |
+
self.workdir: {"bind": self.container_workdir, "mode": "rw"}
|
| 161 |
+
},
|
| 162 |
+
"working_dir": self.container_workdir,
|
| 163 |
+
# --- privilege lockdown ---
|
| 164 |
+
"user": self.user,
|
| 165 |
+
"cap_drop": ["ALL"],
|
| 166 |
+
"security_opt": ["no-new-privileges:true"],
|
| 167 |
+
# --- resource limits ---
|
| 168 |
+
"pids_limit": self.pids_limit,
|
| 169 |
+
"mem_limit": self.mem_limit,
|
| 170 |
+
"memswap_limit": self.memswap_limit,
|
| 171 |
+
"nano_cpus": self.nano_cpus,
|
| 172 |
+
# --- lifecycle / reaping ---
|
| 173 |
+
"name": f"fd-{uuid4().hex[:12]}",
|
| 174 |
+
"labels": {_LABEL_KEY: _LABEL_VALUE},
|
| 175 |
+
}
|
| 176 |
+
# runtime is OPTIONAL; only pass it through when set so the default
|
| 177 |
+
# (runc) path never references a runtime that may not exist.
|
| 178 |
+
if self.runtime:
|
| 179 |
+
kwargs["runtime"] = self.runtime
|
| 180 |
+
return kwargs
|
| 181 |
+
|
| 182 |
+
# ---- Sandbox Protocol ----------------------------------------------------
|
| 183 |
+
|
| 184 |
+
def boot(self, image: str) -> None:
|
| 185 |
+
"""Scrub the HOST workdir (primary control), reap any leaked siblings,
|
| 186 |
+
then start the hardened ephemeral container."""
|
| 187 |
+
self.booted_image = image
|
| 188 |
+
self._trajectory = []
|
| 189 |
+
# PRIMARY reward-hack control — run on the host before the bind mount.
|
| 190 |
+
scrub_tree(self.workdir)
|
| 191 |
+
|
| 192 |
+
self._client = _make_client()
|
| 193 |
+
self.reap_leaked(self._client)
|
| 194 |
+
|
| 195 |
+
docker = _require_docker()
|
| 196 |
+
kwargs = self.container_kwargs(image)
|
| 197 |
+
try:
|
| 198 |
+
self._container = self._client.containers.run(**kwargs)
|
| 199 |
+
except docker.errors.ImageNotFound as e:
|
| 200 |
+
raise RuntimeError(
|
| 201 |
+
f"DockerSandbox.boot: image {image!r} not found locally and "
|
| 202 |
+
"could not be pulled (the container is --network none). Pull it "
|
| 203 |
+
f"on the host first. Underlying: {e!r}"
|
| 204 |
+
) from e
|
| 205 |
+
except docker.errors.APIError as e:
|
| 206 |
+
raise RuntimeError(
|
| 207 |
+
f"DockerSandbox.boot: Docker API error starting {image!r} with "
|
| 208 |
+
f"runtime={self.runtime!r}: {e!r}"
|
| 209 |
+
) from e
|
| 210 |
+
|
| 211 |
+
def is_command_allowed(self, command: str) -> bool:
|
| 212 |
+
# First-token-only check — see SANDBOX_DENYLIST notes. NOT a boundary on
|
| 213 |
+
# its own; the container isolation + host scrub are the real controls.
|
| 214 |
+
return command not in SANDBOX_DENYLIST
|
| 215 |
+
|
| 216 |
+
def _exec(self, cmd: str) -> tuple[int, str]:
|
| 217 |
+
"""Run one shell command in the live container via exec_run, enforcing a
|
| 218 |
+
wall-clock cap with coreutils `timeout` (exec_run has no timeout param —
|
| 219 |
+
docker-py #2651). Returns (exit_code, combined_output)."""
|
| 220 |
+
if self._container is None:
|
| 221 |
+
raise RuntimeError("DockerSandbox.exec called before boot()")
|
| 222 |
+
# Wrap in `timeout` then a login-ish shell so PATH lookups work. demux
|
| 223 |
+
# keeps stdout/stderr separable but we combine them like the local
|
| 224 |
+
# sandbox does for the pytest parser.
|
| 225 |
+
wrapped = f"timeout {self.exec_timeout_s} {cmd}"
|
| 226 |
+
full = ["/bin/sh", "-c", wrapped]
|
| 227 |
+
res = self._container.exec_run(
|
| 228 |
+
full, workdir=self.container_workdir, demux=True
|
| 229 |
+
)
|
| 230 |
+
exit_code = res.exit_code if res.exit_code is not None else -1
|
| 231 |
+
out = res.output
|
| 232 |
+
# demux=True => output is (stdout_bytes|None, stderr_bytes|None).
|
| 233 |
+
if isinstance(out, tuple):
|
| 234 |
+
stdout_b, stderr_b = out
|
| 235 |
+
else: # defensive — some daemons return raw bytes
|
| 236 |
+
stdout_b, stderr_b = out, None
|
| 237 |
+
text = self._decode(stdout_b) + self._decode(stderr_b)
|
| 238 |
+
return exit_code, text
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def _decode(b) -> str:
|
| 242 |
+
"""Untrusted code can emit non-UTF-8 bytes; never rely on text mode."""
|
| 243 |
+
if not b:
|
| 244 |
+
return ""
|
| 245 |
+
if isinstance(b, bytes):
|
| 246 |
+
return b.decode("utf-8", errors="replace")
|
| 247 |
+
return str(b)
|
| 248 |
+
|
| 249 |
+
def exec(self, action: dict) -> str:
|
| 250 |
+
self._trajectory.append(action)
|
| 251 |
+
cmd = str(action.get("command", ""))
|
| 252 |
+
if not cmd.strip():
|
| 253 |
+
return ""
|
| 254 |
+
head = cmd.strip().split()[0]
|
| 255 |
+
if not self.is_command_allowed(head):
|
| 256 |
+
return f"ERROR: command '{head}' is not allowed in the sandbox."
|
| 257 |
+
_exit, out = self._exec(cmd)
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
def run_tests(self, test_command: str, tests: tuple[str, ...]) -> TestRunResult:
|
| 261 |
+
# shlex.quote each node id — SWE-bench node ids contain spaces/brackets
|
| 262 |
+
# (parametrized tests) and could otherwise break the shell or inject
|
| 263 |
+
# commands (matches LocalSubprocessSandbox).
|
| 264 |
+
node_ids = " ".join(shlex.quote(t) for t in tests)
|
| 265 |
+
cmd = f"{test_command} {node_ids}".strip()
|
| 266 |
+
returncode, out = self._exec(cmd)
|
| 267 |
+
# Same conservative parse as LocalSubprocessSandbox: a test is "passed"
|
| 268 |
+
# only if its node id appears with PASSED, else failed; collection
|
| 269 |
+
# errors => collected_ok False.
|
| 270 |
+
passed, failed = set(), set()
|
| 271 |
+
collected_ok = "errors during collection" not in out.lower()
|
| 272 |
+
for t in tests:
|
| 273 |
+
if f"{t} PASSED" in out or (returncode == 0 and not failed):
|
| 274 |
+
passed.add(t)
|
| 275 |
+
else:
|
| 276 |
+
failed.add(t)
|
| 277 |
+
return TestRunResult(
|
| 278 |
+
passed=frozenset(passed),
|
| 279 |
+
failed=frozenset(failed),
|
| 280 |
+
stdout=out,
|
| 281 |
+
collected_ok=collected_ok,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def trajectory(self) -> list[dict]:
|
| 285 |
+
return list(self._trajectory)
|
| 286 |
+
|
| 287 |
+
# ---- lifecycle / cleanup -------------------------------------------------
|
| 288 |
+
|
| 289 |
+
def close(self) -> None:
|
| 290 |
+
"""Tear down the ephemeral container (force=True). Idempotent; swallows
|
| 291 |
+
errors so an already-gone container never masks the episode result."""
|
| 292 |
+
c = self._container
|
| 293 |
+
self._container = None
|
| 294 |
+
if c is not None:
|
| 295 |
+
try:
|
| 296 |
+
c.remove(force=True)
|
| 297 |
+
except Exception:
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
@staticmethod
|
| 301 |
+
def reap_leaked(client=None) -> int:
|
| 302 |
+
"""Sweep ephemeral containers leaked by a crashed episode (labelled
|
| 303 |
+
composer_replication=datagen). Callable at boot and shutdown. Returns
|
| 304 |
+
the count removed. Best-effort — never raises."""
|
| 305 |
+
removed = 0
|
| 306 |
+
try:
|
| 307 |
+
client = client or _make_client()
|
| 308 |
+
leaked = client.containers.list(
|
| 309 |
+
all=True, filters={"label": f"{_LABEL_KEY}={_LABEL_VALUE}"}
|
| 310 |
+
)
|
| 311 |
+
for c in leaked:
|
| 312 |
+
try:
|
| 313 |
+
c.remove(force=True)
|
| 314 |
+
removed += 1
|
| 315 |
+
except Exception:
|
| 316 |
+
pass
|
| 317 |
+
except Exception:
|
| 318 |
+
pass
|
| 319 |
+
return removed
|
| 320 |
+
|
| 321 |
+
def __enter__(self) -> DockerSandbox:
|
| 322 |
+
return self
|
| 323 |
+
|
| 324 |
+
def __exit__(self, *exc) -> None:
|
| 325 |
+
self.close()
|
| 326 |
+
|
| 327 |
+
def __del__(self): # pragma: no cover - best-effort GC cleanup
|
| 328 |
+
try:
|
| 329 |
+
self.close()
|
| 330 |
+
except Exception:
|
| 331 |
+
pass
|
|
@@ -50,11 +50,51 @@ SANDBOX_DENYLIST: frozenset[str] = frozenset(
|
|
| 50 |
# primary control. `is_command_allowed` checks only the first whitespace token,
|
| 51 |
# so `/usr/bin/find`, `sh -c "strings x"`, and especially `python -c "import
|
| 52 |
# marshal,dis; ..."` all bypass it. The ADR-claimed PRIMARY control is the
|
| 53 |
-
# pre-task cache/.git scrub in `boot()` — see `
|
| 54 |
# implemented (was previously absent, making the denylist the only — broken —
|
| 55 |
# defense). The denylist remains as cheap defense-in-depth, not the wall.
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
@runtime_checkable
|
| 59 |
class Sandbox(Protocol):
|
| 60 |
"""An execution environment for one FD episode."""
|
|
@@ -120,13 +160,11 @@ class LocalSubprocessSandbox:
|
|
| 120 |
_trajectory: list[dict] = field(default_factory=list)
|
| 121 |
booted_image: str | None = None
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
#
|
| 125 |
-
#
|
| 126 |
-
_SCRUB_NAMES: tuple[str, ...] =
|
| 127 |
-
|
| 128 |
-
)
|
| 129 |
-
_SCRUB_SUFFIXES: tuple[str, ...] = (".pyc", ".pyo", ".class")
|
| 130 |
|
| 131 |
def boot(self, image: str) -> None:
|
| 132 |
self.booted_image = image
|
|
@@ -134,26 +172,9 @@ class LocalSubprocessSandbox:
|
|
| 134 |
self._scrub_tree()
|
| 135 |
|
| 136 |
def _scrub_tree(self) -> None:
|
| 137 |
-
"""
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
the wall; the command denylist is only cheap defense-in-depth on top.
|
| 141 |
-
Cross-family review 2026-05-29 found this was previously UNIMPLEMENTED —
|
| 142 |
-
boot() only recorded the image string."""
|
| 143 |
-
if not self.workdir or not os.path.isdir(self.workdir):
|
| 144 |
-
return
|
| 145 |
-
for root, dirs, files in os.walk(self.workdir, topdown=True):
|
| 146 |
-
# Remove (and stop descending into) scrub-named directories.
|
| 147 |
-
for d in list(dirs):
|
| 148 |
-
if d in self._SCRUB_NAMES:
|
| 149 |
-
shutil.rmtree(os.path.join(root, d), ignore_errors=True)
|
| 150 |
-
dirs.remove(d)
|
| 151 |
-
for f in files:
|
| 152 |
-
if f.endswith(self._SCRUB_SUFFIXES):
|
| 153 |
-
try:
|
| 154 |
-
os.remove(os.path.join(root, f))
|
| 155 |
-
except OSError:
|
| 156 |
-
pass
|
| 157 |
|
| 158 |
def is_command_allowed(self, command: str) -> bool:
|
| 159 |
# NOTE: first-token-only check — see SANDBOX_DENYLIST comment. This is
|
|
|
|
| 50 |
# primary control. `is_command_allowed` checks only the first whitespace token,
|
| 51 |
# so `/usr/bin/find`, `sh -c "strings x"`, and especially `python -c "import
|
| 52 |
# marshal,dis; ..."` all bypass it. The ADR-claimed PRIMARY control is the
|
| 53 |
+
# pre-task cache/.git scrub in `boot()` — see `scrub_tree` below, which is now
|
| 54 |
# implemented (was previously absent, making the denylist the only — broken —
|
| 55 |
# defense). The denylist remains as cheap defense-in-depth, not the wall.
|
| 56 |
|
| 57 |
|
| 58 |
+
# Cache/history artifacts that let an agent recover a deleted signature WITHOUT
|
| 59 |
+
# reimplementing it (the Composer-blog reward-hacks). Scrubbed at boot() so the
|
| 60 |
+
# denylist isn't the only — and bypassable — line of defense. These are module
|
| 61 |
+
# level so EVERY sandbox backend (LocalSubprocessSandbox AND DockerSandbox)
|
| 62 |
+
# applies the identical primary control via the shared `scrub_tree` free
|
| 63 |
+
# function below.
|
| 64 |
+
SCRUB_NAMES: tuple[str, ...] = (
|
| 65 |
+
"__pycache__", ".mypy_cache", ".pytest_cache", ".git", ".hg",
|
| 66 |
+
)
|
| 67 |
+
SCRUB_SUFFIXES: tuple[str, ...] = (".pyc", ".pyo", ".class")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def scrub_tree(workdir: str) -> None:
|
| 71 |
+
"""PRIMARY reward-hack control (ADR-010 §3): physically remove byte-code
|
| 72 |
+
caches, type-check caches, and VCS history from the working tree before the
|
| 73 |
+
episode starts, so there is no cached signature to recover. This is the
|
| 74 |
+
wall; the command denylist is only cheap defense-in-depth on top.
|
| 75 |
+
|
| 76 |
+
Shared by LocalSubprocessSandbox (scrubs the subprocess cwd) and
|
| 77 |
+
DockerSandbox (scrubs the HOST workdir BEFORE the bind mount — the mount is
|
| 78 |
+
shared host<->container, so a host-side scrub pre-boot is exactly equivalent
|
| 79 |
+
to scrubbing inside the container). Cross-family review 2026-05-29 found this
|
| 80 |
+
was previously UNIMPLEMENTED — boot() only recorded the image string.
|
| 81 |
+
"""
|
| 82 |
+
if not workdir or not os.path.isdir(workdir):
|
| 83 |
+
return
|
| 84 |
+
for root, dirs, files in os.walk(workdir, topdown=True):
|
| 85 |
+
# Remove (and stop descending into) scrub-named directories.
|
| 86 |
+
for d in list(dirs):
|
| 87 |
+
if d in SCRUB_NAMES:
|
| 88 |
+
shutil.rmtree(os.path.join(root, d), ignore_errors=True)
|
| 89 |
+
dirs.remove(d)
|
| 90 |
+
for f in files:
|
| 91 |
+
if f.endswith(SCRUB_SUFFIXES):
|
| 92 |
+
try:
|
| 93 |
+
os.remove(os.path.join(root, f))
|
| 94 |
+
except OSError:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
@runtime_checkable
|
| 99 |
class Sandbox(Protocol):
|
| 100 |
"""An execution environment for one FD episode."""
|
|
|
|
| 160 |
_trajectory: list[dict] = field(default_factory=list)
|
| 161 |
booted_image: str | None = None
|
| 162 |
|
| 163 |
+
# Back-compat aliases for the module-level scrub constants (callers/tests
|
| 164 |
+
# that referenced the old instance attributes keep working). The real
|
| 165 |
+
# control is the shared module-level `scrub_tree` free function.
|
| 166 |
+
_SCRUB_NAMES: tuple[str, ...] = SCRUB_NAMES
|
| 167 |
+
_SCRUB_SUFFIXES: tuple[str, ...] = SCRUB_SUFFIXES
|
|
|
|
|
|
|
| 168 |
|
| 169 |
def boot(self, image: str) -> None:
|
| 170 |
self.booted_image = image
|
|
|
|
| 172 |
self._scrub_tree()
|
| 173 |
|
| 174 |
def _scrub_tree(self) -> None:
|
| 175 |
+
"""Delegate to the shared module-level `scrub_tree` (see its docstring).
|
| 176 |
+
Kept as a method for back-compat with existing callers."""
|
| 177 |
+
scrub_tree(self.workdir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
def is_command_allowed(self, command: str) -> bool:
|
| 180 |
# NOTE: first-token-only check — see SANDBOX_DENYLIST comment. This is
|
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""test_docker_sandbox.py — DockerSandbox unit + live-Docker coverage.
|
| 2 |
+
|
| 3 |
+
Two tiers, mirroring the repo's `test_docker_substrate_e2e.py` gating and the
|
| 4 |
+
ModalSpawnExecutor mock pattern:
|
| 5 |
+
|
| 6 |
+
1. PURE-UNIT (always run, no daemon): a mock docker client/container asserts
|
| 7 |
+
the hardening config (network_disabled, mem_limit, runtime, cap_drop, ...),
|
| 8 |
+
the shared host-side scrub, the bytes->str decode, the pytest-summary
|
| 9 |
+
parse in run_tests, the denylist short-circuit, and the missing-SDK /
|
| 10 |
+
dead-daemon RuntimeError paths. These cover DockerSandbox even on a box
|
| 11 |
+
with no Docker.
|
| 12 |
+
|
| 13 |
+
2. LIVE-DOCKER (skipif `_docker_available()`): boots a REAL hardened
|
| 14 |
+
`python:3.11-slim` container with --network none and runs the 4 inversion
|
| 15 |
+
gates + a cache-scrub check + a network-isolation check inside it. Since
|
| 16 |
+
Docker is available on this host, these ACTUALLY RUN.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import shutil
|
| 22 |
+
import subprocess
|
| 23 |
+
import tempfile
|
| 24 |
+
import textwrap
|
| 25 |
+
import types
|
| 26 |
+
from collections import namedtuple
|
| 27 |
+
|
| 28 |
+
import pytest
|
| 29 |
+
|
| 30 |
+
from composer_replication.datagen import docker_sandbox as ds_mod
|
| 31 |
+
from composer_replication.datagen.docker_sandbox import DockerSandbox
|
| 32 |
+
from composer_replication.datagen.sandbox import SANDBOX_DENYLIST, Sandbox
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Live-Docker gate (mirrors test_docker_substrate_e2e.py)
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
def _docker_available() -> bool:
|
| 39 |
+
"""True iff a usable Docker daemon is reachable via the CLI."""
|
| 40 |
+
if shutil.which("docker") is None:
|
| 41 |
+
return False
|
| 42 |
+
try:
|
| 43 |
+
r = subprocess.run(["docker", "info"], capture_output=True, timeout=10)
|
| 44 |
+
return r.returncode == 0
|
| 45 |
+
except Exception:
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# A tiny image we know exists locally on this host (212MB python:3.11-slim).
|
| 50 |
+
# The live tests run `--network none`, so the image MUST be present already
|
| 51 |
+
# (no pull possible inside a network-disabled container).
|
| 52 |
+
_TEST_IMAGE = "python:3.11-slim"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _image_present(image: str) -> bool:
|
| 56 |
+
try:
|
| 57 |
+
r = subprocess.run(
|
| 58 |
+
["docker", "image", "inspect", image], capture_output=True, timeout=15
|
| 59 |
+
)
|
| 60 |
+
return r.returncode == 0
|
| 61 |
+
except Exception:
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ===========================================================================
|
| 66 |
+
# TIER 1 — PURE-UNIT (mock docker client, no daemon required)
|
| 67 |
+
# ===========================================================================
|
| 68 |
+
|
| 69 |
+
_ExecResult = namedtuple("ExecResult", ["exit_code", "output"])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class _MockContainer:
|
| 73 |
+
"""Stand-in for a docker-py Container.
|
| 74 |
+
|
| 75 |
+
Knobs:
|
| 76 |
+
- exec_script: callable(cmd_argv) -> (exit_code, stdout_bytes, stderr_bytes)
|
| 77 |
+
used to fake exec_run results. Defaults to an empty success.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, *, exec_script=None):
|
| 81 |
+
self.exec_calls: list[tuple] = []
|
| 82 |
+
self.removed = False
|
| 83 |
+
self.remove_force = None
|
| 84 |
+
self._exec_script = exec_script or (lambda argv: (0, b"", b""))
|
| 85 |
+
|
| 86 |
+
def exec_run(self, cmd, *, workdir=None, demux=False, **kw):
|
| 87 |
+
self.exec_calls.append((cmd, {"workdir": workdir, "demux": demux, **kw}))
|
| 88 |
+
code, out, err = self._exec_script(cmd)
|
| 89 |
+
if demux:
|
| 90 |
+
return _ExecResult(code, (out, err))
|
| 91 |
+
return _ExecResult(code, (out or b"") + (err or b""))
|
| 92 |
+
|
| 93 |
+
def remove(self, force=False):
|
| 94 |
+
self.removed = True
|
| 95 |
+
self.remove_force = force
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class _MockContainers:
|
| 99 |
+
def __init__(self, container, *, run_raises=None):
|
| 100 |
+
self._container = container
|
| 101 |
+
self.run_kwargs: dict | None = None
|
| 102 |
+
self._run_raises = run_raises
|
| 103 |
+
self._listed: list = []
|
| 104 |
+
|
| 105 |
+
def run(self, **kwargs):
|
| 106 |
+
self.run_kwargs = kwargs
|
| 107 |
+
if self._run_raises is not None:
|
| 108 |
+
raise self._run_raises
|
| 109 |
+
return self._container
|
| 110 |
+
|
| 111 |
+
def list(self, all=False, filters=None): # noqa: A002 - matches docker-py
|
| 112 |
+
return list(self._listed)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class _MockClient:
|
| 116 |
+
def __init__(self, container, *, run_raises=None, runtimes=None):
|
| 117 |
+
self.containers = _MockContainers(container, run_raises=run_raises)
|
| 118 |
+
self._info = {"Runtimes": runtimes or {"runc": {}}}
|
| 119 |
+
self.pinged = False
|
| 120 |
+
|
| 121 |
+
def ping(self):
|
| 122 |
+
self.pinged = True
|
| 123 |
+
return True
|
| 124 |
+
|
| 125 |
+
def info(self):
|
| 126 |
+
return self._info
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _patch_client(monkeypatch, client):
|
| 130 |
+
"""Make DockerSandbox use `client` instead of a real daemon, and stub the
|
| 131 |
+
lazy `docker` module so `docker.errors.*` resolve in boot()."""
|
| 132 |
+
monkeypatch.setattr(ds_mod, "_make_client", lambda: client)
|
| 133 |
+
|
| 134 |
+
fake_docker = types.ModuleType("docker")
|
| 135 |
+
errors = types.ModuleType("docker.errors")
|
| 136 |
+
|
| 137 |
+
class ImageNotFound(Exception): # noqa: N818 — mirrors docker.errors.ImageNotFound name
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
class APIError(Exception):
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
errors.ImageNotFound = ImageNotFound
|
| 144 |
+
errors.APIError = APIError
|
| 145 |
+
fake_docker.errors = errors
|
| 146 |
+
fake_docker.from_env = lambda: client
|
| 147 |
+
monkeypatch.setattr(ds_mod, "_require_docker", lambda: fake_docker)
|
| 148 |
+
return fake_docker
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_dockersandbox_is_a_sandbox_protocol_instance():
|
| 152 |
+
"""Drop-in for FakeSandbox/LocalSubprocessSandbox in env/validator."""
|
| 153 |
+
sb = DockerSandbox(workdir="/tmp")
|
| 154 |
+
assert isinstance(sb, Sandbox)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_container_kwargs_hardening_config():
|
| 158 |
+
"""The lockdown recipe is present and correct WITHOUT any daemon."""
|
| 159 |
+
sb = DockerSandbox(workdir="/some/work")
|
| 160 |
+
kw = sb.container_kwargs(_TEST_IMAGE)
|
| 161 |
+
|
| 162 |
+
# network egress kill-switch
|
| 163 |
+
assert kw["network_disabled"] is True
|
| 164 |
+
assert kw["network_mode"] == "none"
|
| 165 |
+
# filesystem lockdown
|
| 166 |
+
assert kw["read_only"] is True
|
| 167 |
+
assert kw["tmpfs"] == {"/tmp": "rw,noexec,nosuid,size=64m"}
|
| 168 |
+
assert kw["volumes"] == {"/some/work": {"bind": "/work", "mode": "rw"}}
|
| 169 |
+
assert kw["working_dir"] == "/work"
|
| 170 |
+
# privilege lockdown
|
| 171 |
+
assert kw["user"] == "1000:1000"
|
| 172 |
+
assert kw["cap_drop"] == ["ALL"]
|
| 173 |
+
assert kw["security_opt"] == ["no-new-privileges:true"]
|
| 174 |
+
# resource limits
|
| 175 |
+
assert kw["pids_limit"] == 256
|
| 176 |
+
assert kw["mem_limit"] == "1g"
|
| 177 |
+
assert kw["memswap_limit"] == "1g" # == mem_limit => no swap
|
| 178 |
+
assert kw["nano_cpus"] == 2_000_000_000
|
| 179 |
+
# lifecycle
|
| 180 |
+
assert kw["detach"] is True
|
| 181 |
+
assert kw["command"] == ["sleep", "infinity"]
|
| 182 |
+
assert kw["labels"] == {"composer_replication": "datagen"}
|
| 183 |
+
assert kw["name"].startswith("fd-")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def test_runtime_optional_default_runc():
|
| 187 |
+
"""runtime defaults to None => the 'runtime' kwarg is omitted (daemon
|
| 188 |
+
default runc), so the default path never names a runtime that may not
|
| 189 |
+
exist."""
|
| 190 |
+
assert "runtime" not in DockerSandbox(workdir="/w").container_kwargs("img")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def test_runtime_runsc_passed_through_when_set():
|
| 194 |
+
"""When the caller opts into gVisor, runtime='runsc' reaches run()."""
|
| 195 |
+
kw = DockerSandbox(workdir="/w", runtime="runsc").container_kwargs("img")
|
| 196 |
+
assert kw["runtime"] == "runsc"
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def test_resource_limits_are_configurable():
|
| 200 |
+
sb = DockerSandbox(
|
| 201 |
+
workdir="/w", mem_limit="256m", memswap_limit="256m",
|
| 202 |
+
pids_limit=64, nano_cpus=1_000_000_000, tmpfs_size="16m",
|
| 203 |
+
)
|
| 204 |
+
kw = sb.container_kwargs("img")
|
| 205 |
+
assert kw["mem_limit"] == "256m"
|
| 206 |
+
assert kw["memswap_limit"] == "256m"
|
| 207 |
+
assert kw["pids_limit"] == 64
|
| 208 |
+
assert kw["nano_cpus"] == 1_000_000_000
|
| 209 |
+
assert kw["tmpfs"] == {"/tmp": "rw,noexec,nosuid,size=16m"}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def test_keep_root_writable_escape_hatch():
|
| 213 |
+
kw = DockerSandbox(workdir="/w", keep_root_writable=True).container_kwargs("i")
|
| 214 |
+
assert kw["read_only"] is False
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def test_boot_scrubs_host_tree_before_container(monkeypatch):
|
| 218 |
+
"""PRIMARY reward-hack control: scrub_tree runs on the HOST workdir in boot()
|
| 219 |
+
BEFORE the container starts (the bind mount is shared)."""
|
| 220 |
+
with tempfile.TemporaryDirectory() as d:
|
| 221 |
+
os.makedirs(os.path.join(d, "__pycache__"))
|
| 222 |
+
with open(os.path.join(d, "__pycache__", "x.cpython-311.pyc"), "wb") as f:
|
| 223 |
+
f.write(b"\x00stale-bytecode")
|
| 224 |
+
os.makedirs(os.path.join(d, ".git"))
|
| 225 |
+
with open(os.path.join(d, "mod.pyc"), "wb") as f:
|
| 226 |
+
f.write(b"\x00")
|
| 227 |
+
with open(os.path.join(d, "keep.py"), "w") as f:
|
| 228 |
+
f.write("x = 1\n")
|
| 229 |
+
|
| 230 |
+
container = _MockContainer()
|
| 231 |
+
client = _MockClient(container)
|
| 232 |
+
_patch_client(monkeypatch, client)
|
| 233 |
+
|
| 234 |
+
sb = DockerSandbox(workdir=d)
|
| 235 |
+
sb.boot(_TEST_IMAGE)
|
| 236 |
+
|
| 237 |
+
assert not os.path.exists(os.path.join(d, "__pycache__"))
|
| 238 |
+
assert not os.path.exists(os.path.join(d, ".git"))
|
| 239 |
+
assert not os.path.exists(os.path.join(d, "mod.pyc"))
|
| 240 |
+
assert os.path.exists(os.path.join(d, "keep.py")) # real source survives
|
| 241 |
+
# the container was actually started with the hardened kwargs
|
| 242 |
+
assert client.containers.run_kwargs["network_disabled"] is True
|
| 243 |
+
assert sb.booted_image == _TEST_IMAGE
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def test_exec_uses_timeout_and_workdir_and_denylist(monkeypatch):
|
| 247 |
+
"""exec() wraps the command with coreutils `timeout`, runs in /work, and
|
| 248 |
+
short-circuits denied commands without touching the container."""
|
| 249 |
+
container = _MockContainer(
|
| 250 |
+
exec_script=lambda argv: (0, b"hello\n", b"")
|
| 251 |
+
)
|
| 252 |
+
client = _MockClient(container)
|
| 253 |
+
_patch_client(monkeypatch, client)
|
| 254 |
+
|
| 255 |
+
sb = DockerSandbox(workdir="/w", exec_timeout_s=42)
|
| 256 |
+
sb._container = container # skip boot for this focused unit
|
| 257 |
+
|
| 258 |
+
out = sb.exec({"command": "echo hello"})
|
| 259 |
+
assert out == "hello\n"
|
| 260 |
+
cmd_argv, kw = container.exec_calls[-1]
|
| 261 |
+
assert cmd_argv == ["/bin/sh", "-c", "timeout 42 echo hello"]
|
| 262 |
+
assert kw["workdir"] == "/work"
|
| 263 |
+
assert kw["demux"] is True
|
| 264 |
+
|
| 265 |
+
# a denylisted first token never reaches the container
|
| 266 |
+
n_before = len(container.exec_calls)
|
| 267 |
+
denied = sorted(SANDBOX_DENYLIST)[0]
|
| 268 |
+
msg = sb.exec({"command": f"{denied} something"})
|
| 269 |
+
assert "not allowed" in msg
|
| 270 |
+
assert len(container.exec_calls) == n_before # no new exec
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def test_exec_decodes_non_utf8_bytes(monkeypatch):
|
| 274 |
+
"""Untrusted code can emit invalid UTF-8 on stdout; we must not crash."""
|
| 275 |
+
container = _MockContainer(
|
| 276 |
+
exec_script=lambda argv: (0, b"\xff\xfe bad bytes", b"")
|
| 277 |
+
)
|
| 278 |
+
sb = DockerSandbox(workdir="/w")
|
| 279 |
+
sb._container = container
|
| 280 |
+
out = sb.exec({"command": "echo x"})
|
| 281 |
+
assert "bad bytes" in out # replaced, not crashed
|
| 282 |
+
assert "�" in out # U+FFFD replacement char
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def test_run_tests_parses_pytest_summary(monkeypatch):
|
| 286 |
+
"""run_tests applies the SAME conservative parse as LocalSubprocessSandbox:
|
| 287 |
+
a node id is passed iff '<nodeid> PASSED' appears."""
|
| 288 |
+
tests = ("t.py::test_a", "t.py::test_b")
|
| 289 |
+
out = b"t.py::test_a PASSED\nt.py::test_b FAILED\n1 failed, 1 passed\n"
|
| 290 |
+
container = _MockContainer(exec_script=lambda argv: (1, out, b""))
|
| 291 |
+
sb = DockerSandbox(workdir="/w")
|
| 292 |
+
sb._container = container
|
| 293 |
+
|
| 294 |
+
res = sb.run_tests("pytest -v", tests)
|
| 295 |
+
assert res.passed == frozenset({"t.py::test_a"})
|
| 296 |
+
assert res.failed == frozenset({"t.py::test_b"})
|
| 297 |
+
assert res.collected_ok is True
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def test_run_tests_collection_error(monkeypatch):
|
| 301 |
+
tests = ("t.py::test_a",)
|
| 302 |
+
out = b"ERROR collecting t.py\n!!! errors during collection !!!\n"
|
| 303 |
+
container = _MockContainer(exec_script=lambda argv: (2, out, b""))
|
| 304 |
+
sb = DockerSandbox(workdir="/w")
|
| 305 |
+
sb._container = container
|
| 306 |
+
res = sb.run_tests("pytest -v", tests)
|
| 307 |
+
assert res.collected_ok is False
|
| 308 |
+
assert res.failed == frozenset({"t.py::test_a"})
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def test_run_tests_quotes_node_ids(monkeypatch):
|
| 312 |
+
"""Parametrized node ids with spaces/brackets must be shlex-quoted (shell
|
| 313 |
+
injection guard the repo already fixed for the local sandbox)."""
|
| 314 |
+
captured = {}
|
| 315 |
+
|
| 316 |
+
def script(argv):
|
| 317 |
+
captured["argv"] = argv
|
| 318 |
+
return (0, b"", b"")
|
| 319 |
+
|
| 320 |
+
container = _MockContainer(exec_script=script)
|
| 321 |
+
sb = DockerSandbox(workdir="/w")
|
| 322 |
+
sb._container = container
|
| 323 |
+
sb.run_tests("pytest -v", ("t.py::test_x[a b]",))
|
| 324 |
+
# the dangerous node id is quoted inside the timeout-wrapped sh -c string
|
| 325 |
+
shell_cmd = captured["argv"][-1]
|
| 326 |
+
assert "'t.py::test_x[a b]'" in shell_cmd
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def test_exec_before_boot_raises():
|
| 330 |
+
sb = DockerSandbox(workdir="/w")
|
| 331 |
+
with pytest.raises(RuntimeError, match="before boot"):
|
| 332 |
+
sb.exec({"command": "echo hi"})
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def test_trajectory_records_actions(monkeypatch):
|
| 336 |
+
container = _MockContainer()
|
| 337 |
+
sb = DockerSandbox(workdir="/w")
|
| 338 |
+
sb._container = container
|
| 339 |
+
sb.exec({"command": "echo a"})
|
| 340 |
+
sb.exec({"command": "echo b"})
|
| 341 |
+
traj = sb.trajectory()
|
| 342 |
+
assert [a["command"] for a in traj] == ["echo a", "echo b"]
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def test_close_removes_container_force(monkeypatch):
|
| 346 |
+
container = _MockContainer()
|
| 347 |
+
sb = DockerSandbox(workdir="/w")
|
| 348 |
+
sb._container = container
|
| 349 |
+
sb.close()
|
| 350 |
+
assert container.removed is True
|
| 351 |
+
assert container.remove_force is True
|
| 352 |
+
# idempotent
|
| 353 |
+
sb.close()
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def test_context_manager_closes(monkeypatch):
|
| 357 |
+
container = _MockContainer()
|
| 358 |
+
client = _MockClient(container)
|
| 359 |
+
_patch_client(monkeypatch, client)
|
| 360 |
+
with tempfile.TemporaryDirectory() as d:
|
| 361 |
+
with DockerSandbox(workdir=d) as sb:
|
| 362 |
+
sb.boot(_TEST_IMAGE)
|
| 363 |
+
assert sb._container is container
|
| 364 |
+
assert container.removed is True
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def test_reap_leaked_sweeps_labelled_containers(monkeypatch):
|
| 368 |
+
leaked = [_MockContainer(), _MockContainer()]
|
| 369 |
+
container = _MockContainer()
|
| 370 |
+
client = _MockClient(container)
|
| 371 |
+
client.containers._listed = leaked
|
| 372 |
+
n = DockerSandbox.reap_leaked(client)
|
| 373 |
+
assert n == 2
|
| 374 |
+
assert all(c.removed for c in leaked)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def test_boot_image_not_found_raises_runtimeerror(monkeypatch):
|
| 378 |
+
container = _MockContainer()
|
| 379 |
+
client = _MockClient(container)
|
| 380 |
+
fake_docker = _patch_client(monkeypatch, client)
|
| 381 |
+
# make run() raise ImageNotFound
|
| 382 |
+
client.containers._run_raises = fake_docker.errors.ImageNotFound("nope")
|
| 383 |
+
with tempfile.TemporaryDirectory() as d:
|
| 384 |
+
sb = DockerSandbox(workdir=d)
|
| 385 |
+
with pytest.raises(RuntimeError, match="not found locally"):
|
| 386 |
+
sb.boot("ghost:latest")
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def test_boot_api_error_raises_runtimeerror(monkeypatch):
|
| 390 |
+
container = _MockContainer()
|
| 391 |
+
client = _MockClient(container)
|
| 392 |
+
fake_docker = _patch_client(monkeypatch, client)
|
| 393 |
+
client.containers._run_raises = fake_docker.errors.APIError("bad runtime")
|
| 394 |
+
with tempfile.TemporaryDirectory() as d:
|
| 395 |
+
sb = DockerSandbox(workdir=d, runtime="runsc")
|
| 396 |
+
with pytest.raises(RuntimeError, match="Docker API error"):
|
| 397 |
+
sb.boot(_TEST_IMAGE)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def test_require_docker_missing_sdk_raises(monkeypatch):
|
| 401 |
+
"""If the docker SDK is absent, a clear RuntimeError is raised (lazy import
|
| 402 |
+
means the FakeSandbox/core path never needs it)."""
|
| 403 |
+
import builtins
|
| 404 |
+
|
| 405 |
+
real_import = builtins.__import__
|
| 406 |
+
|
| 407 |
+
def fake_import(name, *args, **kwargs):
|
| 408 |
+
if name == "docker":
|
| 409 |
+
raise ImportError("No module named 'docker'")
|
| 410 |
+
return real_import(name, *args, **kwargs)
|
| 411 |
+
|
| 412 |
+
monkeypatch.setattr(builtins, "__import__", fake_import)
|
| 413 |
+
with pytest.raises(RuntimeError, match="requires the 'docker' Python SDK"):
|
| 414 |
+
ds_mod._require_docker()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def test_make_client_dead_daemon_raises(monkeypatch):
|
| 418 |
+
"""A dead/unreachable daemon surfaces a clear RuntimeError at client build."""
|
| 419 |
+
fake_docker = types.ModuleType("docker")
|
| 420 |
+
|
| 421 |
+
def from_env():
|
| 422 |
+
raise RuntimeError("Cannot connect to the Docker daemon")
|
| 423 |
+
|
| 424 |
+
fake_docker.from_env = from_env
|
| 425 |
+
monkeypatch.setattr(ds_mod, "_require_docker", lambda: fake_docker)
|
| 426 |
+
with pytest.raises(RuntimeError, match="could not reach a Docker daemon"):
|
| 427 |
+
ds_mod._make_client()
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def test_runsc_available_false_when_only_runc(monkeypatch):
|
| 431 |
+
client = _MockClient(_MockContainer(), runtimes={"runc": {}})
|
| 432 |
+
monkeypatch.setattr(ds_mod, "_make_client", lambda: client)
|
| 433 |
+
assert ds_mod.runsc_available() is False
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def test_runsc_available_true_when_registered(monkeypatch):
|
| 437 |
+
client = _MockClient(_MockContainer(), runtimes={"runc": {}, "runsc": {}})
|
| 438 |
+
monkeypatch.setattr(ds_mod, "_make_client", lambda: client)
|
| 439 |
+
assert ds_mod.runsc_available() is True
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# ===========================================================================
|
| 443 |
+
# TIER 2 — LIVE DOCKER (skipif on daemon availability)
|
| 444 |
+
# ===========================================================================
|
| 445 |
+
|
| 446 |
+
live = pytest.mark.skipif(
|
| 447 |
+
not _docker_available(),
|
| 448 |
+
reason="Docker daemon not available — DockerSandbox live tests are "
|
| 449 |
+
"hardware-gated (mirror test_docker_substrate_e2e.py).",
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Minimal synthetic FD task (same shape as test_docker_substrate_e2e.py).
|
| 453 |
+
_MODULE_SOLVED = textwrap.dedent('''\
|
| 454 |
+
def add(a, b):
|
| 455 |
+
return a + b
|
| 456 |
+
|
| 457 |
+
def mul(a, b):
|
| 458 |
+
return a * b
|
| 459 |
+
''')
|
| 460 |
+
_MODULE_BROKEN = textwrap.dedent('''\
|
| 461 |
+
def add(a, b):
|
| 462 |
+
return a + b
|
| 463 |
+
''')
|
| 464 |
+
|
| 465 |
+
# A stdlib-only pytest substitute: NO pip install, so --network none holds.
|
| 466 |
+
# Writes a tiny runner that imports `feature`, checks both add/mul, and prints
|
| 467 |
+
# pytest-style '<nodeid> PASSED/FAILED' lines that run_tests parses. We pass the
|
| 468 |
+
# NODE IDs to check on argv (trusted, test-author-controlled) and the runner
|
| 469 |
+
# evaluates FIXED expressions per node id — no eval() of untrusted input.
|
| 470 |
+
_RUNNER_TMPL = '''\
|
| 471 |
+
import sys
|
| 472 |
+
# Fixed expectations keyed by node id. The deleted-feature episode is detected
|
| 473 |
+
# by import-time AttributeError on `feature.mul`, never by evaluating a string.
|
| 474 |
+
CHECKS = {
|
| 475 |
+
"feature.py::test_add": lambda m: m.add(2, 3) == 5,
|
| 476 |
+
"feature.py::test_mul": lambda m: m.mul(2, 3) == 6,
|
| 477 |
+
}
|
| 478 |
+
nodeid = sys.argv[1]
|
| 479 |
+
try:
|
| 480 |
+
import feature
|
| 481 |
+
ok = bool(CHECKS[nodeid](feature))
|
| 482 |
+
except Exception as e:
|
| 483 |
+
print(nodeid, "FAILED", "(exc:", type(e).__name__, e, ")")
|
| 484 |
+
sys.exit(1)
|
| 485 |
+
print(nodeid, "PASSED" if ok else "FAILED")
|
| 486 |
+
sys.exit(0 if ok else 1)
|
| 487 |
+
'''
|
| 488 |
+
|
| 489 |
+
# Host-side network probe written into the workdir, then run inside the
|
| 490 |
+
# container as a plain file (avoids fragile inline `python -c` quoting through
|
| 491 |
+
# `sh -c`). Prints CONNECTED if egress works, BLOCKED otherwise.
|
| 492 |
+
_NETPROBE = '''\
|
| 493 |
+
import socket
|
| 494 |
+
s = socket.socket()
|
| 495 |
+
s.settimeout(3)
|
| 496 |
+
try:
|
| 497 |
+
s.connect(("1.1.1.1", 53))
|
| 498 |
+
print("CONNECTED")
|
| 499 |
+
except Exception as e:
|
| 500 |
+
print("BLOCKED", type(e).__name__)
|
| 501 |
+
'''
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def _materialize(d: str, module_src: str) -> None:
|
| 505 |
+
with open(os.path.join(d, "feature.py"), "w") as f:
|
| 506 |
+
f.write(module_src)
|
| 507 |
+
with open(os.path.join(d, "runner.py"), "w") as f:
|
| 508 |
+
f.write(_RUNNER_TMPL)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
@live
|
| 512 |
+
def test_live_image_present_guard():
|
| 513 |
+
"""The live tests run --network none and cannot pull; assert the image is
|
| 514 |
+
already on the host so a missing-image failure reads clearly."""
|
| 515 |
+
if not _image_present(_TEST_IMAGE):
|
| 516 |
+
pytest.skip(f"{_TEST_IMAGE} not present locally; `docker pull {_TEST_IMAGE}` to enable")
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
@live
|
| 520 |
+
def test_live_four_inversion_gates_in_hardened_container():
|
| 521 |
+
"""The 4 ADR-010 gates against a REAL hardened DockerSandbox container."""
|
| 522 |
+
if not _image_present(_TEST_IMAGE):
|
| 523 |
+
pytest.skip(f"{_TEST_IMAGE} not present locally")
|
| 524 |
+
|
| 525 |
+
target = "feature.py::test_mul" # FAIL_TO_PASS — exercises the deleted symbol
|
| 526 |
+
guard = "feature.py::test_add" # PASS_TO_PASS — must survive the deletion
|
| 527 |
+
|
| 528 |
+
def _run(module_src, node):
|
| 529 |
+
with tempfile.TemporaryDirectory() as d:
|
| 530 |
+
_materialize(d, module_src)
|
| 531 |
+
sb = DockerSandbox(workdir=d, exec_timeout_s=60)
|
| 532 |
+
sb.boot(_TEST_IMAGE)
|
| 533 |
+
try:
|
| 534 |
+
# run_tests appends the shlex-quoted node id to the command, and
|
| 535 |
+
# the runner uses it to pick which FIXED check to run.
|
| 536 |
+
res = sb.run_tests("python runner.py", (node,))
|
| 537 |
+
return node in res.passed, res.stdout
|
| 538 |
+
finally:
|
| 539 |
+
sb.close()
|
| 540 |
+
|
| 541 |
+
# Gate 1 — solved: both pass.
|
| 542 |
+
g1t, _ = _run(_MODULE_SOLVED, target)
|
| 543 |
+
g1g, _ = _run(_MODULE_SOLVED, guard)
|
| 544 |
+
assert g1t and g1g, "Gate 1 (baseline green) failed in hardened container"
|
| 545 |
+
|
| 546 |
+
# Gate 2 — broken: target FAILS (mul gone).
|
| 547 |
+
g2t, out2 = _run(_MODULE_BROKEN, target)
|
| 548 |
+
assert not g2t, f"Gate 2 (deletion breaks target) failed:\n{out2}"
|
| 549 |
+
|
| 550 |
+
# Gate 3 — broken: guard still PASSES.
|
| 551 |
+
g3g, out3 = _run(_MODULE_BROKEN, guard)
|
| 552 |
+
assert g3g, f"Gate 3 (remains functional) failed:\n{out3}"
|
| 553 |
+
|
| 554 |
+
# Gate 4 — gold restores: target passes again.
|
| 555 |
+
g4t, _ = _run(_MODULE_SOLVED, target)
|
| 556 |
+
assert g4t, "Gate 4 (gold restores) failed"
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@live
|
| 560 |
+
def test_live_network_is_disabled():
|
| 561 |
+
"""--network none / network_disabled actually blocks egress in the live
|
| 562 |
+
container — the reward-hack egress kill-switch."""
|
| 563 |
+
if not _image_present(_TEST_IMAGE):
|
| 564 |
+
pytest.skip(f"{_TEST_IMAGE} not present locally")
|
| 565 |
+
with tempfile.TemporaryDirectory() as d:
|
| 566 |
+
_materialize(d, _MODULE_SOLVED)
|
| 567 |
+
with open(os.path.join(d, "netprobe.py"), "w") as f:
|
| 568 |
+
f.write(_NETPROBE)
|
| 569 |
+
sb = DockerSandbox(workdir=d, exec_timeout_s=30)
|
| 570 |
+
sb.boot(_TEST_IMAGE)
|
| 571 |
+
try:
|
| 572 |
+
out = sb.exec({"command": "python netprobe.py"})
|
| 573 |
+
assert "CONNECTED" not in out, f"network egress was NOT blocked:\n{out}"
|
| 574 |
+
assert "BLOCKED" in out, f"unexpected network probe output:\n{out}"
|
| 575 |
+
finally:
|
| 576 |
+
sb.close()
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@live
|
| 580 |
+
def test_live_cache_scrub_removes_bytecode():
|
| 581 |
+
"""The cache scrub primary control holds on a real container: a stale .pyc
|
| 582 |
+
on the host mount is removed by boot() before the (broken) episode."""
|
| 583 |
+
with tempfile.TemporaryDirectory() as d:
|
| 584 |
+
_materialize(d, _MODULE_BROKEN)
|
| 585 |
+
os.makedirs(os.path.join(d, "__pycache__"), exist_ok=True)
|
| 586 |
+
with open(os.path.join(d, "__pycache__", "feature.cpython-311.pyc"), "wb") as f:
|
| 587 |
+
f.write(b"\x00stale-bytecode-with-mul-signature")
|
| 588 |
+
|
| 589 |
+
if not _image_present(_TEST_IMAGE):
|
| 590 |
+
# scrub is host-side and needs no daemon, but keep the live gate honest
|
| 591 |
+
pass
|
| 592 |
+
container = None
|
| 593 |
+
try:
|
| 594 |
+
sb = DockerSandbox(workdir=d)
|
| 595 |
+
sb.boot(_TEST_IMAGE)
|
| 596 |
+
container = sb
|
| 597 |
+
assert not os.path.exists(os.path.join(d, "__pycache__")), \
|
| 598 |
+
"cache scrub did not remove __pycache__ in DockerSandbox.boot()"
|
| 599 |
+
finally:
|
| 600 |
+
if container is not None:
|
| 601 |
+
container.close()
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
@live
|
| 605 |
+
def test_live_runsc_runtime():
|
| 606 |
+
"""If gVisor is registered, boot with runtime='runsc' and run a test in it;
|
| 607 |
+
else skip (runsc is not installed on most hosts)."""
|
| 608 |
+
if not ds_mod.runsc_available():
|
| 609 |
+
pytest.skip("gVisor 'runsc' runtime not registered with this daemon")
|
| 610 |
+
if not _image_present(_TEST_IMAGE):
|
| 611 |
+
pytest.skip(f"{_TEST_IMAGE} not present locally")
|
| 612 |
+
with tempfile.TemporaryDirectory() as d:
|
| 613 |
+
_materialize(d, _MODULE_SOLVED)
|
| 614 |
+
sb = DockerSandbox(workdir=d, runtime="runsc", exec_timeout_s=60)
|
| 615 |
+
sb.boot(_TEST_IMAGE)
|
| 616 |
+
try:
|
| 617 |
+
res = sb.run_tests("python runner.py", ("feature.py::test_mul",))
|
| 618 |
+
assert "feature.py::test_mul" in res.passed
|
| 619 |
+
finally:
|
| 620 |
+
sb.close()
|
|
@@ -47,6 +47,7 @@ from composer_replication.diloco.serverless.allreduce import (
|
|
| 47 |
MockManager,
|
| 48 |
ObjectStoreAllReduce,
|
| 49 |
)
|
|
|
|
| 50 |
from composer_replication.diloco.serverless.executor import (
|
| 51 |
LocalProcessExecutor,
|
| 52 |
ReplicaHandle,
|
|
@@ -55,13 +56,16 @@ from composer_replication.diloco.serverless.executor import (
|
|
| 55 |
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
|
| 56 |
from composer_replication.diloco.serverless.modal import ModalExecutor
|
| 57 |
from composer_replication.diloco.serverless.modal_spawn import ModalSpawnExecutor
|
|
|
|
| 58 |
|
| 59 |
__all__ = [
|
|
|
|
| 60 |
"HFJobsExecutor",
|
| 61 |
"LocalProcessExecutor",
|
| 62 |
"MockManager",
|
| 63 |
"ModalExecutor",
|
| 64 |
"ModalSpawnExecutor",
|
|
|
|
| 65 |
"ObjectStoreAllReduce",
|
| 66 |
"ReplicaHandle",
|
| 67 |
"ServerlessExecutor",
|
|
|
|
| 47 |
MockManager,
|
| 48 |
ObjectStoreAllReduce,
|
| 49 |
)
|
| 50 |
+
from composer_replication.diloco.serverless.eks import EKSExecutor
|
| 51 |
from composer_replication.diloco.serverless.executor import (
|
| 52 |
LocalProcessExecutor,
|
| 53 |
ReplicaHandle,
|
|
|
|
| 56 |
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
|
| 57 |
from composer_replication.diloco.serverless.modal import ModalExecutor
|
| 58 |
from composer_replication.diloco.serverless.modal_spawn import ModalSpawnExecutor
|
| 59 |
+
from composer_replication.diloco.serverless.sagemaker import SageMakerExecutor
|
| 60 |
|
| 61 |
__all__ = [
|
| 62 |
+
"EKSExecutor",
|
| 63 |
"HFJobsExecutor",
|
| 64 |
"LocalProcessExecutor",
|
| 65 |
"MockManager",
|
| 66 |
"ModalExecutor",
|
| 67 |
"ModalSpawnExecutor",
|
| 68 |
+
"SageMakerExecutor",
|
| 69 |
"ObjectStoreAllReduce",
|
| 70 |
"ReplicaHandle",
|
| 71 |
"ServerlessExecutor",
|
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""EKSExecutor — production Amazon EKS / Kubernetes-backed serverless executor.
|
| 2 |
+
|
| 3 |
+
This is the v0-finished k8s sibling of `ModalSpawnExecutor`. It implements
|
| 4 |
+
the `ServerlessExecutor` Protocol against the Kubernetes ``BatchV1Api`` using
|
| 5 |
+
the **single Indexed Job** topology recommended for gang-scheduled DiLoCo
|
| 6 |
+
replicas.
|
| 7 |
+
|
| 8 |
+
Topology (the load-bearing design choice)
|
| 9 |
+
------------------------------------------
|
| 10 |
+
There are two ways to map N replicas onto k8s:
|
| 11 |
+
|
| 12 |
+
(A) ONE Indexed Job — ``completions=N, parallelism=N,
|
| 13 |
+
completionMode='Indexed'``. The control plane assigns each pod a
|
| 14 |
+
``JOB_COMPLETION_INDEX`` 0..N-1 which IS the rank, all pods share one
|
| 15 |
+
rendezvous URI, scheduling is atomic, and a single delete cancels the
|
| 16 |
+
whole cohort.
|
| 17 |
+
(B) N separate non-indexed Jobs, one per rank.
|
| 18 |
+
|
| 19 |
+
`EKSExecutor` uses **(A)** because it is the better fit for DiLoCo: rank
|
| 20 |
+
assignment is free, scheduling is gang-atomic, and one delete tears down the
|
| 21 |
+
cohort — which matches ``ObjectStoreAllReduce``'s all-or-nothing barrier. The
|
| 22 |
+
reconciliation with the per-replica ``ReplicaHandle`` contract: ``launch_replicas``
|
| 23 |
+
creates ONE Indexed Job but still returns N ``ReplicaHandle`` objects
|
| 24 |
+
(``handles[i].rank == i``) whose ``metadata`` stores the SHARED
|
| 25 |
+
``job_name``/``namespace`` plus that rank.
|
| 26 |
+
|
| 27 |
+
This is materially different from ``ModalSpawnExecutor`` where each handle is
|
| 28 |
+
an independent ``FunctionCall``:
|
| 29 |
+
|
| 30 |
+
* ``poll(handle)`` reads the single Job status and checks whether
|
| 31 |
+
``handle.rank`` is in the run-length-compressed ``completed_indexes`` /
|
| 32 |
+
``failed_indexes`` strings.
|
| 33 |
+
* ``cancel(handle)`` on ANY handle deletes the WHOLE Job (intentional gang
|
| 34 |
+
semantics — cancelling one rank tears down the whole replica cohort).
|
| 35 |
+
|
| 36 |
+
Rank plumbing
|
| 37 |
+
-------------
|
| 38 |
+
The repo's ``replica_entrypoint`` reads ``REPLICA_RANK``. We bridge the k8s
|
| 39 |
+
completion-index to that env var via the downward API rather than relying on
|
| 40 |
+
the auto-injected ``JOB_COMPLETION_INDEX``::
|
| 41 |
+
|
| 42 |
+
V1EnvVar(
|
| 43 |
+
name="REPLICA_RANK",
|
| 44 |
+
value_from=V1EnvVarSource(field_ref=V1ObjectFieldSelector(
|
| 45 |
+
field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']")),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
so the unchanged entrypoint's ``REPLICA_RANK`` read just works. ``WORLD_SIZE``
|
| 49 |
+
is set as a literal env var.
|
| 50 |
+
|
| 51 |
+
S3 rendezvous via IRSA / Pod Identity
|
| 52 |
+
-------------------------------------
|
| 53 |
+
``EKSExecutor`` accepts ``service_account_name`` and references it on the
|
| 54 |
+
PodSpec. The EKS Pod Identity / IRSA mutating webhook then injects
|
| 55 |
+
``AWS_ROLE_ARN`` + ``AWS_WEB_IDENTITY_TOKEN_FILE`` (and a projected token
|
| 56 |
+
volume) into the pod, so ``boto3``/``s3fs``/``fsspec`` pick up credentials via
|
| 57 |
+
the web-identity provider with ZERO code change inside the replica — the
|
| 58 |
+
``s3://`` rendezvous works out of the box. ``EKSExecutor`` only REFERENCES a
|
| 59 |
+
pre-annotated ServiceAccount; it never creates IAM/OIDC resources.
|
| 60 |
+
|
| 61 |
+
Sandboxing (advanced, optional)
|
| 62 |
+
-------------------------------
|
| 63 |
+
``runtime_class_name`` references a pre-existing cluster-scoped ``RuntimeClass``
|
| 64 |
+
(``runsc`` for gVisor, ``kata`` for Kata). It defaults to ``None``.
|
| 65 |
+
|
| 66 |
+
.. warning::
|
| 67 |
+
Combining ``gpu`` with ``runtime_class_name`` is advanced and unverified.
|
| 68 |
+
gVisor (runsc) needs ``nvproxy`` enabled and only supports a fixed allowlist
|
| 69 |
+
of NVIDIA driver families; Kata runs a microVM that caps CPU/mem and needs
|
| 70 |
+
GPU passthrough (PCIe/IOMMU + NVIDIA Kata Manager + CDI). Do not silently
|
| 71 |
+
combine the two without operator validation. ``EKSExecutor`` cannot create
|
| 72 |
+
the RuntimeClass — it only references one that already exists.
|
| 73 |
+
|
| 74 |
+
References
|
| 75 |
+
----------
|
| 76 |
+
- k8s Indexed Jobs: https://kubernetes.io/docs/tasks/job/indexed-parallel-processing-static/
|
| 77 |
+
- kubernetes-client/python job_crud example + V1JobSpec / V1JobStatus docs
|
| 78 |
+
- EKS IRSA: https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
|
| 79 |
+
- ADR-005 (executor protocol design)
|
| 80 |
+
"""
|
| 81 |
+
from __future__ import annotations
|
| 82 |
+
|
| 83 |
+
import time
|
| 84 |
+
import uuid
|
| 85 |
+
from collections.abc import Callable, Mapping
|
| 86 |
+
from typing import Any
|
| 87 |
+
|
| 88 |
+
from composer_replication.diloco.serverless.executor import (
|
| 89 |
+
ReplicaHandle,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Logical GPU spec ("A100"/"H100") -> (gpu_count_string, node_selector merge).
|
| 93 |
+
# The Protocol's `gpu` arg is a logical name; map it to a concrete EKS node
|
| 94 |
+
# class + GPU count rather than passing the opaque string straight through.
|
| 95 |
+
_GPU_SPEC_TABLE: dict[str, tuple[str, dict[str, str]]] = {
|
| 96 |
+
"A100": ("1", {"node.kubernetes.io/instance-type": "p4d.24xlarge"}),
|
| 97 |
+
"H100": ("1", {"node.kubernetes.io/instance-type": "p5.48xlarge"}),
|
| 98 |
+
"A10G": ("1", {"node.kubernetes.io/instance-type": "g5.xlarge"}),
|
| 99 |
+
"T4": ("1", {"node.kubernetes.io/instance-type": "g4dn.xlarge"}),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _expand_indexes(spec: str | None) -> set[int]:
|
| 104 |
+
"""Expand a run-length-compressed completion-index string to a set.
|
| 105 |
+
|
| 106 |
+
The k8s ``V1JobStatus.completed_indexes`` / ``failed_indexes`` fields are
|
| 107 |
+
strings like ``"1,3-5,7"`` (comma-separated singletons and ``a-b`` ranges).
|
| 108 |
+
``_expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}``. Empty/None -> empty set.
|
| 109 |
+
"""
|
| 110 |
+
out: set[int] = set()
|
| 111 |
+
if not spec:
|
| 112 |
+
return out
|
| 113 |
+
for token in spec.split(","):
|
| 114 |
+
token = token.strip()
|
| 115 |
+
if not token:
|
| 116 |
+
continue
|
| 117 |
+
if "-" in token:
|
| 118 |
+
lo_s, _, hi_s = token.partition("-")
|
| 119 |
+
try:
|
| 120 |
+
lo, hi = int(lo_s), int(hi_s)
|
| 121 |
+
except ValueError:
|
| 122 |
+
continue
|
| 123 |
+
if hi < lo:
|
| 124 |
+
lo, hi = hi, lo
|
| 125 |
+
out.update(range(lo, hi + 1))
|
| 126 |
+
else:
|
| 127 |
+
try:
|
| 128 |
+
out.add(int(token))
|
| 129 |
+
except ValueError:
|
| 130 |
+
continue
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class EKSExecutor:
|
| 135 |
+
"""Run N DiLoCo replicas as a single Kubernetes Indexed Job on EKS.
|
| 136 |
+
|
| 137 |
+
Implements the `ServerlessExecutor` Protocol. ``launch_replicas`` creates
|
| 138 |
+
ONE Indexed Job (``completions == parallelism == n_replicas``,
|
| 139 |
+
``completionMode='Indexed'``) and returns N ``ReplicaHandle`` objects that
|
| 140 |
+
all share the same ``job_name``/``namespace`` (gang semantics).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
image: container image that has ``composer_replication`` installed and
|
| 144 |
+
runs the replica entrypoint.
|
| 145 |
+
namespace: k8s namespace for the Job. Default ``"default"``.
|
| 146 |
+
service_account_name: ServiceAccount to attach to the PodSpec for IRSA /
|
| 147 |
+
EKS Pod Identity S3 access. ``EKSExecutor`` references it; it does
|
| 148 |
+
NOT create it or any IAM/OIDC resources.
|
| 149 |
+
node_selector: extra node selector merged into the GPU node selector.
|
| 150 |
+
tolerations: PodSpec tolerations. If GPU is requested and the caller did
|
| 151 |
+
not supply tolerations, the standard ``nvidia.com/gpu`` NoSchedule
|
| 152 |
+
toleration is added automatically.
|
| 153 |
+
runtime_class_name: optional pre-existing RuntimeClass (e.g. ``"gvisor"``
|
| 154 |
+
/ ``"kata"``). Default ``None``. See the module-level warning before
|
| 155 |
+
combining with ``gpu``.
|
| 156 |
+
command: container command. Defaults to the repo replica entrypoint
|
| 157 |
+
module ``["python", "-m",
|
| 158 |
+
"composer_replication.diloco.serverless.replica_entrypoint"]``.
|
| 159 |
+
cpu_request / memory_request: PodSpec resource requests.
|
| 160 |
+
ttl_seconds_after_finished: auto-delete the finished Job (and its pods,
|
| 161 |
+
cascadingly) after this many seconds. Default 3600.
|
| 162 |
+
backoff_limit: Job retry budget. Default 0 (fail-fast — RL gangs usually
|
| 163 |
+
do NOT want the k8s default of 6 retries).
|
| 164 |
+
gpu_resource_key: the GPU resource key. Default ``"nvidia.com/gpu"``.
|
| 165 |
+
run_id: optional run id baked into the generated Job name.
|
| 166 |
+
batch_api / core_api: dependency-injected ``BatchV1Api`` / ``CoreV1Api``
|
| 167 |
+
instances. When ``None`` (the default), they are built lazily on
|
| 168 |
+
first use via in-cluster or kube-config loading. Tests inject mocks.
|
| 169 |
+
|
| 170 |
+
Raises:
|
| 171 |
+
RuntimeError: if the ``kubernetes`` client is not installed AND no api
|
| 172 |
+
was injected (the import is needed to construct V1 model objects).
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
backend_name = "eks"
|
| 176 |
+
# Pods are network-isolated by default; rendezvous is S3 (ObjectStoreAllReduce).
|
| 177 |
+
supports_inter_replica_network = False
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
image: str,
|
| 182 |
+
*,
|
| 183 |
+
namespace: str = "default",
|
| 184 |
+
service_account_name: str | None = None,
|
| 185 |
+
node_selector: dict[str, str] | None = None,
|
| 186 |
+
tolerations: list[Any] | None = None,
|
| 187 |
+
runtime_class_name: str | None = None,
|
| 188 |
+
command: list[str] | None = None,
|
| 189 |
+
cpu_request: str = "4",
|
| 190 |
+
memory_request: str = "16Gi",
|
| 191 |
+
ttl_seconds_after_finished: int = 3600,
|
| 192 |
+
backoff_limit: int = 0,
|
| 193 |
+
gpu_resource_key: str = "nvidia.com/gpu",
|
| 194 |
+
run_id: str | None = None,
|
| 195 |
+
batch_api: Any = None,
|
| 196 |
+
core_api: Any = None,
|
| 197 |
+
) -> None:
|
| 198 |
+
# `kubernetes` is only strictly required when we have to BUILD V1 model
|
| 199 |
+
# objects ourselves (launch_replicas) or load cluster config (when no
|
| 200 |
+
# api is injected). We surface a clear error here only if we definitely
|
| 201 |
+
# need it and it is absent — i.e. when no api was injected. When apis
|
| 202 |
+
# ARE injected (tests, or callers that pre-built clients), we tolerate a
|
| 203 |
+
# missing top-level package and lazy-import `client` per call.
|
| 204 |
+
if batch_api is None or core_api is None:
|
| 205 |
+
try:
|
| 206 |
+
import kubernetes # noqa: F401
|
| 207 |
+
except ImportError as e:
|
| 208 |
+
raise RuntimeError(
|
| 209 |
+
'EKSExecutor requires the kubernetes client: '
|
| 210 |
+
'pip install "kubernetes>=29" (or '
|
| 211 |
+
"`pip install -e .[serverless]`). Got: " + repr(e)
|
| 212 |
+
) from e
|
| 213 |
+
|
| 214 |
+
self.image = image
|
| 215 |
+
self.namespace = namespace
|
| 216 |
+
self.service_account_name = service_account_name
|
| 217 |
+
self.node_selector = dict(node_selector) if node_selector else None
|
| 218 |
+
self.tolerations = list(tolerations) if tolerations else None
|
| 219 |
+
self.runtime_class_name = runtime_class_name
|
| 220 |
+
self.command = command or [
|
| 221 |
+
"python",
|
| 222 |
+
"-m",
|
| 223 |
+
"composer_replication.diloco.serverless.replica_entrypoint",
|
| 224 |
+
]
|
| 225 |
+
self.cpu_request = cpu_request
|
| 226 |
+
self.memory_request = memory_request
|
| 227 |
+
self.ttl_seconds_after_finished = ttl_seconds_after_finished
|
| 228 |
+
self.backoff_limit = backoff_limit
|
| 229 |
+
self.gpu_resource_key = gpu_resource_key
|
| 230 |
+
self.run_id = run_id or "diloco"
|
| 231 |
+
|
| 232 |
+
self._batch_api = batch_api
|
| 233 |
+
self._core_api = core_api
|
| 234 |
+
# rank -> {"job_name", "namespace", "result"}; lets poll/collect cache.
|
| 235 |
+
self._handles: dict[int, dict[str, Any]] = {}
|
| 236 |
+
|
| 237 |
+
# -----------------------------------------------------------------
|
| 238 |
+
# Lazy client construction (config loading only when not injected)
|
| 239 |
+
# -----------------------------------------------------------------
|
| 240 |
+
|
| 241 |
+
def _load_config(self) -> None:
|
| 242 |
+
"""Load k8s config once: in-cluster first, then ~/.kube/config."""
|
| 243 |
+
from kubernetes import config
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
config.load_incluster_config()
|
| 247 |
+
except config.ConfigException:
|
| 248 |
+
config.load_kube_config()
|
| 249 |
+
|
| 250 |
+
def _batch(self) -> Any:
|
| 251 |
+
if self._batch_api is None:
|
| 252 |
+
from kubernetes import client
|
| 253 |
+
|
| 254 |
+
self._load_config()
|
| 255 |
+
self._batch_api = client.BatchV1Api()
|
| 256 |
+
return self._batch_api
|
| 257 |
+
|
| 258 |
+
def _core(self) -> Any:
|
| 259 |
+
if self._core_api is None:
|
| 260 |
+
from kubernetes import client
|
| 261 |
+
|
| 262 |
+
self._load_config()
|
| 263 |
+
self._core_api = client.CoreV1Api()
|
| 264 |
+
return self._core_api
|
| 265 |
+
|
| 266 |
+
# -----------------------------------------------------------------
|
| 267 |
+
# Job-spec construction
|
| 268 |
+
# -----------------------------------------------------------------
|
| 269 |
+
|
| 270 |
+
def _build_env(
|
| 271 |
+
self, world_size: int, entrypoint_args: Mapping[str, Any]
|
| 272 |
+
) -> list[Any]:
|
| 273 |
+
"""Build the container env list, including the downward-API rank var."""
|
| 274 |
+
from kubernetes import client
|
| 275 |
+
|
| 276 |
+
env: list[Any] = [
|
| 277 |
+
# REPLICA_RANK from the per-pod completion-index annotation via the
|
| 278 |
+
# downward API — bridges k8s indexing to the repo entrypoint's
|
| 279 |
+
# REPLICA_RANK read with no entrypoint change.
|
| 280 |
+
client.V1EnvVar(
|
| 281 |
+
name="REPLICA_RANK",
|
| 282 |
+
value_from=client.V1EnvVarSource(
|
| 283 |
+
field_ref=client.V1ObjectFieldSelector(
|
| 284 |
+
field_path=(
|
| 285 |
+
"metadata.annotations["
|
| 286 |
+
"'batch.kubernetes.io/job-completion-index']"
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
),
|
| 290 |
+
),
|
| 291 |
+
client.V1EnvVar(name="WORLD_SIZE", value=str(world_size)),
|
| 292 |
+
]
|
| 293 |
+
# rendezvous_uri (and any other scalar kwargs) passed as literal env so
|
| 294 |
+
# the entrypoint / user code can read them. `rank_env` is the
|
| 295 |
+
# LocalProcessExecutor convention — drop it (same as ModalSpawnExecutor).
|
| 296 |
+
for key, value in entrypoint_args.items():
|
| 297 |
+
if key == "rank_env":
|
| 298 |
+
continue
|
| 299 |
+
if isinstance(value, (str, int, float, bool)):
|
| 300 |
+
env.append(
|
| 301 |
+
client.V1EnvVar(name=key.upper(), value=str(value))
|
| 302 |
+
)
|
| 303 |
+
return env
|
| 304 |
+
|
| 305 |
+
def _build_resources(self, gpu: str | None) -> tuple[Any, dict[str, str], list[Any]]:
|
| 306 |
+
"""Build V1ResourceRequirements + (node_selector, tolerations) for GPU.
|
| 307 |
+
|
| 308 |
+
Returns (resources, node_selector, tolerations). The GPU count is
|
| 309 |
+
ALWAYS a STRING ('1', not int 1) — the OpenAPI type for the limits map
|
| 310 |
+
is dict[str, str] and an int can serialize wrong or raise.
|
| 311 |
+
"""
|
| 312 |
+
from kubernetes import client
|
| 313 |
+
|
| 314 |
+
requests = {"cpu": self.cpu_request, "memory": self.memory_request}
|
| 315 |
+
limits: dict[str, str] = {}
|
| 316 |
+
node_selector: dict[str, str] = dict(self.node_selector or {})
|
| 317 |
+
tolerations: list[Any] = list(self.tolerations or [])
|
| 318 |
+
|
| 319 |
+
if gpu is not None:
|
| 320 |
+
gpu_count, gpu_node_selector = _GPU_SPEC_TABLE.get(
|
| 321 |
+
gpu, ("1", {})
|
| 322 |
+
)
|
| 323 |
+
# STRING, always.
|
| 324 |
+
limits[self.gpu_resource_key] = str(gpu_count)
|
| 325 |
+
# Merge the mapped node selector under any caller-supplied one
|
| 326 |
+
# (caller wins on key conflicts).
|
| 327 |
+
for k, v in gpu_node_selector.items():
|
| 328 |
+
node_selector.setdefault(k, v)
|
| 329 |
+
# Auto-add the GPU NoSchedule toleration unless the caller overrode
|
| 330 |
+
# tolerations explicitly.
|
| 331 |
+
if not self.tolerations:
|
| 332 |
+
tolerations.append(
|
| 333 |
+
client.V1Toleration(
|
| 334 |
+
key=self.gpu_resource_key,
|
| 335 |
+
operator="Exists",
|
| 336 |
+
effect="NoSchedule",
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
resources = client.V1ResourceRequirements(
|
| 341 |
+
requests=requests,
|
| 342 |
+
limits=limits or None,
|
| 343 |
+
)
|
| 344 |
+
return resources, node_selector, tolerations
|
| 345 |
+
|
| 346 |
+
def _build_job(
|
| 347 |
+
self,
|
| 348 |
+
*,
|
| 349 |
+
job_name: str,
|
| 350 |
+
n_replicas: int,
|
| 351 |
+
gpu: str | None,
|
| 352 |
+
timeout: int,
|
| 353 |
+
entrypoint_args: Mapping[str, Any],
|
| 354 |
+
) -> Any:
|
| 355 |
+
"""Assemble the full V1Job (Indexed) bottom-up."""
|
| 356 |
+
from kubernetes import client
|
| 357 |
+
|
| 358 |
+
env = self._build_env(n_replicas, entrypoint_args)
|
| 359 |
+
resources, node_selector, tolerations = self._build_resources(gpu)
|
| 360 |
+
|
| 361 |
+
container = client.V1Container(
|
| 362 |
+
name="replica",
|
| 363 |
+
image=self.image,
|
| 364 |
+
command=list(self.command),
|
| 365 |
+
env=env,
|
| 366 |
+
resources=resources,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
pod_spec = client.V1PodSpec(
|
| 370 |
+
restart_policy="Never", # required for Indexed jobs / fail-fast RL
|
| 371 |
+
containers=[container],
|
| 372 |
+
service_account_name=self.service_account_name,
|
| 373 |
+
node_selector=node_selector or None,
|
| 374 |
+
tolerations=tolerations or None,
|
| 375 |
+
runtime_class_name=self.runtime_class_name,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
labels = {"app": "composer-diloco", "job-name": job_name}
|
| 379 |
+
pod_template = client.V1PodTemplateSpec(
|
| 380 |
+
metadata=client.V1ObjectMeta(labels=labels),
|
| 381 |
+
spec=pod_spec,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
job_spec = client.V1JobSpec(
|
| 385 |
+
template=pod_template,
|
| 386 |
+
completions=n_replicas,
|
| 387 |
+
parallelism=n_replicas,
|
| 388 |
+
completion_mode="Indexed",
|
| 389 |
+
backoff_limit=self.backoff_limit,
|
| 390 |
+
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
|
| 391 |
+
active_deadline_seconds=timeout,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return client.V1Job(
|
| 395 |
+
api_version="batch/v1",
|
| 396 |
+
kind="Job",
|
| 397 |
+
metadata=client.V1ObjectMeta(name=job_name, labels=labels),
|
| 398 |
+
spec=job_spec,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# -----------------------------------------------------------------
|
| 402 |
+
# ServerlessExecutor Protocol
|
| 403 |
+
# -----------------------------------------------------------------
|
| 404 |
+
|
| 405 |
+
def launch_replicas(
|
| 406 |
+
self,
|
| 407 |
+
n_replicas: int,
|
| 408 |
+
entrypoint: str | Callable[..., Any],
|
| 409 |
+
entrypoint_args: Mapping[str, Any],
|
| 410 |
+
*,
|
| 411 |
+
gpu: str | None = None,
|
| 412 |
+
timeout: int = 3600,
|
| 413 |
+
) -> list[ReplicaHandle]:
|
| 414 |
+
"""Create ONE Indexed Job of N pods and return N rank-ordered handles.
|
| 415 |
+
|
| 416 |
+
``entrypoint`` is ignored when it names a Callable (k8s runs a container
|
| 417 |
+
command, not an in-process callable); the container command is fixed at
|
| 418 |
+
construction (``command`` ctor arg). The repo entrypoint module is the
|
| 419 |
+
default. ``entrypoint_args`` scalar kwargs are passed as upper-cased env
|
| 420 |
+
vars so ``replica_entrypoint`` / user code can read them. ``gpu`` maps to
|
| 421 |
+
a ``nvidia.com/gpu`` limit + node selector; ``timeout`` becomes the Job's
|
| 422 |
+
``active_deadline_seconds`` hard wall-clock kill.
|
| 423 |
+
"""
|
| 424 |
+
del entrypoint # k8s runs a container command, not an in-process fn
|
| 425 |
+
|
| 426 |
+
if n_replicas < 1:
|
| 427 |
+
raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
|
| 428 |
+
|
| 429 |
+
job_name = f"{self.run_id}-{uuid.uuid4().hex[:8]}"
|
| 430 |
+
job = self._build_job(
|
| 431 |
+
job_name=job_name,
|
| 432 |
+
n_replicas=n_replicas,
|
| 433 |
+
gpu=gpu,
|
| 434 |
+
timeout=timeout,
|
| 435 |
+
entrypoint_args=entrypoint_args,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
self._batch().create_namespaced_job(namespace=self.namespace, body=job)
|
| 439 |
+
|
| 440 |
+
handles: list[ReplicaHandle] = []
|
| 441 |
+
for rank in range(n_replicas):
|
| 442 |
+
handles.append(
|
| 443 |
+
ReplicaHandle(
|
| 444 |
+
rank=rank,
|
| 445 |
+
backend_name=self.backend_name,
|
| 446 |
+
metadata={
|
| 447 |
+
"job_name": job_name,
|
| 448 |
+
"namespace": self.namespace,
|
| 449 |
+
"rank": rank,
|
| 450 |
+
},
|
| 451 |
+
)
|
| 452 |
+
)
|
| 453 |
+
self._handles[rank] = {
|
| 454 |
+
"job_name": job_name,
|
| 455 |
+
"namespace": self.namespace,
|
| 456 |
+
"result": None,
|
| 457 |
+
}
|
| 458 |
+
return handles
|
| 459 |
+
|
| 460 |
+
def poll(self, handle: ReplicaHandle) -> str:
|
| 461 |
+
"""Poll this rank's status off the shared Indexed Job.
|
| 462 |
+
|
| 463 |
+
Reads ``read_namespaced_job_status`` once, then maps the whole-job
|
| 464 |
+
status to this rank: ``rank in completed_indexes`` -> ``succeeded``;
|
| 465 |
+
``rank in failed_indexes`` -> ``failed``; ``active > 0`` -> ``running``;
|
| 466 |
+
else ``pending``. A 404 (Job deleted/cancelled) -> ``cancelled``.
|
| 467 |
+
|
| 468 |
+
Returns one of: ``pending`` | ``running`` | ``succeeded`` | ``failed`` |
|
| 469 |
+
``cancelled``.
|
| 470 |
+
"""
|
| 471 |
+
from kubernetes.client.exceptions import ApiException
|
| 472 |
+
|
| 473 |
+
job_name = handle.metadata["job_name"]
|
| 474 |
+
namespace = handle.metadata["namespace"]
|
| 475 |
+
rank = handle.metadata.get("rank", handle.rank)
|
| 476 |
+
|
| 477 |
+
try:
|
| 478 |
+
status = self._batch().read_namespaced_job_status(
|
| 479 |
+
name=job_name, namespace=namespace
|
| 480 |
+
).status
|
| 481 |
+
except ApiException as e:
|
| 482 |
+
if getattr(e, "status", None) == 404:
|
| 483 |
+
return "cancelled"
|
| 484 |
+
raise
|
| 485 |
+
|
| 486 |
+
completed = _expand_indexes(getattr(status, "completed_indexes", None))
|
| 487 |
+
if rank in completed:
|
| 488 |
+
return "succeeded"
|
| 489 |
+
|
| 490 |
+
failed = _expand_indexes(getattr(status, "failed_indexes", None))
|
| 491 |
+
if rank in failed:
|
| 492 |
+
return "failed"
|
| 493 |
+
|
| 494 |
+
# Whole-job terminal Failed (e.g. DeadlineExceeded / backoff) with no
|
| 495 |
+
# per-index attribution -> treat this rank as failed.
|
| 496 |
+
for cond in (getattr(status, "conditions", None) or []):
|
| 497 |
+
if (
|
| 498 |
+
getattr(cond, "type", None) == "Failed"
|
| 499 |
+
and getattr(cond, "status", None) == "True"
|
| 500 |
+
):
|
| 501 |
+
return "failed"
|
| 502 |
+
|
| 503 |
+
active = getattr(status, "active", None) or 0
|
| 504 |
+
if active > 0:
|
| 505 |
+
return "running"
|
| 506 |
+
return "pending"
|
| 507 |
+
|
| 508 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
|
| 509 |
+
"""Read recent logs for this rank's pod.
|
| 510 |
+
|
| 511 |
+
Finds the pod whose ``batch.kubernetes.io/job-completion-index``
|
| 512 |
+
annotation (or label) equals the rank, then reads its log tail. Returns
|
| 513 |
+
a placeholder string (rather than raising) when the pod has not started
|
| 514 |
+
or the Job is gone — mirrors ``LocalProcessExecutor``.
|
| 515 |
+
"""
|
| 516 |
+
from kubernetes.client.exceptions import ApiException
|
| 517 |
+
|
| 518 |
+
job_name = handle.metadata["job_name"]
|
| 519 |
+
namespace = handle.metadata["namespace"]
|
| 520 |
+
rank = handle.metadata.get("rank", handle.rank)
|
| 521 |
+
idx_key = "batch.kubernetes.io/job-completion-index"
|
| 522 |
+
|
| 523 |
+
try:
|
| 524 |
+
pods = self._core().list_namespaced_pod(
|
| 525 |
+
namespace=namespace, label_selector=f"job-name={job_name}"
|
| 526 |
+
)
|
| 527 |
+
except ApiException:
|
| 528 |
+
return f"<rank {rank}: job not found / no pods yet>"
|
| 529 |
+
|
| 530 |
+
pod_name = None
|
| 531 |
+
for pod in getattr(pods, "items", None) or []:
|
| 532 |
+
meta = getattr(pod, "metadata", None)
|
| 533 |
+
annotations = getattr(meta, "annotations", None) or {}
|
| 534 |
+
labels = getattr(meta, "labels", None) or {}
|
| 535 |
+
if annotations.get(idx_key) == str(rank) or labels.get(idx_key) == str(rank):
|
| 536 |
+
pod_name = getattr(meta, "name", None)
|
| 537 |
+
break
|
| 538 |
+
|
| 539 |
+
if pod_name is None:
|
| 540 |
+
# Fall back to the deterministic name prefix on k8s >= 1.28.
|
| 541 |
+
prefix = f"{job_name}-{rank}-"
|
| 542 |
+
for pod in getattr(pods, "items", None) or []:
|
| 543 |
+
name = getattr(getattr(pod, "metadata", None), "name", "") or ""
|
| 544 |
+
if name.startswith(prefix):
|
| 545 |
+
pod_name = name
|
| 546 |
+
break
|
| 547 |
+
|
| 548 |
+
if pod_name is None:
|
| 549 |
+
return f"<rank {rank}: pod not started / no logs yet>"
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
return self._core().read_namespaced_pod_log(
|
| 553 |
+
name=pod_name,
|
| 554 |
+
namespace=namespace,
|
| 555 |
+
container="replica",
|
| 556 |
+
tail_lines=n_lines,
|
| 557 |
+
)
|
| 558 |
+
except ApiException as e:
|
| 559 |
+
if getattr(e, "status", None) in (400, 404):
|
| 560 |
+
return f"<rank {rank}: pod not started / no logs yet>"
|
| 561 |
+
raise
|
| 562 |
+
|
| 563 |
+
def cancel(self, handle: ReplicaHandle) -> None:
|
| 564 |
+
"""Delete the WHOLE shared Indexed Job (gang teardown).
|
| 565 |
+
|
| 566 |
+
Because ``EKSExecutor`` uses one shared Indexed Job, cancelling ANY rank
|
| 567 |
+
tears down the entire replica cohort — intentional gang semantics for
|
| 568 |
+
the DiLoCo all-reduce barrier (a single straggler being cancelled should
|
| 569 |
+
not leave the rest spinning and burning GPU).
|
| 570 |
+
|
| 571 |
+
Uses ``propagation_policy='Background'`` so the pods are cascadingly
|
| 572 |
+
deleted (the k8s default ORPHANS pods, which would keep burning GPU —
|
| 573 |
+
the exact failure mode for RL). Idempotent: a 404 (already deleted) is
|
| 574 |
+
swallowed, and an unknown handle never raises, honoring the Protocol's
|
| 575 |
+
"no exception if already terminated" contract.
|
| 576 |
+
"""
|
| 577 |
+
from kubernetes import client
|
| 578 |
+
from kubernetes.client.exceptions import ApiException
|
| 579 |
+
|
| 580 |
+
job_name = handle.metadata.get("job_name")
|
| 581 |
+
namespace = handle.metadata.get("namespace", self.namespace)
|
| 582 |
+
if not job_name:
|
| 583 |
+
return # unknown handle — no-op
|
| 584 |
+
|
| 585 |
+
try:
|
| 586 |
+
self._batch().delete_namespaced_job(
|
| 587 |
+
name=job_name,
|
| 588 |
+
namespace=namespace,
|
| 589 |
+
body=client.V1DeleteOptions(
|
| 590 |
+
propagation_policy="Background",
|
| 591 |
+
grace_period_seconds=0,
|
| 592 |
+
),
|
| 593 |
+
)
|
| 594 |
+
except ApiException as e:
|
| 595 |
+
if getattr(e, "status", None) == 404:
|
| 596 |
+
return # already deleted
|
| 597 |
+
# Best-effort: swallow other API errors (network blip, etc.).
|
| 598 |
+
return
|
| 599 |
+
except Exception:
|
| 600 |
+
return
|
| 601 |
+
|
| 602 |
+
def collect(
|
| 603 |
+
self,
|
| 604 |
+
handles: list[ReplicaHandle],
|
| 605 |
+
*,
|
| 606 |
+
timeout: int | None = None,
|
| 607 |
+
) -> list[dict[str, Any]]:
|
| 608 |
+
"""Poll until every rank reaches a terminal state or the deadline.
|
| 609 |
+
|
| 610 |
+
Sleeps between polls (Job status is eventually consistent — do not
|
| 611 |
+
hammer the API server). Returns per-rank result dicts in handles order::
|
| 612 |
+
|
| 613 |
+
{"rank", "status", "exit_code", "error", "job_name"}
|
| 614 |
+
|
| 615 |
+
``exit_code`` is 0 for succeeded, 1 for failed, ``None`` for
|
| 616 |
+
running/pending/cancelled — matching the Protocol's documented shape.
|
| 617 |
+
"""
|
| 618 |
+
deadline = time.time() + (timeout if timeout is not None else 86400)
|
| 619 |
+
poll_interval = float(self._collect_poll_interval())
|
| 620 |
+
terminal = {"succeeded", "failed", "cancelled"}
|
| 621 |
+
results_by_rank: dict[int, dict[str, Any]] = {}
|
| 622 |
+
|
| 623 |
+
pending = list(handles)
|
| 624 |
+
while pending and time.time() < deadline:
|
| 625 |
+
still_pending: list[ReplicaHandle] = []
|
| 626 |
+
for h in pending:
|
| 627 |
+
state = self.poll(h)
|
| 628 |
+
if state in terminal:
|
| 629 |
+
results_by_rank[h.rank] = self._result_dict(h, state)
|
| 630 |
+
else:
|
| 631 |
+
still_pending.append(h)
|
| 632 |
+
pending = still_pending
|
| 633 |
+
if not pending:
|
| 634 |
+
break
|
| 635 |
+
remaining = deadline - time.time()
|
| 636 |
+
if remaining <= 0:
|
| 637 |
+
break
|
| 638 |
+
time.sleep(min(poll_interval, max(0.0, remaining)))
|
| 639 |
+
|
| 640 |
+
# Any rank still non-terminal at the deadline -> report its last state.
|
| 641 |
+
for h in pending:
|
| 642 |
+
state = self.poll(h)
|
| 643 |
+
results_by_rank[h.rank] = self._result_dict(h, state)
|
| 644 |
+
|
| 645 |
+
return [results_by_rank[h.rank] for h in handles]
|
| 646 |
+
|
| 647 |
+
# -----------------------------------------------------------------
|
| 648 |
+
# Internals
|
| 649 |
+
# -----------------------------------------------------------------
|
| 650 |
+
|
| 651 |
+
def _collect_poll_interval(self) -> float:
|
| 652 |
+
"""Seconds between collect() polls. Overridable in tests."""
|
| 653 |
+
return 5.0
|
| 654 |
+
|
| 655 |
+
@staticmethod
|
| 656 |
+
def _result_dict(handle: ReplicaHandle, state: str) -> dict[str, Any]:
|
| 657 |
+
exit_code = {"succeeded": 0, "failed": 1}.get(state, None)
|
| 658 |
+
error = None
|
| 659 |
+
if state == "failed":
|
| 660 |
+
error = f"rank {handle.rank} reported failed by Job status"
|
| 661 |
+
elif state == "cancelled":
|
| 662 |
+
error = f"rank {handle.rank} Job no longer exists (cancelled)"
|
| 663 |
+
elif state in ("running", "pending"):
|
| 664 |
+
error = f"rank {handle.rank} not terminal at deadline (state={state})"
|
| 665 |
+
return {
|
| 666 |
+
"rank": handle.rank,
|
| 667 |
+
"status": state,
|
| 668 |
+
"exit_code": exit_code,
|
| 669 |
+
"error": error,
|
| 670 |
+
"job_name": handle.metadata.get("job_name"),
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
__all__ = ["EKSExecutor"]
|
|
@@ -36,9 +36,10 @@ class ReplicaHandle:
|
|
| 36 |
class ServerlessExecutor(Protocol):
|
| 37 |
"""Uniform interface for launching N replicas on a serverless backend.
|
| 38 |
|
| 39 |
-
Implementations: `LocalProcessExecutor` (test/dev), `
|
| 40 |
-
(Modal,
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
Note on rank assignment: the Protocol guarantees that handles are
|
| 44 |
returned in rank order (`handles[i].rank == i`). The replica entrypoint
|
|
|
|
| 36 |
class ServerlessExecutor(Protocol):
|
| 37 |
"""Uniform interface for launching N replicas on a serverless backend.
|
| 38 |
|
| 39 |
+
Implementations: `LocalProcessExecutor` (test/dev), `ModalSpawnExecutor`
|
| 40 |
+
(Modal, production), `EKSExecutor` (Amazon EKS / Kubernetes Indexed Job,
|
| 41 |
+
production), `ModalExecutor` / `HFJobsExecutor` (v0 skeletons). Future
|
| 42 |
+
adapters: `RunPodExecutor`, `SageMakerExecutor`.
|
| 43 |
|
| 44 |
Note on rank assignment: the Protocol guarantees that handles are
|
| 45 |
returned in rank order (`handles[i].rank == i`). The replica entrypoint
|
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SageMakerExecutor — production boto3-backed serverless executor.
|
| 2 |
+
|
| 3 |
+
This is a fully-working cloud adapter (the sibling of `ModalSpawnExecutor`,
|
| 4 |
+
not the loud-failing `modal.py` / `hf_jobs.py` skeletons). It implements the
|
| 5 |
+
`ServerlessExecutor` Protocol against Amazon SageMaker Training Jobs via the
|
| 6 |
+
boto3 low-level `sagemaker` client.
|
| 7 |
+
|
| 8 |
+
Design choices
|
| 9 |
+
--------------
|
| 10 |
+
|
| 11 |
+
1. **N independent single-instance jobs, NOT one multi-instance job.**
|
| 12 |
+
SageMaker's *native* distributed training (``ResourceConfig.InstanceCount > 1``)
|
| 13 |
+
groups instances into ONE job with an in-cluster NCCL/MPI fabric wired via
|
| 14 |
+
``/opt/ml/input/config/resourceconfig.json``. That is the WRONG model for
|
| 15 |
+
DiLoCo replicas — it would couple replicas through SageMaker's intra-job
|
| 16 |
+
network and break the "each replica is an independent DiLoCo worker that
|
| 17 |
+
syncs only through S3" design. So ``launch_replicas`` submits N **separate**
|
| 18 |
+
training jobs, each with ``ResourceConfig.InstanceCount == 1``, tagged with
|
| 19 |
+
``REPLICA_RANK=i`` / ``WORLD_SIZE=N`` via the ``Environment`` map. This
|
| 20 |
+
mirrors ``ModalSpawnExecutor`` spawning N independent Modal calls.
|
| 21 |
+
|
| 22 |
+
2. **Same S3 ``ObjectStoreAllReduce`` rendezvous — DiLoCo math untouched.**
|
| 23 |
+
Cross-replica communication is EXCLUSIVELY the object-store rendezvous; the
|
| 24 |
+
executor passes ``rendezvous_uri`` (an ``s3://...`` URI) through to
|
| 25 |
+
``replica_entrypoint.py`` unchanged. ``allreduce.py`` / ``MockManager`` /
|
| 26 |
+
``make_diloco_outer_loop`` / the trainer all stay byte-for-byte identical.
|
| 27 |
+
|
| 28 |
+
3. **Stateless after launch; rank via ``Environment``.** Handle metadata is the
|
| 29 |
+
``training_job_name`` (plus submit timestamp). ``replica_entrypoint.py``
|
| 30 |
+
already reads ``REPLICA_RANK`` from ``os.environ``, so the cleanest channel
|
| 31 |
+
is the ``Environment`` map (string->string, max 100 entries, value <= 512
|
| 32 |
+
chars). The container command is baked into the image entrypoint and the
|
| 33 |
+
rendezvous args are passed via ``AlgorithmSpecification.ContainerArguments``.
|
| 34 |
+
|
| 35 |
+
4. **``supports_inter_replica_network = False``.** Separate single-instance
|
| 36 |
+
training jobs have no mutual network path by design — they rendezvous only
|
| 37 |
+
through S3. (SageMaker's algo-N container fabric and
|
| 38 |
+
``EnableInterContainerTrafficEncryption`` only exist WITHIN a single
|
| 39 |
+
multi-instance job, which this design deliberately does not use.)
|
| 40 |
+
|
| 41 |
+
Load-bearing gotcha — ``EnableNetworkIsolation`` MUST stay ``False``
|
| 42 |
+
--------------------------------------------------------------------
|
| 43 |
+
When ``EnableNetworkIsolation=True`` the training *container* has no outbound
|
| 44 |
+
network access. SageMaker's host-side processes still stage input channels and
|
| 45 |
+
ship CloudWatch logs, but the container itself cannot make S3 GET/PUT calls.
|
| 46 |
+
``ObjectStoreAllReduce`` needs live S3 PUT+GET every outer round, so network
|
| 47 |
+
isolation would silently dead-lock the allreduce poll loop until its timeout.
|
| 48 |
+
This executor pins ``EnableNetworkIsolation=False`` (the API default) and never
|
| 49 |
+
exposes it as a knob. The rendezvous bucket access must instead be granted on
|
| 50 |
+
the execution ``RoleArn`` — the SageMaker analog of EKS IRSA.
|
| 51 |
+
|
| 52 |
+
HyperPod <-> EKS 1:1 control-plane mapping (recommended hybrid)
|
| 53 |
+
---------------------------------------------------------------
|
| 54 |
+
Per the SageMaker docs: *"The high-level architecture of Amazon EKS support in
|
| 55 |
+
HyperPod involves a 1-to-1 mapping between an EKS cluster (control plane) and a
|
| 56 |
+
HyperPod cluster (worker nodes) within a VPC."*
|
| 57 |
+
(https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html)
|
| 58 |
+
|
| 59 |
+
Consequence for this repo's hybrid: "use HyperPod for the inner GRPO trainer"
|
| 60 |
+
does NOT mean leaving EKS — it means attaching a HyperPod-managed
|
| 61 |
+
(auto-recovering, deep-health-checked, PyTorch-job auto-resume) node-group to
|
| 62 |
+
the SAME EKS control plane that runs the outer loop. A future ``EKSExecutor``
|
| 63 |
+
(kubernetes client, Indexed Jobs) therefore targets both plain Karpenter GPU
|
| 64 |
+
nodes AND HyperPod nodes transparently. ``SageMakerExecutor`` (ephemeral
|
| 65 |
+
Training Jobs via boto3) is the SEPARATE bursty-fallback inner-loop path for
|
| 66 |
+
when you don't want a persistent cluster: Training Jobs suit periodic /
|
| 67 |
+
smaller-model / pay-per-use runs; HyperPod suits continuous / large-model /
|
| 68 |
+
persistent runs. Both share the IDENTICAL S3 rendezvous, so a run can move
|
| 69 |
+
between them with zero trainer / loss / DiLoCo changes.
|
| 70 |
+
|
| 71 |
+
References
|
| 72 |
+
----------
|
| 73 |
+
- create_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/create_training_job.html
|
| 74 |
+
- describe_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/describe_training_job.html
|
| 75 |
+
- stop_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/stop_training_job.html
|
| 76 |
+
- network isolation: https://repost.aws/knowledge-center/sagemaker-access-network-isolation
|
| 77 |
+
- HyperPod-EKS: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html
|
| 78 |
+
- ADR-005 (executor protocol design)
|
| 79 |
+
"""
|
| 80 |
+
from __future__ import annotations
|
| 81 |
+
|
| 82 |
+
import json
|
| 83 |
+
import time
|
| 84 |
+
import uuid
|
| 85 |
+
from collections.abc import Callable, Mapping
|
| 86 |
+
from typing import Any
|
| 87 |
+
|
| 88 |
+
from composer_replication.diloco.serverless.executor import (
|
| 89 |
+
ReplicaHandle,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# SageMaker TrainingJobStatus -> Protocol status vocabulary.
|
| 93 |
+
# describe_training_job's TrainingJobStatus is EXACTLY one of:
|
| 94 |
+
# 'InProgress' | 'Completed' | 'Failed' | 'Stopping' | 'Stopped'.
|
| 95 |
+
# We map Stopping -> 'running' (transient; still terminating, so collect()
|
| 96 |
+
# keeps waiting) and Stopped -> 'cancelled'.
|
| 97 |
+
_STATUS_MAP = {
|
| 98 |
+
"InProgress": "running",
|
| 99 |
+
"Completed": "succeeded",
|
| 100 |
+
"Failed": "failed",
|
| 101 |
+
"Stopping": "running",
|
| 102 |
+
"Stopped": "cancelled",
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# SecondaryStatus values that mean "queued / not yet executing user code" —
|
| 106 |
+
# used to refine an InProgress job into the Protocol's 'pending'.
|
| 107 |
+
_PENDING_SECONDARY = frozenset(
|
| 108 |
+
{"Starting", "Pending", "LaunchingMLInstances", "PreparingTrainingStack"}
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Abstract Protocol GPU strings -> SageMaker instance types.
|
| 112 |
+
_GPU_INSTANCE_MAP = {
|
| 113 |
+
"A100": "ml.p4d.24xlarge",
|
| 114 |
+
"H100": "ml.p5.48xlarge",
|
| 115 |
+
"H200": "ml.p5e.48xlarge",
|
| 116 |
+
"B200": "ml.p6-b200.48xlarge",
|
| 117 |
+
"L40S": "ml.g6e.12xlarge",
|
| 118 |
+
"A10G": "ml.g5.2xlarge",
|
| 119 |
+
"L4": "ml.g6.2xlarge",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
_CLOUDWATCH_LOG_GROUP = "/aws/sagemaker/TrainingJobs"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SageMakerExecutor:
|
| 126 |
+
"""Run replicas as N independent SageMaker Training Jobs.
|
| 127 |
+
|
| 128 |
+
Implements the `ServerlessExecutor` Protocol against the boto3
|
| 129 |
+
``sagemaker`` client. Each replica is one single-instance training job;
|
| 130 |
+
cross-replica communication happens only through the shared S3
|
| 131 |
+
``ObjectStoreAllReduce`` rendezvous.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
role_arn: IAM execution role SageMaker assumes for the job. Must grant
|
| 135 |
+
S3 access to the rendezvous + output buckets (the boto3 analog of
|
| 136 |
+
EKS IRSA). The caller's credentials need ``iam:PassRole`` on it.
|
| 137 |
+
image_uri: ECR image URI for the training container. The image must
|
| 138 |
+
bake an entrypoint that runs
|
| 139 |
+
``python -m composer_replication.diloco.serverless.replica_entrypoint``
|
| 140 |
+
(this executor also passes ``ContainerEntrypoint`` explicitly so a
|
| 141 |
+
generic image works too).
|
| 142 |
+
output_s3_path: ``s3://...`` prefix for ``OutputDataConfig.S3OutputPath``
|
| 143 |
+
(model artifacts / failure output).
|
| 144 |
+
instance_type: default SageMaker instance type when ``gpu`` is not
|
| 145 |
+
mapped (e.g. ``"ml.g5.2xlarge"``). ``gpu=None`` at launch falls
|
| 146 |
+
back to ``cpu_instance_type``.
|
| 147 |
+
cpu_instance_type: instance type used when ``gpu`` is ``None`` (CPU
|
| 148 |
+
smoke tests). Default ``"ml.m5.xlarge"``.
|
| 149 |
+
volume_size_gb: ``ResourceConfig.VolumeSizeInGB`` per job.
|
| 150 |
+
run_id: prefix for generated training-job names. Defaults to a short
|
| 151 |
+
random token so names are unique per region+account.
|
| 152 |
+
region: AWS region for the lazily-constructed boto3 clients. ``None``
|
| 153 |
+
uses the ambient boto3 default-region resolution.
|
| 154 |
+
sagemaker_client: inject a pre-built ``boto3.client('sagemaker')`` (or a
|
| 155 |
+
mock) instead of constructing one. Used by tests.
|
| 156 |
+
logs_client: inject a pre-built ``boto3.client('logs')`` (or a mock).
|
| 157 |
+
|
| 158 |
+
Raises:
|
| 159 |
+
RuntimeError: if boto3 is not installed and no client was injected.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
backend_name = "sagemaker"
|
| 163 |
+
# Separate single-instance jobs have no mutual network path — S3 only.
|
| 164 |
+
supports_inter_replica_network = False
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
*,
|
| 169 |
+
role_arn: str,
|
| 170 |
+
image_uri: str,
|
| 171 |
+
output_s3_path: str,
|
| 172 |
+
instance_type: str = "ml.g5.2xlarge",
|
| 173 |
+
cpu_instance_type: str = "ml.m5.xlarge",
|
| 174 |
+
volume_size_gb: int = 100,
|
| 175 |
+
run_id: str | None = None,
|
| 176 |
+
region: str | None = None,
|
| 177 |
+
sagemaker_client: Any = None,
|
| 178 |
+
logs_client: Any = None,
|
| 179 |
+
) -> None:
|
| 180 |
+
self.role_arn = role_arn
|
| 181 |
+
self.image_uri = image_uri
|
| 182 |
+
self.output_s3_path = output_s3_path
|
| 183 |
+
self.instance_type = instance_type
|
| 184 |
+
self.cpu_instance_type = cpu_instance_type
|
| 185 |
+
self.volume_size_gb = volume_size_gb
|
| 186 |
+
self.run_id = run_id or f"diloco-{uuid.uuid4().hex[:8]}"
|
| 187 |
+
self._region = region
|
| 188 |
+
|
| 189 |
+
# Lazy boto3 — only constructed if the caller didn't inject a client.
|
| 190 |
+
# This keeps `import composer_replication.diloco.serverless` free of a
|
| 191 |
+
# hard boto3 dependency (boto3 lives in the optional [aws] extra), and
|
| 192 |
+
# lets tests inject a _MockSMClient with zero AWS calls.
|
| 193 |
+
if sagemaker_client is None:
|
| 194 |
+
sagemaker_client = self._make_boto3_client("sagemaker")
|
| 195 |
+
self._client = sagemaker_client
|
| 196 |
+
self._logs_client = logs_client # built lazily on first stream_logs()
|
| 197 |
+
|
| 198 |
+
# rank -> {"job_name": str, "result": dict | None}
|
| 199 |
+
self._handles: dict[int, dict[str, Any]] = {}
|
| 200 |
+
|
| 201 |
+
# -----------------------------------------------------------------
|
| 202 |
+
# boto3 plumbing (lazy)
|
| 203 |
+
# -----------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
def _make_boto3_client(self, service: str) -> Any:
|
| 206 |
+
try:
|
| 207 |
+
import boto3
|
| 208 |
+
except ImportError as e:
|
| 209 |
+
raise RuntimeError(
|
| 210 |
+
"SageMakerExecutor requires boto3. Install with "
|
| 211 |
+
"`pip install -e .[aws]` (or `pip install boto3`). "
|
| 212 |
+
f"Got: {e!r}"
|
| 213 |
+
) from e
|
| 214 |
+
if self._region is not None:
|
| 215 |
+
return boto3.client(service, region_name=self._region)
|
| 216 |
+
return boto3.client(service)
|
| 217 |
+
|
| 218 |
+
def _map_gpu(self, gpu: str | None) -> str:
|
| 219 |
+
"""Translate the Protocol's abstract gpu string to an instance type.
|
| 220 |
+
|
| 221 |
+
``gpu=None`` -> ``cpu_instance_type`` (smoke tests). Unrecognized gpu
|
| 222 |
+
strings fall back to ``instance_type`` (so a caller can pass a literal
|
| 223 |
+
SageMaker instance type and it's honoured if not in the map).
|
| 224 |
+
"""
|
| 225 |
+
if gpu is None:
|
| 226 |
+
return self.cpu_instance_type
|
| 227 |
+
if gpu in _GPU_INSTANCE_MAP:
|
| 228 |
+
return _GPU_INSTANCE_MAP[gpu]
|
| 229 |
+
# Caller may have passed a literal "ml.*" instance type.
|
| 230 |
+
if gpu.startswith("ml."):
|
| 231 |
+
return gpu
|
| 232 |
+
return self.instance_type
|
| 233 |
+
|
| 234 |
+
def _job_name(self, rank: int) -> str:
|
| 235 |
+
"""Build a unique, regex-safe training-job name (<= 63 chars).
|
| 236 |
+
|
| 237 |
+
Pattern required by the API: ``[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}``.
|
| 238 |
+
"""
|
| 239 |
+
name = f"{self.run_id}-r{rank:04d}-{int(time.time())}"
|
| 240 |
+
return name[:63]
|
| 241 |
+
|
| 242 |
+
# -----------------------------------------------------------------
|
| 243 |
+
# ServerlessExecutor Protocol
|
| 244 |
+
# -----------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
def launch_replicas(
|
| 247 |
+
self,
|
| 248 |
+
n_replicas: int,
|
| 249 |
+
entrypoint: str | Callable[..., Any],
|
| 250 |
+
entrypoint_args: Mapping[str, Any],
|
| 251 |
+
*,
|
| 252 |
+
gpu: str | None = None,
|
| 253 |
+
timeout: int = 3600,
|
| 254 |
+
) -> list[ReplicaHandle]:
|
| 255 |
+
"""Submit N independent single-instance SageMaker Training Jobs.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
n_replicas: number of replicas (= number of training jobs).
|
| 259 |
+
entrypoint: ignored — the container command is baked into the
|
| 260 |
+
image / passed as ``ContainerEntrypoint``. Kept for Protocol
|
| 261 |
+
compatibility.
|
| 262 |
+
entrypoint_args: must contain ``rendezvous_uri`` (``s3://...``) and
|
| 263 |
+
``trainer_module``. Optional: ``trainer_fn`` (default
|
| 264 |
+
``"train"``), ``trainer_kwargs`` (dict, JSON-encoded into the
|
| 265 |
+
container args). The conventional ``rank_env`` key (from
|
| 266 |
+
``LocalProcessExecutor``) is ignored — rank goes through the
|
| 267 |
+
``Environment`` map instead.
|
| 268 |
+
gpu: abstract GPU spec mapped to an instance type via ``_map_gpu``.
|
| 269 |
+
``None`` -> CPU instance.
|
| 270 |
+
timeout: ``StoppingCondition.MaxRuntimeInSeconds`` per job.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
``list[ReplicaHandle]`` of length ``n_replicas`` in rank order
|
| 274 |
+
(``handles[i].rank == i``).
|
| 275 |
+
"""
|
| 276 |
+
del entrypoint # container command is baked / passed explicitly
|
| 277 |
+
|
| 278 |
+
if n_replicas < 1:
|
| 279 |
+
raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
|
| 280 |
+
|
| 281 |
+
rendezvous_uri = entrypoint_args.get("rendezvous_uri")
|
| 282 |
+
if not rendezvous_uri:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
"entrypoint_args must include 'rendezvous_uri' (the s3:// "
|
| 285 |
+
"ObjectStoreAllReduce rendezvous prefix)."
|
| 286 |
+
)
|
| 287 |
+
trainer_module = entrypoint_args.get("trainer_module")
|
| 288 |
+
if not trainer_module:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
"entrypoint_args must include 'trainer_module' (importable "
|
| 291 |
+
"module path of the user's train function)."
|
| 292 |
+
)
|
| 293 |
+
trainer_fn = entrypoint_args.get("trainer_fn", "train")
|
| 294 |
+
trainer_kwargs = entrypoint_args.get("trainer_kwargs", {})
|
| 295 |
+
|
| 296 |
+
instance_type = self._map_gpu(gpu)
|
| 297 |
+
|
| 298 |
+
# Container args: each element is a SINGLE token (StackOverflow
|
| 299 |
+
# 77994925 — `['--world-size', '4']` NOT `['--world-size 4']`).
|
| 300 |
+
container_args = [
|
| 301 |
+
"--rendezvous", str(rendezvous_uri),
|
| 302 |
+
"--world-size", str(n_replicas),
|
| 303 |
+
"--trainer-module", str(trainer_module),
|
| 304 |
+
"--trainer-fn", str(trainer_fn),
|
| 305 |
+
"--trainer-kwargs-json", json.dumps(trainer_kwargs),
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
handles: list[ReplicaHandle] = []
|
| 309 |
+
for rank in range(n_replicas):
|
| 310 |
+
job_name = self._job_name(rank)
|
| 311 |
+
request = {
|
| 312 |
+
"TrainingJobName": job_name,
|
| 313 |
+
"AlgorithmSpecification": {
|
| 314 |
+
"TrainingImage": self.image_uri,
|
| 315 |
+
"TrainingInputMode": "File",
|
| 316 |
+
"ContainerEntrypoint": [
|
| 317 |
+
"python", "-m",
|
| 318 |
+
"composer_replication.diloco.serverless.replica_entrypoint",
|
| 319 |
+
],
|
| 320 |
+
"ContainerArguments": container_args,
|
| 321 |
+
},
|
| 322 |
+
"RoleArn": self.role_arn,
|
| 323 |
+
# InputDataConfig intentionally omitted — the replica pulls
|
| 324 |
+
# data via its own code / the S3 rendezvous, not SM channels.
|
| 325 |
+
"OutputDataConfig": {"S3OutputPath": self.output_s3_path},
|
| 326 |
+
"ResourceConfig": {
|
| 327 |
+
"InstanceType": instance_type,
|
| 328 |
+
"InstanceCount": 1,
|
| 329 |
+
"VolumeSizeInGB": self.volume_size_gb,
|
| 330 |
+
},
|
| 331 |
+
"StoppingCondition": {"MaxRuntimeInSeconds": int(timeout)},
|
| 332 |
+
# REPLICA_RANK / WORLD_SIZE injected as container env vars;
|
| 333 |
+
# replica_entrypoint.py reads os.environ['REPLICA_RANK'].
|
| 334 |
+
"Environment": {
|
| 335 |
+
"REPLICA_RANK": str(rank),
|
| 336 |
+
"WORLD_SIZE": str(n_replicas),
|
| 337 |
+
"RENDEZVOUS_URI": str(rendezvous_uri),
|
| 338 |
+
},
|
| 339 |
+
# MUST stay False — True severs the container's S3 access and
|
| 340 |
+
# dead-locks the allreduce poll loop. See module docstring.
|
| 341 |
+
"EnableNetworkIsolation": False,
|
| 342 |
+
}
|
| 343 |
+
try:
|
| 344 |
+
self._client.create_training_job(**request)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
# Best-effort stop of already-launched siblings, then raise.
|
| 347 |
+
for prior in handles:
|
| 348 |
+
try:
|
| 349 |
+
self.cancel(prior)
|
| 350 |
+
except Exception:
|
| 351 |
+
pass
|
| 352 |
+
raise RuntimeError(
|
| 353 |
+
f"SageMakerExecutor.launch_replicas failed at rank={rank} "
|
| 354 |
+
f"of {n_replicas} (already-launched siblings stopped). "
|
| 355 |
+
f"Underlying error: {e!r}"
|
| 356 |
+
) from e
|
| 357 |
+
|
| 358 |
+
handle = ReplicaHandle(
|
| 359 |
+
rank=rank,
|
| 360 |
+
backend_name=self.backend_name,
|
| 361 |
+
metadata={
|
| 362 |
+
"training_job_name": job_name,
|
| 363 |
+
"submit_ts": time.time(),
|
| 364 |
+
},
|
| 365 |
+
)
|
| 366 |
+
self._handles[rank] = {"job_name": job_name, "result": None}
|
| 367 |
+
handles.append(handle)
|
| 368 |
+
|
| 369 |
+
return handles
|
| 370 |
+
|
| 371 |
+
def poll(self, handle: ReplicaHandle) -> str:
|
| 372 |
+
"""Poll a training job's status.
|
| 373 |
+
|
| 374 |
+
Returns one of: ``"pending"`` | ``"running"`` | ``"succeeded"`` |
|
| 375 |
+
``"failed"`` | ``"cancelled"``.
|
| 376 |
+
|
| 377 |
+
Maps ``describe_training_job``'s ``TrainingJobStatus`` via
|
| 378 |
+
``_STATUS_MAP``; refines ``InProgress`` to ``"pending"`` while the job
|
| 379 |
+
is still queued (``SecondaryStatus`` in ``_PENDING_SECONDARY``). A
|
| 380 |
+
vanished job (``ResourceNotFound``) is treated as ``"cancelled"``.
|
| 381 |
+
"""
|
| 382 |
+
meta = self._handles.get(handle.rank)
|
| 383 |
+
if meta is None:
|
| 384 |
+
return "cancelled"
|
| 385 |
+
if meta["result"] is not None:
|
| 386 |
+
return meta["result"]["status"]
|
| 387 |
+
|
| 388 |
+
job_name = meta["job_name"]
|
| 389 |
+
try:
|
| 390 |
+
resp = self._client.describe_training_job(TrainingJobName=job_name)
|
| 391 |
+
except Exception as e:
|
| 392 |
+
if self._is_resource_not_found(e):
|
| 393 |
+
return "cancelled"
|
| 394 |
+
raise
|
| 395 |
+
|
| 396 |
+
sm_status = resp.get("TrainingJobStatus", "InProgress")
|
| 397 |
+
mapped = _STATUS_MAP.get(sm_status, "running")
|
| 398 |
+
|
| 399 |
+
if sm_status == "InProgress":
|
| 400 |
+
if resp.get("SecondaryStatus") in _PENDING_SECONDARY:
|
| 401 |
+
return "pending"
|
| 402 |
+
return "running"
|
| 403 |
+
|
| 404 |
+
# Terminal — cache a result dict so collect()/repeat-poll are cheap.
|
| 405 |
+
meta["result"] = self._terminal_result(handle.rank, sm_status, resp)
|
| 406 |
+
return mapped
|
| 407 |
+
|
| 408 |
+
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
|
| 409 |
+
"""Read recent CloudWatch logs for this replica's training job.
|
| 410 |
+
|
| 411 |
+
SageMaker writes container stdout/stderr to the
|
| 412 |
+
``/aws/sagemaker/TrainingJobs`` log group, stream
|
| 413 |
+
``<job-name>/algo-<n>-<epoch>``. We discover the exact stream name by
|
| 414 |
+
prefix then read the tail. Falls back to a CloudWatch console pointer
|
| 415 |
+
on any error (mirrors ModalSpawnExecutor's dashboard-URL fallback).
|
| 416 |
+
"""
|
| 417 |
+
meta = self._handles.get(handle.rank)
|
| 418 |
+
if meta is None:
|
| 419 |
+
return f"<replica {handle.rank}: no metadata>"
|
| 420 |
+
job_name = meta["job_name"]
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
logs = self._logs()
|
| 424 |
+
prefix = f"{job_name}/"
|
| 425 |
+
streams = logs.describe_log_streams(
|
| 426 |
+
logGroupName=_CLOUDWATCH_LOG_GROUP,
|
| 427 |
+
logStreamNamePrefix=prefix,
|
| 428 |
+
orderBy="LastEventTime",
|
| 429 |
+
descending=True,
|
| 430 |
+
limit=1,
|
| 431 |
+
)
|
| 432 |
+
stream_list = streams.get("logStreams", [])
|
| 433 |
+
if not stream_list:
|
| 434 |
+
return (
|
| 435 |
+
f"[rank {handle.rank}] job={job_name}: no CloudWatch log "
|
| 436 |
+
f"stream yet (job pending / not started)."
|
| 437 |
+
)
|
| 438 |
+
stream_name = stream_list[0]["logStreamName"]
|
| 439 |
+
events = logs.get_log_events(
|
| 440 |
+
logGroupName=_CLOUDWATCH_LOG_GROUP,
|
| 441 |
+
logStreamName=stream_name,
|
| 442 |
+
limit=n_lines,
|
| 443 |
+
startFromHead=False,
|
| 444 |
+
)
|
| 445 |
+
lines = [e.get("message", "") for e in events.get("events", [])]
|
| 446 |
+
body = "\n".join(lines) if lines else "<no log events>"
|
| 447 |
+
return f"[rank {handle.rank}] job={job_name} stream={stream_name}\n{body}"
|
| 448 |
+
except Exception as e:
|
| 449 |
+
region = self._region or "<region>"
|
| 450 |
+
url = (
|
| 451 |
+
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
|
| 452 |
+
f"?region={region}#logsV2:log-groups/log-group/"
|
| 453 |
+
f"$252Faws$252Fsagemaker$252FTrainingJobs"
|
| 454 |
+
)
|
| 455 |
+
return (
|
| 456 |
+
f"[rank {handle.rank}] job={job_name}: log fetch failed "
|
| 457 |
+
f"({type(e).__name__}: {e!r}).\n CloudWatch console: {url}"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def cancel(self, handle: ReplicaHandle) -> None:
|
| 461 |
+
"""Best-effort stop of a training job.
|
| 462 |
+
|
| 463 |
+
Calls ``stop_training_job`` (SIGTERM + 120s grace), swallowing
|
| 464 |
+
``ResourceNotFound`` and "already terminal" ``ValidationException`` so
|
| 465 |
+
the contract — "no exception if already terminated" — holds.
|
| 466 |
+
"""
|
| 467 |
+
meta = self._handles.get(handle.rank)
|
| 468 |
+
if meta is None:
|
| 469 |
+
return
|
| 470 |
+
try:
|
| 471 |
+
self._client.stop_training_job(TrainingJobName=meta["job_name"])
|
| 472 |
+
except Exception:
|
| 473 |
+
# ResourceNotFound, already-Completed/Stopped ValidationException,
|
| 474 |
+
# transient network blip — all best-effort no-ops.
|
| 475 |
+
pass
|
| 476 |
+
|
| 477 |
+
def collect(
|
| 478 |
+
self,
|
| 479 |
+
handles: list[ReplicaHandle],
|
| 480 |
+
*,
|
| 481 |
+
timeout: int | None = None,
|
| 482 |
+
) -> list[dict[str, Any]]:
|
| 483 |
+
"""Block until all replicas finish; return per-replica result dicts.
|
| 484 |
+
|
| 485 |
+
Polls ``describe_training_job`` per handle until the job reaches a
|
| 486 |
+
terminal status (``Completed`` / ``Failed`` / ``Stopped``) or the
|
| 487 |
+
shared deadline elapses. Returns results aligned to the input handle
|
| 488 |
+
order (Protocol contract; mirrors ``LocalProcessExecutor.collect``).
|
| 489 |
+
|
| 490 |
+
Each result dict has at least
|
| 491 |
+
``{"rank", "status", "exit_code", "error"}``.
|
| 492 |
+
"""
|
| 493 |
+
deadline = time.time() + (timeout if timeout is not None else 86400)
|
| 494 |
+
poll_interval = 30.0
|
| 495 |
+
results: list[dict[str, Any]] = []
|
| 496 |
+
|
| 497 |
+
for h in handles:
|
| 498 |
+
meta = self._handles.get(h.rank)
|
| 499 |
+
if meta is None:
|
| 500 |
+
results.append({
|
| 501 |
+
"rank": h.rank,
|
| 502 |
+
"status": "cancelled",
|
| 503 |
+
"exit_code": None,
|
| 504 |
+
"error": "handle has no metadata (cancelled or unknown)",
|
| 505 |
+
"result": None,
|
| 506 |
+
"training_job_name": h.metadata.get("training_job_name"),
|
| 507 |
+
})
|
| 508 |
+
continue
|
| 509 |
+
|
| 510 |
+
# Already cached by an earlier poll()/collect().
|
| 511 |
+
if meta["result"] is not None:
|
| 512 |
+
results.append(meta["result"])
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
job_name = meta["job_name"]
|
| 516 |
+
result_dict: dict[str, Any] | None = None
|
| 517 |
+
while True:
|
| 518 |
+
try:
|
| 519 |
+
resp = self._client.describe_training_job(
|
| 520 |
+
TrainingJobName=job_name
|
| 521 |
+
)
|
| 522 |
+
except Exception as e:
|
| 523 |
+
if self._is_resource_not_found(e):
|
| 524 |
+
result_dict = {
|
| 525 |
+
"rank": h.rank,
|
| 526 |
+
"status": "cancelled",
|
| 527 |
+
"exit_code": None,
|
| 528 |
+
"error": "training job not found (deleted?)",
|
| 529 |
+
"result": None,
|
| 530 |
+
"training_job_name": job_name,
|
| 531 |
+
}
|
| 532 |
+
break
|
| 533 |
+
raise
|
| 534 |
+
|
| 535 |
+
sm_status = resp.get("TrainingJobStatus", "InProgress")
|
| 536 |
+
if sm_status in ("Completed", "Failed", "Stopped"):
|
| 537 |
+
result_dict = self._terminal_result(h.rank, sm_status, resp)
|
| 538 |
+
break
|
| 539 |
+
|
| 540 |
+
if time.time() >= deadline:
|
| 541 |
+
result_dict = {
|
| 542 |
+
"rank": h.rank,
|
| 543 |
+
"status": "running",
|
| 544 |
+
"exit_code": None,
|
| 545 |
+
"error": "timeout before terminal",
|
| 546 |
+
"result": None,
|
| 547 |
+
"training_job_name": job_name,
|
| 548 |
+
}
|
| 549 |
+
break
|
| 550 |
+
|
| 551 |
+
# Sleep, but never overrun the deadline.
|
| 552 |
+
time.sleep(min(poll_interval, max(0.0, deadline - time.time())))
|
| 553 |
+
|
| 554 |
+
# Cache only terminal results (not the timeout 'running' sentinel,
|
| 555 |
+
# so a later collect() can re-check the job).
|
| 556 |
+
if result_dict["status"] in ("succeeded", "failed", "cancelled"):
|
| 557 |
+
meta["result"] = result_dict
|
| 558 |
+
results.append(result_dict)
|
| 559 |
+
|
| 560 |
+
return results
|
| 561 |
+
|
| 562 |
+
# -----------------------------------------------------------------
|
| 563 |
+
# Helpers
|
| 564 |
+
# -----------------------------------------------------------------
|
| 565 |
+
|
| 566 |
+
def _logs(self) -> Any:
|
| 567 |
+
"""Lazily build the CloudWatch Logs client (separate from sagemaker)."""
|
| 568 |
+
if self._logs_client is None:
|
| 569 |
+
self._logs_client = self._make_boto3_client("logs")
|
| 570 |
+
return self._logs_client
|
| 571 |
+
|
| 572 |
+
@staticmethod
|
| 573 |
+
def _terminal_result(
|
| 574 |
+
rank: int, sm_status: str, resp: Mapping[str, Any]
|
| 575 |
+
) -> dict[str, Any]:
|
| 576 |
+
"""Build a result dict from a terminal describe_training_job response."""
|
| 577 |
+
mapped = _STATUS_MAP.get(sm_status, "failed")
|
| 578 |
+
if sm_status == "Completed":
|
| 579 |
+
exit_code: int | None = 0
|
| 580 |
+
error = None
|
| 581 |
+
elif sm_status == "Stopped":
|
| 582 |
+
exit_code = None
|
| 583 |
+
error = resp.get("FailureReason")
|
| 584 |
+
else: # Failed
|
| 585 |
+
exit_code = 1
|
| 586 |
+
error = resp.get("FailureReason") or "training job failed"
|
| 587 |
+
artifacts = resp.get("ModelArtifacts", {}) or {}
|
| 588 |
+
return {
|
| 589 |
+
"rank": rank,
|
| 590 |
+
"status": mapped,
|
| 591 |
+
"exit_code": exit_code,
|
| 592 |
+
"error": error,
|
| 593 |
+
"result": artifacts.get("S3ModelArtifacts"),
|
| 594 |
+
"training_job_name": resp.get("TrainingJobName"),
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
def _is_resource_not_found(self, exc: Exception) -> bool:
|
| 598 |
+
"""True if ``exc`` is the boto3 ResourceNotFound for the sagemaker client.
|
| 599 |
+
|
| 600 |
+
Handles both the typed client exception
|
| 601 |
+
(``client.exceptions.ResourceNotFound``) and a generic botocore
|
| 602 |
+
``ClientError`` whose error code is ``ResourceNotFound`` /
|
| 603 |
+
``ValidationException`` naming a missing job — robust across whether a
|
| 604 |
+
real boto3 client or a mock is in use.
|
| 605 |
+
"""
|
| 606 |
+
rnf = getattr(getattr(self._client, "exceptions", None),
|
| 607 |
+
"ResourceNotFound", None)
|
| 608 |
+
if rnf is not None and isinstance(exc, rnf):
|
| 609 |
+
return True
|
| 610 |
+
# Generic botocore ClientError fallback.
|
| 611 |
+
resp = getattr(exc, "response", None)
|
| 612 |
+
if isinstance(resp, Mapping):
|
| 613 |
+
code = resp.get("Error", {}).get("Code", "")
|
| 614 |
+
if code in ("ResourceNotFound", "ValidationException"):
|
| 615 |
+
return True
|
| 616 |
+
return False
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
__all__ = ["SageMakerExecutor"]
|
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for EKSExecutor — the Kubernetes Indexed-Job-backed executor.
|
| 2 |
+
|
| 3 |
+
These tests exercise the executor's contract WITHOUT a live cluster and
|
| 4 |
+
WITHOUT the `kubernetes` client actually being installed. They:
|
| 5 |
+
|
| 6 |
+
* inject a fake `kubernetes` module into ``sys.modules`` so the executor's
|
| 7 |
+
lazy ``from kubernetes import client`` / ``...client.exceptions`` calls
|
| 8 |
+
resolve to recording stand-in V1* model classes (this is the k8s analogue
|
| 9 |
+
of the modal test's ``_MockFunctionCall``), and
|
| 10 |
+
* pass mock ``batch_api`` / ``core_api`` via dependency injection (the
|
| 11 |
+
constructor's ``batch_api=`` / ``core_api=`` args), so no config loading or
|
| 12 |
+
cluster contact happens.
|
| 13 |
+
|
| 14 |
+
For real-cluster integration testing you would gate behind cluster
|
| 15 |
+
availability (e.g. ``config.load_kube_config()`` succeeding), exactly like
|
| 16 |
+
``test_modal_spawn_executor.py`` gates on ``_is_modal_installed()``.
|
| 17 |
+
|
| 18 |
+
Run: ``.venv/bin/python -m pytest <thisfile> -q``
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
import types
|
| 24 |
+
|
| 25 |
+
import pytest
|
| 26 |
+
|
| 27 |
+
from composer_replication.diloco.serverless import EKSExecutor, ReplicaHandle
|
| 28 |
+
from composer_replication.diloco.serverless.eks import _expand_indexes
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------
|
| 31 |
+
# Fake `kubernetes` module — recording V1* model stand-ins + ApiException
|
| 32 |
+
# ---------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class _Rec:
|
| 36 |
+
"""Generic recording model: stores all ctor kwargs as attributes.
|
| 37 |
+
|
| 38 |
+
Stands in for the kubernetes client's ``V1*`` model classes (V1Job,
|
| 39 |
+
V1JobSpec, V1Container, V1EnvVar, ...). Every attr the executor sets is
|
| 40 |
+
inspectable by tests. Mirrors how the modal mock records ``.spawn`` args.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, **kwargs):
|
| 44 |
+
# Default the common optional model fields to None so attribute
|
| 45 |
+
# access in assertions never raises AttributeError.
|
| 46 |
+
for k, v in kwargs.items():
|
| 47 |
+
setattr(self, k, v)
|
| 48 |
+
|
| 49 |
+
def __getattr__(self, name): # only called when attr is genuinely absent
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class _ApiException(Exception): # noqa: N818 — mirrors kubernetes.client.exceptions.ApiException name
|
| 54 |
+
"""Stand-in for kubernetes.client.exceptions.ApiException."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, status=None, reason=None, body=None):
|
| 57 |
+
super().__init__(f"ApiException(status={status})")
|
| 58 |
+
self.status = status
|
| 59 |
+
self.reason = reason
|
| 60 |
+
self.body = body
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# The set of V1* names the executor constructs. Each maps to _Rec.
|
| 64 |
+
_V1_NAMES = [
|
| 65 |
+
"V1Job",
|
| 66 |
+
"V1JobSpec",
|
| 67 |
+
"V1ObjectMeta",
|
| 68 |
+
"V1PodTemplateSpec",
|
| 69 |
+
"V1PodSpec",
|
| 70 |
+
"V1Container",
|
| 71 |
+
"V1EnvVar",
|
| 72 |
+
"V1EnvVarSource",
|
| 73 |
+
"V1ObjectFieldSelector",
|
| 74 |
+
"V1ResourceRequirements",
|
| 75 |
+
"V1Toleration",
|
| 76 |
+
"V1DeleteOptions",
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@pytest.fixture
|
| 81 |
+
def fake_kubernetes(monkeypatch):
|
| 82 |
+
"""Install a fake `kubernetes` package into sys.modules for the test.
|
| 83 |
+
|
| 84 |
+
Provides:
|
| 85 |
+
- kubernetes.client.<V1*> -> recording _Rec classes
|
| 86 |
+
- kubernetes.client.exceptions.ApiException
|
| 87 |
+
- kubernetes.client.BatchV1Api / CoreV1Api (unused — apis are injected)
|
| 88 |
+
- kubernetes.config.load_incluster_config / load_kube_config / ConfigException
|
| 89 |
+
"""
|
| 90 |
+
kubernetes = types.ModuleType("kubernetes")
|
| 91 |
+
client = types.ModuleType("kubernetes.client")
|
| 92 |
+
exceptions = types.ModuleType("kubernetes.client.exceptions")
|
| 93 |
+
config = types.ModuleType("kubernetes.config")
|
| 94 |
+
|
| 95 |
+
for name in _V1_NAMES:
|
| 96 |
+
setattr(client, name, _Rec)
|
| 97 |
+
|
| 98 |
+
# Default api classes (only hit if NOT injected — we always inject).
|
| 99 |
+
client.BatchV1Api = lambda *a, **k: pytest.fail("BatchV1Api should be injected")
|
| 100 |
+
client.CoreV1Api = lambda *a, **k: pytest.fail("CoreV1Api should be injected")
|
| 101 |
+
|
| 102 |
+
exceptions.ApiException = _ApiException
|
| 103 |
+
client.exceptions = exceptions
|
| 104 |
+
|
| 105 |
+
class _ConfigException(Exception): # noqa: N818 — mirrors kubernetes.config.ConfigException name
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
config.ConfigException = _ConfigException
|
| 109 |
+
config.load_incluster_config = lambda *a, **k: (_ for _ in ()).throw(
|
| 110 |
+
_ConfigException("not in cluster")
|
| 111 |
+
)
|
| 112 |
+
config.load_kube_config = lambda *a, **k: None
|
| 113 |
+
|
| 114 |
+
kubernetes.client = client
|
| 115 |
+
kubernetes.config = config
|
| 116 |
+
|
| 117 |
+
monkeypatch.setitem(sys.modules, "kubernetes", kubernetes)
|
| 118 |
+
monkeypatch.setitem(sys.modules, "kubernetes.client", client)
|
| 119 |
+
monkeypatch.setitem(sys.modules, "kubernetes.client.exceptions", exceptions)
|
| 120 |
+
monkeypatch.setitem(sys.modules, "kubernetes.config", config)
|
| 121 |
+
return kubernetes
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------
|
| 125 |
+
# Mock BatchV1Api / CoreV1Api (the _MockBatchV1 the task asks for)
|
| 126 |
+
# ---------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class _MockBatchV1Api:
|
| 130 |
+
"""Records create/read-status/delete calls; returns a settable status."""
|
| 131 |
+
|
| 132 |
+
def __init__(self):
|
| 133 |
+
self.created_jobs: list[tuple[str, object]] = []
|
| 134 |
+
self.delete_calls: list[dict] = []
|
| 135 |
+
# status object returned by read_namespaced_job_status().status
|
| 136 |
+
self.status_obj = _Rec(
|
| 137 |
+
active=None,
|
| 138 |
+
succeeded=None,
|
| 139 |
+
failed=None,
|
| 140 |
+
completed_indexes=None,
|
| 141 |
+
failed_indexes=None,
|
| 142 |
+
conditions=None,
|
| 143 |
+
)
|
| 144 |
+
# Optional: raise this ApiException on read (e.g. 404 -> cancelled)
|
| 145 |
+
self.read_raises: Exception | None = None
|
| 146 |
+
|
| 147 |
+
def create_namespaced_job(self, namespace, body):
|
| 148 |
+
self.created_jobs.append((namespace, body))
|
| 149 |
+
return body
|
| 150 |
+
|
| 151 |
+
def read_namespaced_job_status(self, name, namespace):
|
| 152 |
+
if self.read_raises is not None:
|
| 153 |
+
raise self.read_raises
|
| 154 |
+
return _Rec(status=self.status_obj)
|
| 155 |
+
|
| 156 |
+
def delete_namespaced_job(self, name, namespace, body=None):
|
| 157 |
+
self.delete_calls.append(
|
| 158 |
+
{
|
| 159 |
+
"name": name,
|
| 160 |
+
"namespace": namespace,
|
| 161 |
+
"propagation_policy": getattr(body, "propagation_policy", None),
|
| 162 |
+
"grace_period_seconds": getattr(body, "grace_period_seconds", None),
|
| 163 |
+
}
|
| 164 |
+
)
|
| 165 |
+
return _Rec(status="Success")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class _MockCoreV1Api:
|
| 169 |
+
"""Canned list_namespaced_pod + read_namespaced_pod_log."""
|
| 170 |
+
|
| 171 |
+
def __init__(self, pods=None, logs="line1\nline2\n"):
|
| 172 |
+
self._pods = pods if pods is not None else []
|
| 173 |
+
self._logs = logs
|
| 174 |
+
self.log_calls: list[dict] = []
|
| 175 |
+
self.list_calls: list[dict] = []
|
| 176 |
+
self.log_raises: Exception | None = None
|
| 177 |
+
|
| 178 |
+
def list_namespaced_pod(self, namespace, label_selector=None):
|
| 179 |
+
self.list_calls.append({"namespace": namespace, "label_selector": label_selector})
|
| 180 |
+
return _Rec(items=list(self._pods))
|
| 181 |
+
|
| 182 |
+
def read_namespaced_pod_log(self, name, namespace, container=None, tail_lines=None):
|
| 183 |
+
self.log_calls.append(
|
| 184 |
+
{
|
| 185 |
+
"name": name,
|
| 186 |
+
"namespace": namespace,
|
| 187 |
+
"container": container,
|
| 188 |
+
"tail_lines": tail_lines,
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
if self.log_raises is not None:
|
| 192 |
+
raise self.log_raises
|
| 193 |
+
return self._logs
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _make_pod(name, rank):
|
| 197 |
+
"""Build a fake pod with the completion-index annotation set."""
|
| 198 |
+
return _Rec(
|
| 199 |
+
metadata=_Rec(
|
| 200 |
+
name=name,
|
| 201 |
+
annotations={"batch.kubernetes.io/job-completion-index": str(rank)},
|
| 202 |
+
labels={"job-name": name.rsplit("-", 2)[0]},
|
| 203 |
+
),
|
| 204 |
+
status=_Rec(phase="Running"),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _make_executor(fake_kubernetes, *, batch=None, core=None, **kwargs):
|
| 209 |
+
batch = batch or _MockBatchV1Api()
|
| 210 |
+
core = core or _MockCoreV1Api()
|
| 211 |
+
ex = EKSExecutor(
|
| 212 |
+
image="myrepo/composer-replica:latest",
|
| 213 |
+
batch_api=batch,
|
| 214 |
+
core_api=core,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
# Speed up collect() loops in tests.
|
| 218 |
+
ex._collect_poll_interval = lambda: 0.0
|
| 219 |
+
return ex, batch, core
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ---------------------------------------------------------------------
|
| 223 |
+
# _expand_indexes — the run-length-range parser
|
| 224 |
+
# ---------------------------------------------------------------------
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def test_expand_indexes_singletons_and_ranges():
|
| 228 |
+
assert _expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}
|
| 229 |
+
assert _expand_indexes("0") == {0}
|
| 230 |
+
assert _expand_indexes("0-3") == {0, 1, 2, 3}
|
| 231 |
+
assert _expand_indexes("") == set()
|
| 232 |
+
assert _expand_indexes(None) == set()
|
| 233 |
+
# Reversed range is tolerated.
|
| 234 |
+
assert _expand_indexes("5-3") == {3, 4, 5}
|
| 235 |
+
# Whitespace / junk tolerated.
|
| 236 |
+
assert _expand_indexes(" 2 , 4-6 ") == {2, 4, 5, 6}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ---------------------------------------------------------------------
|
| 240 |
+
# Construction / preconditions
|
| 241 |
+
# ---------------------------------------------------------------------
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def test_missing_kubernetes_raises_runtime_error_when_no_api_injected():
|
| 245 |
+
"""With kubernetes absent AND no injected api, ctor must raise clearly.
|
| 246 |
+
|
| 247 |
+
The import-guard path can ONLY be exercised when `kubernetes` is genuinely
|
| 248 |
+
not importable in this interpreter. When it IS installed (e.g. via the
|
| 249 |
+
`[eks]`/`[serverless]` extra in CI), the lazy import succeeds and the ctor
|
| 250 |
+
legitimately does not raise — so skip rather than assert a false precondition.
|
| 251 |
+
"""
|
| 252 |
+
import importlib.util
|
| 253 |
+
|
| 254 |
+
if importlib.util.find_spec("kubernetes") is not None:
|
| 255 |
+
pytest.skip("kubernetes is importable in this interpreter; the absent-path cannot be exercised")
|
| 256 |
+
with pytest.raises(RuntimeError, match="kubernetes"):
|
| 257 |
+
EKSExecutor(image="x")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def test_construction_with_injected_apis_does_not_need_kubernetes():
|
| 261 |
+
"""When both apis are injected, ctor must not require the kubernetes import."""
|
| 262 |
+
batch = _MockBatchV1Api()
|
| 263 |
+
core = _MockCoreV1Api()
|
| 264 |
+
ex = EKSExecutor(image="img", batch_api=batch, core_api=core)
|
| 265 |
+
assert ex.backend_name == "eks"
|
| 266 |
+
assert ex.supports_inter_replica_network is False
|
| 267 |
+
assert ex.image == "img"
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ---------------------------------------------------------------------
|
| 271 |
+
# launch_replicas — N handles, indexed-job spec correctness
|
| 272 |
+
# ---------------------------------------------------------------------
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def test_launch_returns_n_rank_ordered_handles(fake_kubernetes):
|
| 276 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 277 |
+
handles = ex.launch_replicas(
|
| 278 |
+
n_replicas=4,
|
| 279 |
+
entrypoint="ignored",
|
| 280 |
+
entrypoint_args={"rendezvous_uri": "s3://b/run42/", "world_size": 4},
|
| 281 |
+
)
|
| 282 |
+
assert len(handles) == 4
|
| 283 |
+
for i, h in enumerate(handles):
|
| 284 |
+
assert isinstance(h, ReplicaHandle)
|
| 285 |
+
assert h.rank == i
|
| 286 |
+
assert h.backend_name == "eks"
|
| 287 |
+
assert h.metadata["rank"] == i
|
| 288 |
+
# ALL handles share the same job_name / namespace (gang).
|
| 289 |
+
assert h.metadata["job_name"] == handles[0].metadata["job_name"]
|
| 290 |
+
assert h.metadata["namespace"] == "default"
|
| 291 |
+
|
| 292 |
+
# Exactly ONE job was created (single Indexed Job topology).
|
| 293 |
+
assert len(batch.created_jobs) == 1
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def test_launch_creates_indexed_job_spec(fake_kubernetes):
|
| 297 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 298 |
+
ex.launch_replicas(
|
| 299 |
+
n_replicas=3,
|
| 300 |
+
entrypoint="ignored",
|
| 301 |
+
entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 3},
|
| 302 |
+
)
|
| 303 |
+
ns, job = batch.created_jobs[0]
|
| 304 |
+
assert ns == "default"
|
| 305 |
+
assert job.api_version == "batch/v1"
|
| 306 |
+
assert job.kind == "Job"
|
| 307 |
+
spec = job.spec
|
| 308 |
+
assert spec.completions == 3
|
| 309 |
+
assert spec.parallelism == 3
|
| 310 |
+
assert spec.completion_mode == "Indexed"
|
| 311 |
+
assert spec.backoff_limit == 0
|
| 312 |
+
assert spec.ttl_seconds_after_finished == 3600
|
| 313 |
+
# active_deadline_seconds == timeout (default 3600 here).
|
| 314 |
+
assert spec.active_deadline_seconds == 3600
|
| 315 |
+
# restart_policy Never (required for Indexed jobs).
|
| 316 |
+
assert spec.template.spec.restart_policy == "Never"
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def test_launch_rank_env_uses_downward_api_field_ref(fake_kubernetes):
|
| 320 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 321 |
+
ex.launch_replicas(
|
| 322 |
+
n_replicas=2,
|
| 323 |
+
entrypoint="ignored",
|
| 324 |
+
entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 2},
|
| 325 |
+
)
|
| 326 |
+
_, job = batch.created_jobs[0]
|
| 327 |
+
env = job.spec.template.spec.containers[0].env
|
| 328 |
+
by_name = {e.name: e for e in env}
|
| 329 |
+
|
| 330 |
+
# REPLICA_RANK from the downward-API annotation (NOT a literal value).
|
| 331 |
+
rr = by_name["REPLICA_RANK"]
|
| 332 |
+
assert rr.value is None
|
| 333 |
+
field_ref = rr.value_from.field_ref
|
| 334 |
+
assert (
|
| 335 |
+
field_ref.field_path
|
| 336 |
+
== "metadata.annotations['batch.kubernetes.io/job-completion-index']"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# WORLD_SIZE is a literal string.
|
| 340 |
+
assert by_name["WORLD_SIZE"].value == "2"
|
| 341 |
+
|
| 342 |
+
# rendezvous_uri passed through as an upper-cased literal env var.
|
| 343 |
+
assert by_name["RENDEZVOUS_URI"].value == "s3://b/r/"
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def test_launch_strips_rank_env_kwarg(fake_kubernetes):
|
| 347 |
+
"""`rank_env` is the LocalProcessExecutor convention — must not become env."""
|
| 348 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 349 |
+
ex.launch_replicas(
|
| 350 |
+
n_replicas=1,
|
| 351 |
+
entrypoint="ignored",
|
| 352 |
+
entrypoint_args={"rank_env": "REPLICA_RANK", "rendezvous_uri": "s3://x/"},
|
| 353 |
+
)
|
| 354 |
+
_, job = batch.created_jobs[0]
|
| 355 |
+
env_names = {e.name for e in job.spec.template.spec.containers[0].env}
|
| 356 |
+
assert "RANK_ENV" not in env_names
|
| 357 |
+
assert "RENDEZVOUS_URI" in env_names
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def test_launch_gpu_limit_is_string(fake_kubernetes):
|
| 361 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 362 |
+
ex.launch_replicas(
|
| 363 |
+
n_replicas=2,
|
| 364 |
+
entrypoint="ignored",
|
| 365 |
+
entrypoint_args={"rendezvous_uri": "s3://x/"},
|
| 366 |
+
gpu="A100",
|
| 367 |
+
)
|
| 368 |
+
_, job = batch.created_jobs[0]
|
| 369 |
+
container = job.spec.template.spec.containers[0]
|
| 370 |
+
limits = container.resources.limits
|
| 371 |
+
assert limits["nvidia.com/gpu"] == "1"
|
| 372 |
+
# MUST be a string, not an int.
|
| 373 |
+
assert isinstance(limits["nvidia.com/gpu"], str)
|
| 374 |
+
# GPU node selector merged in.
|
| 375 |
+
node_selector = job.spec.template.spec.node_selector
|
| 376 |
+
assert node_selector["node.kubernetes.io/instance-type"] == "p4d.24xlarge"
|
| 377 |
+
# GPU NoSchedule toleration auto-added.
|
| 378 |
+
tols = job.spec.template.spec.tolerations
|
| 379 |
+
assert any(
|
| 380 |
+
t.key == "nvidia.com/gpu" and t.effect == "NoSchedule" for t in tols
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def test_launch_cpu_only_omits_gpu_limit(fake_kubernetes):
|
| 385 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 386 |
+
ex.launch_replicas(
|
| 387 |
+
n_replicas=2,
|
| 388 |
+
entrypoint="ignored",
|
| 389 |
+
entrypoint_args={"rendezvous_uri": "s3://x/"},
|
| 390 |
+
gpu=None,
|
| 391 |
+
)
|
| 392 |
+
_, job = batch.created_jobs[0]
|
| 393 |
+
limits = job.spec.template.spec.containers[0].resources.limits
|
| 394 |
+
# No GPU -> no nvidia.com/gpu key at all (limits is None or empty).
|
| 395 |
+
assert not limits or "nvidia.com/gpu" not in (limits or {})
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def test_launch_passes_service_account_and_runtime_class(fake_kubernetes):
|
| 399 |
+
ex, batch, _ = _make_executor(
|
| 400 |
+
fake_kubernetes,
|
| 401 |
+
service_account_name="diloco-irsa-sa",
|
| 402 |
+
runtime_class_name="gvisor",
|
| 403 |
+
)
|
| 404 |
+
ex.launch_replicas(
|
| 405 |
+
n_replicas=1,
|
| 406 |
+
entrypoint="ignored",
|
| 407 |
+
entrypoint_args={"rendezvous_uri": "s3://x/"},
|
| 408 |
+
)
|
| 409 |
+
_, job = batch.created_jobs[0]
|
| 410 |
+
pod_spec = job.spec.template.spec
|
| 411 |
+
assert pod_spec.service_account_name == "diloco-irsa-sa"
|
| 412 |
+
assert pod_spec.runtime_class_name == "gvisor"
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def test_launch_timeout_becomes_active_deadline(fake_kubernetes):
|
| 416 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 417 |
+
ex.launch_replicas(
|
| 418 |
+
n_replicas=1,
|
| 419 |
+
entrypoint="ignored",
|
| 420 |
+
entrypoint_args={"rendezvous_uri": "s3://x/"},
|
| 421 |
+
timeout=7200,
|
| 422 |
+
)
|
| 423 |
+
_, job = batch.created_jobs[0]
|
| 424 |
+
assert job.spec.active_deadline_seconds == 7200
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def test_launch_uses_default_entrypoint_command(fake_kubernetes):
|
| 428 |
+
ex, batch, _ = _make_executor(fake_kubernetes)
|
| 429 |
+
ex.launch_replicas(
|
| 430 |
+
n_replicas=1, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}
|
| 431 |
+
)
|
| 432 |
+
_, job = batch.created_jobs[0]
|
| 433 |
+
cmd = job.spec.template.spec.containers[0].command
|
| 434 |
+
assert cmd == [
|
| 435 |
+
"python",
|
| 436 |
+
"-m",
|
| 437 |
+
"composer_replication.diloco.serverless.replica_entrypoint",
|
| 438 |
+
]
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def test_launch_rejects_zero_or_negative(fake_kubernetes):
|
| 442 |
+
ex, _, _ = _make_executor(fake_kubernetes)
|
| 443 |
+
with pytest.raises(ValueError, match="n_replicas"):
|
| 444 |
+
ex.launch_replicas(n_replicas=0, entrypoint="x", entrypoint_args={})
|
| 445 |
+
with pytest.raises(ValueError, match="n_replicas"):
|
| 446 |
+
ex.launch_replicas(n_replicas=-1, entrypoint="x", entrypoint_args={})
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
# ---------------------------------------------------------------------
|
| 450 |
+
# poll — state mapping from completed/failed indexes + active count
|
| 451 |
+
# ---------------------------------------------------------------------
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _launch_two(fake_kubernetes, batch=None, core=None):
|
| 455 |
+
ex, batch, core = _make_executor(fake_kubernetes, batch=batch, core=core)
|
| 456 |
+
handles = ex.launch_replicas(
|
| 457 |
+
n_replicas=4, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://x/"}
|
| 458 |
+
)
|
| 459 |
+
return ex, batch, core, handles
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def test_poll_pending_when_nothing_active(fake_kubernetes):
|
| 463 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 464 |
+
batch.status_obj = _Rec(active=0, completed_indexes=None, failed_indexes=None)
|
| 465 |
+
assert ex.poll(handles[0]) == "pending"
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def test_poll_running_when_active(fake_kubernetes):
|
| 469 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 470 |
+
batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None)
|
| 471 |
+
assert ex.poll(handles[2]) == "running"
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def test_poll_succeeded_when_rank_in_completed_indexes(fake_kubernetes):
|
| 475 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 476 |
+
# completed_indexes "0,2-3" -> ranks {0,2,3} succeeded; rank 1 still running.
|
| 477 |
+
batch.status_obj = _Rec(
|
| 478 |
+
active=1, completed_indexes="0,2-3", failed_indexes=None
|
| 479 |
+
)
|
| 480 |
+
assert ex.poll(handles[0]) == "succeeded"
|
| 481 |
+
assert ex.poll(handles[2]) == "succeeded"
|
| 482 |
+
assert ex.poll(handles[3]) == "succeeded"
|
| 483 |
+
assert ex.poll(handles[1]) == "running"
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def test_poll_failed_when_rank_in_failed_indexes(fake_kubernetes):
|
| 487 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 488 |
+
batch.status_obj = _Rec(
|
| 489 |
+
active=0, completed_indexes="0", failed_indexes="1,3"
|
| 490 |
+
)
|
| 491 |
+
assert ex.poll(handles[1]) == "failed"
|
| 492 |
+
assert ex.poll(handles[3]) == "failed"
|
| 493 |
+
assert ex.poll(handles[0]) == "succeeded"
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def test_poll_failed_on_whole_job_failed_condition(fake_kubernetes):
|
| 497 |
+
"""DeadlineExceeded etc.: a Failed condition with no per-index info -> failed."""
|
| 498 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 499 |
+
batch.status_obj = _Rec(
|
| 500 |
+
active=0,
|
| 501 |
+
completed_indexes=None,
|
| 502 |
+
failed_indexes=None,
|
| 503 |
+
conditions=[_Rec(type="Failed", status="True", reason="DeadlineExceeded")],
|
| 504 |
+
)
|
| 505 |
+
assert ex.poll(handles[0]) == "failed"
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def test_poll_cancelled_on_404(fake_kubernetes):
|
| 509 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 510 |
+
batch.read_raises = _ApiException(status=404)
|
| 511 |
+
assert ex.poll(handles[0]) == "cancelled"
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def test_poll_reraises_non_404_api_exception(fake_kubernetes):
|
| 515 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 516 |
+
batch.read_raises = _ApiException(status=500)
|
| 517 |
+
with pytest.raises(_ApiException):
|
| 518 |
+
ex.poll(handles[0])
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# ---------------------------------------------------------------------
|
| 522 |
+
# cancel — Background propagation on the shared job, idempotent
|
| 523 |
+
# ---------------------------------------------------------------------
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def test_cancel_uses_background_propagation_on_shared_job(fake_kubernetes):
|
| 527 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 528 |
+
ex.cancel(handles[2])
|
| 529 |
+
assert len(batch.delete_calls) == 1
|
| 530 |
+
call = batch.delete_calls[0]
|
| 531 |
+
assert call["propagation_policy"] == "Background"
|
| 532 |
+
assert call["grace_period_seconds"] == 0
|
| 533 |
+
# Cancelling ANY rank deletes the WHOLE shared job (gang semantics).
|
| 534 |
+
assert call["name"] == handles[0].metadata["job_name"]
|
| 535 |
+
assert call["namespace"] == "default"
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def test_cancel_swallows_404(fake_kubernetes):
|
| 539 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 540 |
+
|
| 541 |
+
def _raise_404(name, namespace, body=None):
|
| 542 |
+
raise _ApiException(status=404)
|
| 543 |
+
|
| 544 |
+
batch.delete_namespaced_job = _raise_404
|
| 545 |
+
# Must NOT raise (already deleted == success per the Protocol).
|
| 546 |
+
ex.cancel(handles[0])
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def test_cancel_unknown_handle_is_noop(fake_kubernetes):
|
| 550 |
+
ex, batch, _, _ = _launch_two(fake_kubernetes)
|
| 551 |
+
fake = ReplicaHandle(rank=99, backend_name="eks", metadata={})
|
| 552 |
+
ex.cancel(fake) # no job_name in metadata -> no-op, no delete call
|
| 553 |
+
assert len(batch.delete_calls) == 0
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# ---------------------------------------------------------------------
|
| 557 |
+
# stream_logs — find pod by completion-index annotation
|
| 558 |
+
# ---------------------------------------------------------------------
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def test_stream_logs_reads_pod_for_rank(fake_kubernetes):
|
| 562 |
+
pods = [
|
| 563 |
+
_make_pod("diloco-abcd1234-0-xyz", 0),
|
| 564 |
+
_make_pod("diloco-abcd1234-1-xyz", 1),
|
| 565 |
+
]
|
| 566 |
+
core = _MockCoreV1Api(pods=pods, logs="hello from rank 1\n")
|
| 567 |
+
ex, _, core2, handles = _launch_two(fake_kubernetes, core=core)
|
| 568 |
+
out = ex.stream_logs(handles[1], n_lines=50)
|
| 569 |
+
assert out == "hello from rank 1\n"
|
| 570 |
+
# Read the right pod, container 'replica', tail_lines honored.
|
| 571 |
+
last = core.log_calls[-1]
|
| 572 |
+
assert last["name"] == "diloco-abcd1234-1-xyz"
|
| 573 |
+
assert last["container"] == "replica"
|
| 574 |
+
assert last["tail_lines"] == 50
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def test_stream_logs_placeholder_when_pod_missing(fake_kubernetes):
|
| 578 |
+
core = _MockCoreV1Api(pods=[]) # no pods yet
|
| 579 |
+
ex, _, _, handles = _launch_two(fake_kubernetes, core=core)
|
| 580 |
+
out = ex.stream_logs(handles[0])
|
| 581 |
+
assert "rank 0" in out
|
| 582 |
+
assert "not started" in out or "no logs" in out
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def test_stream_logs_placeholder_on_400(fake_kubernetes):
|
| 586 |
+
pods = [_make_pod("diloco-abcd1234-0-xyz", 0)]
|
| 587 |
+
core = _MockCoreV1Api(pods=pods)
|
| 588 |
+
core.log_raises = _ApiException(status=400) # pod not started yet
|
| 589 |
+
ex, _, _, handles = _launch_two(fake_kubernetes, core=core)
|
| 590 |
+
out = ex.stream_logs(handles[0])
|
| 591 |
+
assert "rank 0" in out
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# ---------------------------------------------------------------------
|
| 595 |
+
# collect — per-rank result dicts in handles order
|
| 596 |
+
# ---------------------------------------------------------------------
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def test_collect_returns_terminal_results_in_order(fake_kubernetes):
|
| 600 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 601 |
+
# All four ranks done: 0-2 succeeded, 3 failed.
|
| 602 |
+
batch.status_obj = _Rec(
|
| 603 |
+
active=0, completed_indexes="0-2", failed_indexes="3"
|
| 604 |
+
)
|
| 605 |
+
results = ex.collect(handles, timeout=5)
|
| 606 |
+
assert len(results) == 4
|
| 607 |
+
for i, r in enumerate(results):
|
| 608 |
+
assert r["rank"] == i
|
| 609 |
+
assert r["job_name"] == handles[0].metadata["job_name"]
|
| 610 |
+
assert results[0]["status"] == "succeeded" and results[0]["exit_code"] == 0
|
| 611 |
+
assert results[1]["status"] == "succeeded"
|
| 612 |
+
assert results[2]["status"] == "succeeded"
|
| 613 |
+
assert results[3]["status"] == "failed" and results[3]["exit_code"] == 1
|
| 614 |
+
assert results[3]["error"] is not None
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def test_collect_returns_non_terminal_state_at_deadline(fake_kubernetes):
|
| 618 |
+
ex, batch, _, handles = _launch_two(fake_kubernetes)
|
| 619 |
+
# Never finishes: active stays > 0.
|
| 620 |
+
batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None)
|
| 621 |
+
results = ex.collect(handles, timeout=0) # immediate deadline
|
| 622 |
+
assert len(results) == 4
|
| 623 |
+
for r in results:
|
| 624 |
+
assert r["status"] in ("running", "pending")
|
| 625 |
+
assert r["exit_code"] is None
|
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for SageMakerExecutor (composer_replication.diloco.serverless.sagemaker).
|
| 2 |
+
|
| 3 |
+
The executor is exercised with an INJECTED mock boto3 sagemaker client (the
|
| 4 |
+
`sagemaker_client=` ctor arg), so these run on any host without boto3 or AWS
|
| 5 |
+
credentials — mirroring the _MockFunctionCall pattern in
|
| 6 |
+
test_modal_spawn_executor.py and the _MockBatchV1Api pattern in
|
| 7 |
+
test_eks_executor.py.
|
| 8 |
+
|
| 9 |
+
Closes the test-coverage gap left when the SageMakerExecutor was first written
|
| 10 |
+
without a test module (caught during Wave-2 integration, 2026-06-09).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import importlib.util
|
| 15 |
+
|
| 16 |
+
import pytest
|
| 17 |
+
|
| 18 |
+
from composer_replication.diloco.serverless import SageMakerExecutor
|
| 19 |
+
from composer_replication.diloco.serverless.executor import ReplicaHandle
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------
|
| 22 |
+
# Mock boto3 sagemaker client
|
| 23 |
+
# ---------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class _MockSMClient:
|
| 27 |
+
"""Records create/stop calls and serves a scripted status per job name."""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self.created: list[dict] = []
|
| 31 |
+
self.stopped: list[str] = []
|
| 32 |
+
# job_name -> (TrainingJobStatus, SecondaryStatus)
|
| 33 |
+
self._status: dict[str, tuple[str, str]] = {}
|
| 34 |
+
self.raise_not_found_on: set[str] = set()
|
| 35 |
+
|
| 36 |
+
def create_training_job(self, **request):
|
| 37 |
+
self.created.append(request)
|
| 38 |
+
# default a newly-created job to InProgress/Starting (== pending)
|
| 39 |
+
self._status[request["TrainingJobName"]] = ("InProgress", "Starting")
|
| 40 |
+
return {"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{request['TrainingJobName']}"}
|
| 41 |
+
|
| 42 |
+
def describe_training_job(self, TrainingJobName): # noqa: N803 (boto3 casing)
|
| 43 |
+
if TrainingJobName in self.raise_not_found_on:
|
| 44 |
+
raise _ResourceNotFoundError(f"job {TrainingJobName} not found")
|
| 45 |
+
status, secondary = self._status.get(TrainingJobName, ("InProgress", "Training"))
|
| 46 |
+
return {
|
| 47 |
+
"TrainingJobName": TrainingJobName,
|
| 48 |
+
"TrainingJobStatus": status,
|
| 49 |
+
"SecondaryStatus": secondary,
|
| 50 |
+
"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{TrainingJobName}",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
def stop_training_job(self, TrainingJobName): # noqa: N803
|
| 54 |
+
self.stopped.append(TrainingJobName)
|
| 55 |
+
|
| 56 |
+
# test helper
|
| 57 |
+
def set_status(self, job_name, status, secondary="Completed"):
|
| 58 |
+
self._status[job_name] = (status, secondary)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class _ResourceNotFoundError(Exception):
|
| 62 |
+
"""Stand-in for botocore ResourceNotFound (the executor matches on name/text)."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, msg):
|
| 65 |
+
super().__init__(msg)
|
| 66 |
+
# botocore-style response shape some impls check
|
| 67 |
+
self.response = {"Error": {"Code": "ResourceNotFound", "Message": msg}}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _make_executor(client=None):
|
| 71 |
+
return SageMakerExecutor(
|
| 72 |
+
image_uri="123.dkr.ecr.us-east-1.amazonaws.com/trainer:latest",
|
| 73 |
+
role_arn="arn:aws:iam::123:role/SMRole",
|
| 74 |
+
output_s3_path="s3://bucket/out/",
|
| 75 |
+
region="us-east-1",
|
| 76 |
+
sagemaker_client=client or _MockSMClient(),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
_VALID_ARGS = {
|
| 81 |
+
"rendezvous_uri": "s3://bucket/rendezvous/run1/",
|
| 82 |
+
"trainer_module": "my_pkg.trainer",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------
|
| 87 |
+
# Construction
|
| 88 |
+
# ---------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_backend_identity():
|
| 92 |
+
ex = _make_executor()
|
| 93 |
+
assert ex.backend_name == "sagemaker"
|
| 94 |
+
assert ex.supports_inter_replica_network is False
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_missing_boto3_raises_when_no_client_injected():
|
| 98 |
+
"""The import-guard path only fires when boto3 is genuinely absent."""
|
| 99 |
+
if importlib.util.find_spec("boto3") is not None:
|
| 100 |
+
pytest.skip("boto3 importable; absent-path cannot be exercised")
|
| 101 |
+
with pytest.raises(RuntimeError, match="boto3"):
|
| 102 |
+
SageMakerExecutor(
|
| 103 |
+
image_uri="x", role_arn="r", output_s3_path="s3://b/o/",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def test_construction_with_injected_client_needs_no_boto3():
|
| 108 |
+
ex = _make_executor()
|
| 109 |
+
assert ex is not None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------
|
| 113 |
+
# launch_replicas
|
| 114 |
+
# ---------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def test_launch_returns_rank_ordered_handles():
|
| 118 |
+
client = _MockSMClient()
|
| 119 |
+
ex = _make_executor(client)
|
| 120 |
+
handles = ex.launch_replicas(3, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
|
| 121 |
+
assert len(handles) == 3
|
| 122 |
+
assert [h.rank for h in handles] == [0, 1, 2]
|
| 123 |
+
assert all(isinstance(h, ReplicaHandle) and h.backend_name == "sagemaker" for h in handles)
|
| 124 |
+
assert len(client.created) == 3
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_launch_injects_rank_world_size_and_rendezvous_env():
|
| 128 |
+
client = _MockSMClient()
|
| 129 |
+
ex = _make_executor(client)
|
| 130 |
+
ex.launch_replicas(2, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
|
| 131 |
+
for rank, req in enumerate(client.created):
|
| 132 |
+
env = req["Environment"]
|
| 133 |
+
assert env["REPLICA_RANK"] == str(rank)
|
| 134 |
+
assert env["WORLD_SIZE"] == "2"
|
| 135 |
+
assert env["RENDEZVOUS_URI"] == _VALID_ARGS["rendezvous_uri"]
|
| 136 |
+
# network isolation MUST stay False (else S3 rendezvous deadlocks)
|
| 137 |
+
assert req["EnableNetworkIsolation"] is False
|
| 138 |
+
assert req["OutputDataConfig"]["S3OutputPath"] == "s3://bucket/out/"
|
| 139 |
+
assert req["ResourceConfig"]["InstanceCount"] == 1
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def test_launch_validates_n_replicas():
|
| 143 |
+
ex = _make_executor()
|
| 144 |
+
with pytest.raises(ValueError, match="n_replicas"):
|
| 145 |
+
ex.launch_replicas(0, entrypoint="x", entrypoint_args=_VALID_ARGS)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def test_launch_requires_rendezvous_and_trainer_module():
|
| 149 |
+
ex = _make_executor()
|
| 150 |
+
with pytest.raises(ValueError, match="rendezvous_uri"):
|
| 151 |
+
ex.launch_replicas(1, entrypoint="x", entrypoint_args={"trainer_module": "m"})
|
| 152 |
+
with pytest.raises(ValueError, match="trainer_module"):
|
| 153 |
+
ex.launch_replicas(1, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://b/r/"})
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def test_launch_partial_failure_stops_siblings_and_raises():
|
| 157 |
+
class _FailingClient(_MockSMClient):
|
| 158 |
+
def create_training_job(self, **request):
|
| 159 |
+
if len(self.created) >= 2: # 3rd create fails
|
| 160 |
+
raise RuntimeError("ThrottlingException")
|
| 161 |
+
return super().create_training_job(**request)
|
| 162 |
+
|
| 163 |
+
client = _FailingClient()
|
| 164 |
+
ex = _make_executor(client)
|
| 165 |
+
with pytest.raises(RuntimeError, match="rank=2"):
|
| 166 |
+
ex.launch_replicas(3, entrypoint="x", entrypoint_args=_VALID_ARGS)
|
| 167 |
+
# the two already-launched siblings were best-effort stopped
|
| 168 |
+
assert len(client.stopped) == 2
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------
|
| 172 |
+
# poll status mapping
|
| 173 |
+
# ---------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def test_poll_status_mapping():
|
| 177 |
+
client = _MockSMClient()
|
| 178 |
+
ex = _make_executor(client)
|
| 179 |
+
handles = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)
|
| 180 |
+
h = handles[0]
|
| 181 |
+
job = client.created[0]["TrainingJobName"]
|
| 182 |
+
|
| 183 |
+
client.set_status(job, "InProgress", "Starting")
|
| 184 |
+
assert ex.poll(h) == "pending"
|
| 185 |
+
client.set_status(job, "InProgress", "Training")
|
| 186 |
+
assert ex.poll(h) == "running"
|
| 187 |
+
client.set_status(job, "Completed")
|
| 188 |
+
assert ex.poll(h) == "succeeded"
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def test_poll_failed_and_stopped():
|
| 192 |
+
client = _MockSMClient()
|
| 193 |
+
ex = _make_executor(client)
|
| 194 |
+
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
|
| 195 |
+
job = client.created[0]["TrainingJobName"]
|
| 196 |
+
client.set_status(job, "Failed")
|
| 197 |
+
assert ex.poll(h) == "failed"
|
| 198 |
+
|
| 199 |
+
client2 = _MockSMClient()
|
| 200 |
+
ex2 = _make_executor(client2)
|
| 201 |
+
h2 = ex2.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
|
| 202 |
+
job2 = client2.created[0]["TrainingJobName"]
|
| 203 |
+
client2.set_status(job2, "Stopped")
|
| 204 |
+
assert ex2.poll(h2) == "cancelled"
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_poll_vanished_job_is_cancelled():
|
| 208 |
+
client = _MockSMClient()
|
| 209 |
+
ex = _make_executor(client)
|
| 210 |
+
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
|
| 211 |
+
client.raise_not_found_on.add(client.created[0]["TrainingJobName"])
|
| 212 |
+
assert ex.poll(h) == "cancelled"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def test_poll_unknown_handle_is_cancelled():
|
| 216 |
+
ex = _make_executor()
|
| 217 |
+
orphan = ReplicaHandle(rank=99, backend_name="sagemaker", metadata={})
|
| 218 |
+
assert ex.poll(orphan) == "cancelled"
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ---------------------------------------------------------------------
|
| 222 |
+
# cancel
|
| 223 |
+
# ---------------------------------------------------------------------
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def test_cancel_calls_stop_training_job():
|
| 227 |
+
client = _MockSMClient()
|
| 228 |
+
ex = _make_executor(client)
|
| 229 |
+
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
|
| 230 |
+
ex.cancel(h)
|
| 231 |
+
assert client.stopped == [client.created[0]["TrainingJobName"]]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def test_cancel_swallows_errors():
|
| 235 |
+
class _RaisingStop(_MockSMClient):
|
| 236 |
+
def stop_training_job(self, TrainingJobName): # noqa: N803
|
| 237 |
+
raise _ResourceNotFoundError("already terminal")
|
| 238 |
+
|
| 239 |
+
client = _RaisingStop()
|
| 240 |
+
ex = _make_executor(client)
|
| 241 |
+
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
|
| 242 |
+
ex.cancel(h) # must not raise
|
| 243 |
+
# unknown handle must also be a no-op
|
| 244 |
+
ex.cancel(ReplicaHandle(rank=42, backend_name="sagemaker", metadata={}))
|
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""composer_replication.safety — run-level collapse safeguards.
|
| 2 |
+
|
| 3 |
+
The #2 collapse safeguard for the self-evolving RL flywheel: a held-out disjoint
|
| 4 |
+
eval + a depth/generation kill-switch. The per-task controls live in
|
| 5 |
+
``composer_replication.datagen`` (4-gate validator, ``HackMonitor`` provenance,
|
| 6 |
+
sandbox denylist); this package adds the missing ACROSS-GENERATION / run-level
|
| 7 |
+
control that watches in-loop (proxy) reward against a disjoint held-out (real)
|
| 8 |
+
eval and HALTS the run when collapse / reward-hacking is caught in the act.
|
| 9 |
+
|
| 10 |
+
Public surface:
|
| 11 |
+
- HeldOutGuard — the stateful kill-switch (kill_switch.py)
|
| 12 |
+
- TripwireStatus — the structured per-update verdict (.fire / .halt / .reason /
|
| 13 |
+
.proxy_real_gap)
|
| 14 |
+
- CollapseStopError — typed exception for exception-based trainer control flow
|
| 15 |
+
- kl_token_trust_filter — per-token KL trust-region mask (torchrl KL-Mask analog)
|
| 16 |
+
|
| 17 |
+
Pure-Python, no torch / cloud deps. See docs/adrs/ADR-015-*.md and the
|
| 18 |
+
'holdout-killswitch' research digest.
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
from composer_replication.safety.kill_switch import (
|
| 23 |
+
CollapseStopError,
|
| 24 |
+
HeldOutGuard,
|
| 25 |
+
TripwireStatus,
|
| 26 |
+
kl_token_trust_filter,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"HeldOutGuard",
|
| 31 |
+
"TripwireStatus",
|
| 32 |
+
"CollapseStopError",
|
| 33 |
+
"kl_token_trust_filter",
|
| 34 |
+
]
|
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""kill_switch.py — held-out collapse tripwire (the #2 collapse safeguard).
|
| 2 |
+
|
| 3 |
+
This is the missing RUN-LEVEL / across-generation control for the self-evolving
|
| 4 |
+
RL flywheel. The per-task controls already exist in ``composer_replication.datagen``
|
| 5 |
+
(the 4-gate solvability validator, the ``HackMonitor`` provenance check, and the
|
| 6 |
+
sandbox denylist); this module sits ABOVE them and watches the whole run.
|
| 7 |
+
|
| 8 |
+
Rationale (the literature is unambiguous that a held-out eval + hard stop is the
|
| 9 |
+
load-bearing control, not a nice-to-have):
|
| 10 |
+
|
| 11 |
+
- **Reward hacking rises monotonically with optimization depth.** Zhao et al.,
|
| 12 |
+
"Reward Hacking in Self-Improving Code Agents" (ICLR 2026 Workshop on RSI,
|
| 13 |
+
OpenReview ``ikrQWGgxYg``) show that going from 10 -> 100 optimization steps
|
| 14 |
+
drives the hacking rate from 26.4% to 57.8% (+31.4 points), and that
|
| 15 |
+
73.8% of KernelBench / 46.8% of ALE-Bench optimizations show *proxy gains
|
| 16 |
+
without real gains*. They define **Hacking Gap = proxy gain - real gain**;
|
| 17 |
+
this module's ``proxy_real_gap()`` is exactly that quantity. They label an
|
| 18 |
+
optimization reward-hacking when it "improves the public metric WITHOUT
|
| 19 |
+
improving the private metric" — the canonical signature this tripwire fires on.
|
| 20 |
+
|
| 21 |
+
- **Self-critique alone is insufficient.** The same paper's "retrospection"
|
| 22 |
+
self-critique sometimes *increased* hacking; their conclusion: "mitigating
|
| 23 |
+
reward hacking likely requires stronger evaluations and constraints beyond
|
| 24 |
+
self-critique alone." So we build a genuinely disjoint held-out eval plus a
|
| 25 |
+
hard stop, not a critique hook.
|
| 26 |
+
|
| 27 |
+
- **Held-out eval is necessary but NOT sufficient by itself.** EvilGenie
|
| 28 |
+
(arXiv 2511.21654) found "only minimal improvement from the use of held out
|
| 29 |
+
test cases" in isolation and that "holdout tests have many surprising failure
|
| 30 |
+
modes." This module is therefore explicitly *defense-in-depth*, layered ON
|
| 31 |
+
TOP of ``HackMonitor`` (provenance) — neither is sufficient alone, matching
|
| 32 |
+
the repo's existing defense-in-depth framing in ``datagen/monitor.py``.
|
| 33 |
+
|
| 34 |
+
- **Closed-loop RL on self-generated data collapses.** The self-evolving-agents
|
| 35 |
+
survey (Gao et al., TMLR 2026; arXiv 2507.21046 v4) §8.3 names "model
|
| 36 |
+
collapse from closed-loop RL on static synthetic data" and prescribes
|
| 37 |
+
"continuous monitoring ... to detect long-horizon value drift" — i.e. a
|
| 38 |
+
per-generation online tripwire, not a one-time eval. Shumailov et al. (Nature
|
| 39 |
+
2024, "AI models collapse when trained on recursively generated data") show
|
| 40 |
+
self-training first loses the distribution tails, then converges to a
|
| 41 |
+
low-variance point estimate; the mitigation that matters here is that the
|
| 42 |
+
held-out eval must stay anchored to REAL tasks that are NEVER fed back to the
|
| 43 |
+
generator (see ``HeldoutSplit``), otherwise the eval drifts with the train set.
|
| 44 |
+
|
| 45 |
+
- **KL-to-init hard stop.** The GRPO "healthy progression" band (Orchestra
|
| 46 |
+
Research GRPO SKILL) climbs 0.02 -> 0.05 -> 0.08 -> 0.12 nats/token over a
|
| 47 |
+
run, with 0.08 the top of the "good progression" band and just below the
|
| 48 |
+
code-generation drift zone (0.05-0.15 per-token); >0.5 is "diverging too
|
| 49 |
+
much." So 0.08 nats/token is a sound HARD-STOP default. Catastrophic Goodhart
|
| 50 |
+
(OpenReview ``UXuBzWoZGK``) proves KL regularization alone does NOT prevent
|
| 51 |
+
heavy-tailed reward misspecification, so the KL hard stop is ONE tripwire
|
| 52 |
+
among several, never the sole control.
|
| 53 |
+
|
| 54 |
+
UNITS GOTCHA (load-bearing): the ``kl_to_init`` this module consumes is
|
| 55 |
+
**token-mean KL in nats/token**, matching the repo convention in
|
| 56 |
+
``composer_replication.integrations.altered_minds.kl_logging.token_mean_kl``.
|
| 57 |
+
A token-mean KL is NOT comparable to a sequence-level / sequence-summed KL
|
| 58 |
+
(whose healthy band is ~0.05-10). The 0.08 default is per-token. Do not pass a
|
| 59 |
+
sequence-summed KL into the per-token hard stop — it will fire instantly.
|
| 60 |
+
|
| 61 |
+
This module is pure-Python: no torch, no cloud deps. ``kl_to_init`` is just a
|
| 62 |
+
float the caller passes (computed upstream by ``token_mean_kl``). It is fully
|
| 63 |
+
CPU-testable.
|
| 64 |
+
"""
|
| 65 |
+
from __future__ import annotations
|
| 66 |
+
|
| 67 |
+
from dataclasses import dataclass, field
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class CollapseStopError(RuntimeError):
|
| 71 |
+
"""Raised (by the caller, optionally) when the tripwire fires a hard stop.
|
| 72 |
+
|
| 73 |
+
The trainer loop can either check ``TripwireStatus.fire`` and stop softly,
|
| 74 |
+
or call ``HeldOutGuard.raise_if_fired(status)`` to convert a fired verdict
|
| 75 |
+
into this typed exception. Carries the structured verdict for logging.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, status: TripwireStatus) -> None:
|
| 79 |
+
super().__init__(status.reason)
|
| 80 |
+
self.status = status
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass(frozen=True)
|
| 84 |
+
class TripwireStatus:
|
| 85 |
+
"""Structured verdict returned by every ``HeldOutGuard.update(...)`` call.
|
| 86 |
+
|
| 87 |
+
Attributes:
|
| 88 |
+
fire: True => the run should HALT (collapse / reward-hacking detected).
|
| 89 |
+
reason: human-readable WHY (empty string when ``fire`` is False), so the
|
| 90 |
+
trainer can log exactly which tripwire tripped, mirroring how
|
| 91 |
+
``datagen/monitor.py`` logs suspected hacks for review.
|
| 92 |
+
step: the round/generation index this verdict was computed at.
|
| 93 |
+
proxy_real_gap: the RSI "Hacking Gap" at this step = (in-loop reward gain
|
| 94 |
+
since baseline) - (held-out score gain since baseline). Positive and
|
| 95 |
+
widening => proxy improving faster than (or while) real declines.
|
| 96 |
+
in_loop_ema: EMA of the in-loop / proxy reward at this step.
|
| 97 |
+
heldout_ema: EMA of the held-out / real eval score at this step.
|
| 98 |
+
kl_ema: EMA of ``kl_to_init`` (nats/token), or None if never supplied.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
fire: bool
|
| 102 |
+
reason: str
|
| 103 |
+
step: int
|
| 104 |
+
proxy_real_gap: float
|
| 105 |
+
in_loop_ema: float
|
| 106 |
+
heldout_ema: float
|
| 107 |
+
kl_ema: float | None = None
|
| 108 |
+
|
| 109 |
+
# `halt` is a documented alias for `fire` — the task spec describes a
|
| 110 |
+
# `should_halt()` / verdict with a `halt` field; expose both names so callers
|
| 111 |
+
# reading either convention work.
|
| 112 |
+
@property
|
| 113 |
+
def halt(self) -> bool:
|
| 114 |
+
return self.fire
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class HeldOutGuard:
|
| 119 |
+
"""Across-generation collapse / reward-hacking kill-switch (HeldOutGuard).
|
| 120 |
+
|
| 121 |
+
Tracks, per generation/round: in-loop (proxy) oracle reward, held-out (real)
|
| 122 |
+
eval score, and optional KL-to-init / entropy / reward-std. Computes the
|
| 123 |
+
proxy-minus-real "Hacking Gap" tripwire and fires a structured ``halt``
|
| 124 |
+
verdict when collapse is caught in the act.
|
| 125 |
+
|
| 126 |
+
The guard is **stateful**: call ``update(round_idx, ...)`` once per checkpoint
|
| 127 |
+
in the trainer loop (the same cadence at which ``DifficultyCurriculum.update``
|
| 128 |
+
is called). It maintains denoised EMAs of every metric (raw single-step
|
| 129 |
+
values are too noisy to threshold — theneuralbase early-stopping guidance) and
|
| 130 |
+
returns a ``TripwireStatus``.
|
| 131 |
+
|
| 132 |
+
Fires (``fire=True``) when ANY of:
|
| 133 |
+
|
| 134 |
+
(a) **collapse-caught-in-the-act** — the in-loop reward EMA is RISING while
|
| 135 |
+
the held-out score EMA has DECLINED for >= ``decline_patience``
|
| 136 |
+
consecutive checkpoints (default 3, matching the "monotone for >=3
|
| 137 |
+
checkpoints" rule). This is the canonical reward-hacking signature.
|
| 138 |
+
|
| 139 |
+
(b) **KL breach** — the ``kl_to_init`` EMA exceeds ``kl_hard_stop`` (default
|
| 140 |
+
0.08 nats/token) on/after ``min_steps``.
|
| 141 |
+
|
| 142 |
+
(c) **proxy-real gap blowout** — the Hacking Gap (proxy gain - real gain
|
| 143 |
+
since baseline) widens beyond ``max_proxy_real_gap``, even if held-out
|
| 144 |
+
has not strictly declined for the full patience window (a fast
|
| 145 |
+
single-generation divergence).
|
| 146 |
+
|
| 147 |
+
No tripwire fires before ``min_steps`` (avoids halting on early-run noise,
|
| 148 |
+
when both signals are still warming up).
|
| 149 |
+
|
| 150 |
+
The guard is idempotent in the sense that re-querying ``last_status`` or
|
| 151 |
+
calling ``should_halt()`` does not advance state — only ``update`` does.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
# --- thresholds (calibratable; see calibrate_kl_threshold) ---------------
|
| 155 |
+
kl_hard_stop: float = 0.08 # nats/token; top of GRPO "good" band
|
| 156 |
+
max_proxy_real_gap: float = 0.10 # absolute Hacking-Gap blowout ceiling
|
| 157 |
+
# --- temporal gates ------------------------------------------------------
|
| 158 |
+
min_steps: int = 20 # no fire before this many updates
|
| 159 |
+
decline_patience: int = 3 # consecutive held-out declines to fire (a)
|
| 160 |
+
# --- denoising -----------------------------------------------------------
|
| 161 |
+
ema_alpha: float = 0.9 # EMA weight on the PRIOR (0.9 => slow)
|
| 162 |
+
rise_eps: float = 1e-4 # min EMA delta to count as "rising"/"declining"
|
| 163 |
+
|
| 164 |
+
# --- internal state (do not set directly) --------------------------------
|
| 165 |
+
_n: int = field(default=0, init=False)
|
| 166 |
+
_in_loop_ema: float | None = field(default=None, init=False)
|
| 167 |
+
_heldout_ema: float | None = field(default=None, init=False)
|
| 168 |
+
_kl_ema: float | None = field(default=None, init=False)
|
| 169 |
+
_entropy_ema: float | None = field(default=None, init=False)
|
| 170 |
+
_reward_std_ema: float | None = field(default=None, init=False)
|
| 171 |
+
_in_loop_baseline: float | None = field(default=None, init=False)
|
| 172 |
+
_heldout_baseline: float | None = field(default=None, init=False)
|
| 173 |
+
_prev_in_loop_ema: float | None = field(default=None, init=False)
|
| 174 |
+
_prev_heldout_ema: float | None = field(default=None, init=False)
|
| 175 |
+
_heldout_decline_streak: int = field(default=0, init=False)
|
| 176 |
+
_last_status: TripwireStatus | None = field(default=None, init=False)
|
| 177 |
+
_fired: bool = field(default=False, init=False)
|
| 178 |
+
|
| 179 |
+
def __post_init__(self) -> None:
|
| 180 |
+
if not (0.0 <= self.ema_alpha < 1.0):
|
| 181 |
+
raise ValueError(
|
| 182 |
+
f"ema_alpha must be in [0, 1), got {self.ema_alpha!r} "
|
| 183 |
+
"(it is the weight on the PRIOR EMA)."
|
| 184 |
+
)
|
| 185 |
+
if self.kl_hard_stop <= 0.0:
|
| 186 |
+
raise ValueError(f"kl_hard_stop must be > 0, got {self.kl_hard_stop!r}")
|
| 187 |
+
if self.decline_patience < 1:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
f"decline_patience must be >= 1, got {self.decline_patience!r}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# ------------------------------------------------------------------------
|
| 193 |
+
# core API
|
| 194 |
+
# ------------------------------------------------------------------------
|
| 195 |
+
def update(
|
| 196 |
+
self,
|
| 197 |
+
round_idx: int,
|
| 198 |
+
in_loop_reward: float,
|
| 199 |
+
heldout_score: float,
|
| 200 |
+
kl_to_init: float | None = None,
|
| 201 |
+
entropy: float | None = None,
|
| 202 |
+
reward_std: float | None = None,
|
| 203 |
+
) -> TripwireStatus:
|
| 204 |
+
"""Fold one checkpoint's metrics in and return the current verdict.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
round_idx: the generation / round index (for logging; not used for
|
| 208 |
+
gating — the internal update counter ``_n`` drives ``min_steps``
|
| 209 |
+
so the guard is robust to non-contiguous round indices).
|
| 210 |
+
in_loop_reward: mean in-loop (proxy / oracle) reward this round. This
|
| 211 |
+
is what the policy is optimizing against.
|
| 212 |
+
heldout_score: mean score on the DISJOINT held-out eval pool this
|
| 213 |
+
round — REAL tasks the generator never trains on. See
|
| 214 |
+
``composer_replication.safety.holdout`` design notes / the
|
| 215 |
+
``HeldoutSplit`` discipline; if held-out drifts with the train
|
| 216 |
+
set the gap signal is meaningless.
|
| 217 |
+
kl_to_init: optional token-mean KL(policy || init) in nats/token
|
| 218 |
+
(this repo's ``token_mean_kl`` convention). NOT sequence-level KL.
|
| 219 |
+
entropy: optional policy entropy (early-warning of entropy collapse,
|
| 220 |
+
"the silent killer of RLVR generalization"). Tracked + exposed,
|
| 221 |
+
not currently a hard gate.
|
| 222 |
+
reward_std: optional std of the reward distribution (tracked; a
|
| 223 |
+
collapsing std is an early collapse signal).
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
A ``TripwireStatus``. Once the guard has fired, every subsequent
|
| 227 |
+
``update`` keeps ``fire=True`` (latched) so a transient recovery
|
| 228 |
+
after a detected collapse cannot silently un-halt the run.
|
| 229 |
+
"""
|
| 230 |
+
self._n += 1
|
| 231 |
+
|
| 232 |
+
# --- EMA folds (alpha on the prior; first sample seeds the EMA) -------
|
| 233 |
+
self._in_loop_ema = self._fold(self._in_loop_ema, float(in_loop_reward))
|
| 234 |
+
self._heldout_ema = self._fold(self._heldout_ema, float(heldout_score))
|
| 235 |
+
if kl_to_init is not None:
|
| 236 |
+
self._kl_ema = self._fold(self._kl_ema, float(kl_to_init))
|
| 237 |
+
if entropy is not None:
|
| 238 |
+
self._entropy_ema = self._fold(self._entropy_ema, float(entropy))
|
| 239 |
+
if reward_std is not None:
|
| 240 |
+
self._reward_std_ema = self._fold(self._reward_std_ema, float(reward_std))
|
| 241 |
+
|
| 242 |
+
# --- baselines: seed on the first update so gains are measured from
|
| 243 |
+
# run start (the RSI Hacking-Gap is a gain-since-baseline quantity). -
|
| 244 |
+
if self._in_loop_baseline is None:
|
| 245 |
+
self._in_loop_baseline = self._in_loop_ema
|
| 246 |
+
if self._heldout_baseline is None:
|
| 247 |
+
self._heldout_baseline = self._heldout_ema
|
| 248 |
+
|
| 249 |
+
# --- track the held-out decline streak (uses EMA deltas, denoised) ----
|
| 250 |
+
in_loop_rising = (
|
| 251 |
+
self._prev_in_loop_ema is not None
|
| 252 |
+
and (self._in_loop_ema - self._prev_in_loop_ema) > self.rise_eps
|
| 253 |
+
)
|
| 254 |
+
heldout_declining = (
|
| 255 |
+
self._prev_heldout_ema is not None
|
| 256 |
+
and (self._heldout_ema - self._prev_heldout_ema) < -self.rise_eps
|
| 257 |
+
)
|
| 258 |
+
# The collapse signature is held-out DOWN while in-loop UP. We only count
|
| 259 |
+
# a decline toward the streak when in-loop is simultaneously rising — a
|
| 260 |
+
# held-out dip during an in-loop dip is just noise / a hard batch, not
|
| 261 |
+
# reward hacking.
|
| 262 |
+
if heldout_declining and in_loop_rising:
|
| 263 |
+
self._heldout_decline_streak += 1
|
| 264 |
+
elif not heldout_declining:
|
| 265 |
+
self._heldout_decline_streak = 0
|
| 266 |
+
# (if held-out declines but in-loop is flat/down we neither grow nor reset
|
| 267 |
+
# the streak immediately — but the elif above resets on any non-decline,
|
| 268 |
+
# so a single clean checkpoint clears it.)
|
| 269 |
+
|
| 270 |
+
gap = self.proxy_real_gap()
|
| 271 |
+
status = self._evaluate(round_idx, gap)
|
| 272 |
+
|
| 273 |
+
# advance "previous EMA" trackers AFTER evaluation
|
| 274 |
+
self._prev_in_loop_ema = self._in_loop_ema
|
| 275 |
+
self._prev_heldout_ema = self._heldout_ema
|
| 276 |
+
self._last_status = status
|
| 277 |
+
if status.fire:
|
| 278 |
+
self._fired = True
|
| 279 |
+
return status
|
| 280 |
+
|
| 281 |
+
def _evaluate(self, round_idx: int, gap: float) -> TripwireStatus:
|
| 282 |
+
"""Decide the verdict from current state. Pure (no state mutation)."""
|
| 283 |
+
assert self._in_loop_ema is not None and self._heldout_ema is not None
|
| 284 |
+
|
| 285 |
+
base = dict(
|
| 286 |
+
step=round_idx,
|
| 287 |
+
proxy_real_gap=gap,
|
| 288 |
+
in_loop_ema=self._in_loop_ema,
|
| 289 |
+
heldout_ema=self._heldout_ema,
|
| 290 |
+
kl_ema=self._kl_ema,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Latched: once fired, stay fired (cannot silently un-halt).
|
| 294 |
+
if self._fired:
|
| 295 |
+
prev_reason = self._last_status.reason if self._last_status else "collapse"
|
| 296 |
+
return TripwireStatus(fire=True, reason=f"latched: {prev_reason}", **base)
|
| 297 |
+
|
| 298 |
+
# Warm-up guard: never fire on early-run noise.
|
| 299 |
+
if self._n < self.min_steps:
|
| 300 |
+
return TripwireStatus(fire=False, reason="", **base)
|
| 301 |
+
|
| 302 |
+
# (b) KL hard stop — checked first; it's the cheapest unambiguous breach.
|
| 303 |
+
if self._kl_ema is not None and self._kl_ema > self.kl_hard_stop:
|
| 304 |
+
return TripwireStatus(
|
| 305 |
+
fire=True,
|
| 306 |
+
reason=(
|
| 307 |
+
f"kl_to_init EMA {self._kl_ema:.4f} nats/token exceeds hard "
|
| 308 |
+
f"stop {self.kl_hard_stop:.4f} (policy drifting from init)"
|
| 309 |
+
),
|
| 310 |
+
**base,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# (a) collapse caught in the act — held-out declines while in-loop rises.
|
| 314 |
+
if self._heldout_decline_streak >= self.decline_patience:
|
| 315 |
+
return TripwireStatus(
|
| 316 |
+
fire=True,
|
| 317 |
+
reason=(
|
| 318 |
+
f"reward-hacking signature: held-out score declined while "
|
| 319 |
+
f"in-loop reward rose for {self._heldout_decline_streak} "
|
| 320 |
+
f"consecutive checkpoints (Hacking Gap {gap:.4f})"
|
| 321 |
+
),
|
| 322 |
+
**base,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# (c) proxy-real gap blowout — fast single-generation divergence.
|
| 326 |
+
if gap > self.max_proxy_real_gap:
|
| 327 |
+
return TripwireStatus(
|
| 328 |
+
fire=True,
|
| 329 |
+
reason=(
|
| 330 |
+
f"proxy-real Hacking Gap {gap:.4f} exceeds ceiling "
|
| 331 |
+
f"{self.max_proxy_real_gap:.4f} (proxy reward improving far "
|
| 332 |
+
f"faster than real held-out eval)"
|
| 333 |
+
),
|
| 334 |
+
**base,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
return TripwireStatus(fire=False, reason="", **base)
|
| 338 |
+
|
| 339 |
+
# ------------------------------------------------------------------------
|
| 340 |
+
# query helpers (do NOT advance state — idempotent)
|
| 341 |
+
# ------------------------------------------------------------------------
|
| 342 |
+
def should_halt(self) -> bool:
|
| 343 |
+
"""True if the most recent ``update`` produced a halt verdict.
|
| 344 |
+
|
| 345 |
+
Idempotent: querying does not advance the EMA state.
|
| 346 |
+
"""
|
| 347 |
+
return self._last_status is not None and self._last_status.fire
|
| 348 |
+
|
| 349 |
+
@property
|
| 350 |
+
def last_status(self) -> TripwireStatus | None:
|
| 351 |
+
"""The most recent verdict, or None if ``update`` was never called."""
|
| 352 |
+
return self._last_status
|
| 353 |
+
|
| 354 |
+
def raise_if_fired(self, status: TripwireStatus | None = None) -> None:
|
| 355 |
+
"""Convert a fired verdict into a typed ``CollapseStopError`` exception.
|
| 356 |
+
|
| 357 |
+
Pass the status returned by ``update`` (or omit to use ``last_status``).
|
| 358 |
+
Trainer loops that prefer exception-based control flow call this right
|
| 359 |
+
after ``update``; loops that prefer flag-checking just read
|
| 360 |
+
``status.fire`` / ``should_halt()``.
|
| 361 |
+
"""
|
| 362 |
+
st = status if status is not None else self._last_status
|
| 363 |
+
if st is not None and st.fire:
|
| 364 |
+
raise CollapseStopError(st)
|
| 365 |
+
|
| 366 |
+
def proxy_real_gap(self) -> float:
|
| 367 |
+
"""The RSI Hacking Gap = (in-loop gain) - (held-out gain), both measured
|
| 368 |
+
as EMA-minus-baseline since run start.
|
| 369 |
+
|
| 370 |
+
Returns 0.0 before the first ``update`` (no baseline yet). A positive,
|
| 371 |
+
widening value is the reward-hacking fingerprint: the proxy the policy
|
| 372 |
+
optimizes is improving more than the real held-out objective.
|
| 373 |
+
"""
|
| 374 |
+
if (
|
| 375 |
+
self._in_loop_ema is None
|
| 376 |
+
or self._heldout_ema is None
|
| 377 |
+
or self._in_loop_baseline is None
|
| 378 |
+
or self._heldout_baseline is None
|
| 379 |
+
):
|
| 380 |
+
return 0.0
|
| 381 |
+
in_loop_gain = self._in_loop_ema - self._in_loop_baseline
|
| 382 |
+
heldout_gain = self._heldout_ema - self._heldout_baseline
|
| 383 |
+
return in_loop_gain - heldout_gain
|
| 384 |
+
|
| 385 |
+
# ------------------------------------------------------------------------
|
| 386 |
+
# calibration
|
| 387 |
+
# ------------------------------------------------------------------------
|
| 388 |
+
def calibrate_kl_threshold(
|
| 389 |
+
self, baseline_kls: list[float], factor: float = 3.0
|
| 390 |
+
) -> float:
|
| 391 |
+
"""Set ``kl_hard_stop`` to ``factor`` x the mean of early-run baseline KLs.
|
| 392 |
+
|
| 393 |
+
theneuralbase guidance: "Record baseline KL during the first ~100 steps,
|
| 394 |
+
set max to 3x that." Single fixed thresholds are dataset-dependent; this
|
| 395 |
+
adapts to the run's own KL scale.
|
| 396 |
+
|
| 397 |
+
SAFETY CLAMP: calibration may only ever TIGHTEN the hard stop, never
|
| 398 |
+
loosen it past the documented collapse band. The returned (and stored)
|
| 399 |
+
threshold is ``min(3x baseline, current kl_hard_stop)`` — so a noisy /
|
| 400 |
+
already-drifting baseline cannot raise the ceiling above 0.08 nats/token.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
baseline_kls: per-step token-mean KL values from early in the run.
|
| 404 |
+
factor: multiplier on the baseline mean (default 3.0).
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
The new ``kl_hard_stop`` (also stored on the instance).
|
| 408 |
+
|
| 409 |
+
Raises:
|
| 410 |
+
ValueError: if ``baseline_kls`` is empty.
|
| 411 |
+
"""
|
| 412 |
+
if not baseline_kls:
|
| 413 |
+
raise ValueError("baseline_kls must be non-empty to calibrate")
|
| 414 |
+
mean_kl = sum(baseline_kls) / len(baseline_kls)
|
| 415 |
+
calibrated = factor * mean_kl
|
| 416 |
+
# Only tighten: never let calibration loosen past the current ceiling.
|
| 417 |
+
self.kl_hard_stop = min(calibrated, self.kl_hard_stop)
|
| 418 |
+
return self.kl_hard_stop
|
| 419 |
+
|
| 420 |
+
# ------------------------------------------------------------------------
|
| 421 |
+
# internals
|
| 422 |
+
# ------------------------------------------------------------------------
|
| 423 |
+
def _fold(self, prev: float | None, x: float) -> float:
|
| 424 |
+
"""EMA fold; the first observation seeds the EMA (no warm-up bias)."""
|
| 425 |
+
if prev is None:
|
| 426 |
+
return x
|
| 427 |
+
return self.ema_alpha * prev + (1.0 - self.ema_alpha) * x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool:
|
| 431 |
+
"""Per-token KL trust-region mask, mirroring torchrl's GRPO "KL-Mask".
|
| 432 |
+
|
| 433 |
+
torchrl masks any TOKEN whose ``0.5 * (log pi/pi_ref)^2`` (the Schulman k2
|
| 434 |
+
estimator of per-token KL) exceeds a threshold, forming a per-token trust
|
| 435 |
+
region. This helper returns True when the token should be MASKED OUT (its
|
| 436 |
+
KL contribution is too large), so it can be wired into a loss later without
|
| 437 |
+
pulling torch into this module — the caller computes ``0.5 * logratio**2``.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
logratio_sq_half: ``0.5 * (log pi/pi_ref)^2`` for one token (nats).
|
| 441 |
+
threshold: per-token KL ceiling (default 0.08 nats, the same band as the
|
| 442 |
+
run-level hard stop).
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
True if the token exceeds the trust region and should be masked.
|
| 446 |
+
"""
|
| 447 |
+
return logratio_sq_half > threshold
|
|
File without changes
|
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the held-out collapse kill-switch (HeldOutGuard).
|
| 2 |
+
|
| 3 |
+
CPU-only, pure-Python — no torch, no cloud. Mirrors the
|
| 4 |
+
``datagen/tests/test_feature_deletion.py`` style (small helpers, behavioral
|
| 5 |
+
asserts). Covers:
|
| 6 |
+
- no-halt on a healthy co-rising run (the held-out-twin "within noise" case);
|
| 7 |
+
- HALT on the canonical signature: held-out declines while in-loop rises;
|
| 8 |
+
- HALT on KL-to-init hard-stop breach;
|
| 9 |
+
- HALT on a fast proxy-real Hacking-Gap blowout;
|
| 10 |
+
- window / patience behavior (min_steps warm-up; decline_patience streak);
|
| 11 |
+
- calibration tightens-only;
|
| 12 |
+
- idempotent query + latched-fire edge cases.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import pytest
|
| 17 |
+
|
| 18 |
+
from composer_replication.safety import (
|
| 19 |
+
CollapseStopError,
|
| 20 |
+
HeldOutGuard,
|
| 21 |
+
TripwireStatus,
|
| 22 |
+
kl_token_trust_filter,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _guard(**kw) -> HeldOutGuard:
|
| 27 |
+
# Small min_steps keeps tests fast while still exercising the warm-up gate.
|
| 28 |
+
base = dict(min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08)
|
| 29 |
+
base.update(kw)
|
| 30 |
+
return HeldOutGuard(**base)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# --- healthy run: both rise => never halt -----------------------------------
|
| 34 |
+
|
| 35 |
+
def test_no_halt_when_both_rise():
|
| 36 |
+
"""Clean run: in-loop and held-out rise together, KL stays in band. The
|
| 37 |
+
held-out twin scores within noise of the proxy => no fire (the well-behaved
|
| 38 |
+
case the literature says a clean model exhibits)."""
|
| 39 |
+
g = _guard()
|
| 40 |
+
status = None
|
| 41 |
+
for i in range(30):
|
| 42 |
+
status = g.update(
|
| 43 |
+
i,
|
| 44 |
+
in_loop_reward=0.30 + 0.01 * i,
|
| 45 |
+
heldout_score=0.28 + 0.01 * i, # tracks proxy within noise
|
| 46 |
+
kl_to_init=0.03,
|
| 47 |
+
)
|
| 48 |
+
assert not status.fire, f"fired unexpectedly at step {i}: {status.reason}"
|
| 49 |
+
assert not g.should_halt()
|
| 50 |
+
# Gap stays near zero because both gained equally.
|
| 51 |
+
assert abs(g.proxy_real_gap()) < 0.05
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# --- canonical signature: held-out declines while in-loop rises -------------
|
| 55 |
+
|
| 56 |
+
def test_halt_on_heldout_declines_while_reward_rises():
|
| 57 |
+
g = _guard(max_proxy_real_gap=10.0) # disable gap-blowout path to isolate (a)
|
| 58 |
+
# Warm up past min_steps with a stable healthy stretch.
|
| 59 |
+
for i in range(6):
|
| 60 |
+
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
|
| 61 |
+
assert not s.fire
|
| 62 |
+
# Now: proxy reward climbs, held-out eval falls — the reward-hacking
|
| 63 |
+
# fingerprint. Should fire once the decline streak hits decline_patience (3).
|
| 64 |
+
fired_at = None
|
| 65 |
+
for j, i in enumerate(range(6, 12)):
|
| 66 |
+
s = g.update(
|
| 67 |
+
i,
|
| 68 |
+
in_loop_reward=0.40 + 0.05 * (j + 1), # rising
|
| 69 |
+
heldout_score=0.40 - 0.05 * (j + 1), # declining
|
| 70 |
+
kl_to_init=0.03, # KL stays in band
|
| 71 |
+
)
|
| 72 |
+
if s.fire:
|
| 73 |
+
fired_at = i
|
| 74 |
+
break
|
| 75 |
+
assert fired_at is not None, "tripwire never fired on the collapse signature"
|
| 76 |
+
assert g.should_halt()
|
| 77 |
+
s = g.last_status
|
| 78 |
+
assert "held-out" in s.reason and "consecutive" in s.reason
|
| 79 |
+
assert s.proxy_real_gap > 0.0 # proxy gained while real lost
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_does_not_fire_before_patience_window():
|
| 83 |
+
"""Held-out declining while in-loop rises for FEWER than decline_patience
|
| 84 |
+
checkpoints must NOT fire (window behavior)."""
|
| 85 |
+
g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
|
| 86 |
+
for i in range(6):
|
| 87 |
+
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
|
| 88 |
+
# Only 2 divergent checkpoints (< patience of 3) => no fire.
|
| 89 |
+
s1 = g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
|
| 90 |
+
s2 = g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
|
| 91 |
+
assert not s1.fire and not s2.fire
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_decline_streak_resets_on_recovery():
|
| 95 |
+
"""A clean checkpoint (held-out recovers) resets the decline streak, so a
|
| 96 |
+
later short divergence does not inherit prior declines."""
|
| 97 |
+
g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
|
| 98 |
+
for i in range(6):
|
| 99 |
+
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
|
| 100 |
+
# 2 declines...
|
| 101 |
+
g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
|
| 102 |
+
g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
|
| 103 |
+
# ...then held-out recovers (resets streak)...
|
| 104 |
+
s = g.update(8, in_loop_reward=0.50, heldout_score=0.45, kl_to_init=0.03)
|
| 105 |
+
assert not s.fire
|
| 106 |
+
# ...one more decline is only streak=1, still below patience.
|
| 107 |
+
s = g.update(9, in_loop_reward=0.55, heldout_score=0.40, kl_to_init=0.03)
|
| 108 |
+
assert not s.fire
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# --- KL hard-stop ------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def test_halt_on_kl_hard_stop_breach():
|
| 114 |
+
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
|
| 115 |
+
# Healthy KL through the warm-up; both metrics flat so only KL can fire.
|
| 116 |
+
for i in range(5):
|
| 117 |
+
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
|
| 118 |
+
assert not s.fire
|
| 119 |
+
# KL spikes well above 0.08; EMA climbs across a couple steps then crosses.
|
| 120 |
+
fired = False
|
| 121 |
+
for i in range(5, 12):
|
| 122 |
+
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.20)
|
| 123 |
+
if s.fire:
|
| 124 |
+
fired = True
|
| 125 |
+
assert "kl_to_init" in s.reason and "hard stop" in s.reason
|
| 126 |
+
break
|
| 127 |
+
assert fired, "KL hard-stop never fired despite KL EMA crossing the ceiling"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_kl_none_never_fires_kl_path():
|
| 131 |
+
"""If the caller never supplies kl_to_init, the KL path must be inert (and
|
| 132 |
+
kl_ema stays None) — KL is an optional float."""
|
| 133 |
+
g = _guard(max_proxy_real_gap=10.0)
|
| 134 |
+
s = None
|
| 135 |
+
for i in range(20):
|
| 136 |
+
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=None)
|
| 137 |
+
assert s is not None and not s.fire
|
| 138 |
+
assert s.kl_ema is None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# --- proxy-real gap blowout (fast divergence) -------------------------------
|
| 142 |
+
|
| 143 |
+
def test_halt_on_proxy_real_gap_blowout():
|
| 144 |
+
"""A large single-generation divergence (proxy jumps, real stays flat) fires
|
| 145 |
+
via the gap-blowout path even before the decline streak reaches patience."""
|
| 146 |
+
g = _guard(max_proxy_real_gap=0.10, decline_patience=100) # disable (a)
|
| 147 |
+
for i in range(5):
|
| 148 |
+
g.update(i, in_loop_reward=0.30, heldout_score=0.30, kl_to_init=0.03)
|
| 149 |
+
# Proxy blows up; held-out flat. With ema_alpha=0.5 the gap crosses 0.10 fast.
|
| 150 |
+
fired = False
|
| 151 |
+
for i in range(5, 12):
|
| 152 |
+
s = g.update(i, in_loop_reward=0.90, heldout_score=0.30, kl_to_init=0.03)
|
| 153 |
+
if s.fire:
|
| 154 |
+
fired = True
|
| 155 |
+
assert "Hacking Gap" in s.reason
|
| 156 |
+
assert s.proxy_real_gap > 0.10
|
| 157 |
+
break
|
| 158 |
+
assert fired, "gap-blowout tripwire never fired"
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# --- warm-up window (min_steps) ---------------------------------------------
|
| 162 |
+
|
| 163 |
+
def test_respects_min_steps_no_early_fire():
|
| 164 |
+
"""Even with every signal tripped, no fire before min_steps (avoids halting
|
| 165 |
+
on early-run noise)."""
|
| 166 |
+
g = _guard(min_steps=10, decline_patience=2, kl_hard_stop=0.08,
|
| 167 |
+
max_proxy_real_gap=0.01)
|
| 168 |
+
# Egregiously bad signals from step 0: KL huge, proxy up, held-out down.
|
| 169 |
+
for i in range(9): # 9 updates, all < min_steps=10
|
| 170 |
+
s = g.update(i, in_loop_reward=0.10 + 0.1 * i, heldout_score=0.90 - 0.1 * i,
|
| 171 |
+
kl_to_init=0.9)
|
| 172 |
+
assert not s.fire, f"fired during warm-up at step {i}: {s.reason}"
|
| 173 |
+
# The 10th update (n==10, not < min_steps) is now allowed to fire.
|
| 174 |
+
s = g.update(9, in_loop_reward=1.5, heldout_score=0.0, kl_to_init=0.9)
|
| 175 |
+
assert s.fire
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# --- calibration -------------------------------------------------------------
|
| 179 |
+
|
| 180 |
+
def test_calibrate_kl_threshold_tightens_only():
|
| 181 |
+
g = _guard(kl_hard_stop=0.08)
|
| 182 |
+
# Baseline mean 0.01 => 3x = 0.03 < 0.08 => tightens to 0.03.
|
| 183 |
+
new = g.calibrate_kl_threshold([0.008, 0.010, 0.012], factor=3.0)
|
| 184 |
+
assert new == pytest.approx(0.03, abs=1e-9)
|
| 185 |
+
assert g.kl_hard_stop == pytest.approx(0.03, abs=1e-9)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def test_calibrate_never_loosens_past_band():
|
| 189 |
+
g = _guard(kl_hard_stop=0.08)
|
| 190 |
+
# A drifting baseline (mean 0.05 => 3x = 0.15) must NOT loosen past 0.08.
|
| 191 |
+
new = g.calibrate_kl_threshold([0.05, 0.05, 0.05], factor=3.0)
|
| 192 |
+
assert new == pytest.approx(0.08, abs=1e-9)
|
| 193 |
+
assert g.kl_hard_stop == pytest.approx(0.08, abs=1e-9)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def test_calibrate_empty_raises():
|
| 197 |
+
g = _guard()
|
| 198 |
+
with pytest.raises(ValueError, match="non-empty"):
|
| 199 |
+
g.calibrate_kl_threshold([])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# --- proxy_real_gap definition ----------------------------------------------
|
| 203 |
+
|
| 204 |
+
def test_proxy_real_gap_is_gain_difference():
|
| 205 |
+
g = _guard(min_steps=100, max_proxy_real_gap=10.0) # disable firing
|
| 206 |
+
g.update(0, in_loop_reward=0.20, heldout_score=0.20, kl_to_init=0.02) # baseline
|
| 207 |
+
# With ema_alpha=0.5 the second sample moves each EMA halfway.
|
| 208 |
+
g.update(1, in_loop_reward=0.60, heldout_score=0.30, kl_to_init=0.02)
|
| 209 |
+
# in_loop EMA: 0.5*0.20 + 0.5*0.60 = 0.40; gain = 0.40-0.20 = 0.20
|
| 210 |
+
# heldout EMA: 0.5*0.20 + 0.5*0.30 = 0.25; gain = 0.25-0.20 = 0.05
|
| 211 |
+
# gap = 0.20 - 0.05 = 0.15
|
| 212 |
+
assert g.proxy_real_gap() == pytest.approx(0.15, abs=1e-9)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def test_proxy_real_gap_zero_before_update():
|
| 216 |
+
g = _guard()
|
| 217 |
+
assert g.proxy_real_gap() == 0.0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# --- idempotency / edge cases -----------------------------------------------
|
| 221 |
+
|
| 222 |
+
def test_should_halt_is_idempotent_query():
|
| 223 |
+
g = _guard(max_proxy_real_gap=10.0)
|
| 224 |
+
for i in range(6):
|
| 225 |
+
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
|
| 226 |
+
# Querying repeatedly must not advance state or change the verdict.
|
| 227 |
+
snap_gap = g.proxy_real_gap()
|
| 228 |
+
assert g.should_halt() is False
|
| 229 |
+
assert g.should_halt() is False
|
| 230 |
+
assert g.proxy_real_gap() == snap_gap # unchanged by querying
|
| 231 |
+
assert g.last_status is not None and not g.last_status.fire
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def test_fire_is_latched():
|
| 235 |
+
"""Once fired, a subsequent recovery cannot silently un-halt the run."""
|
| 236 |
+
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
|
| 237 |
+
for i in range(5):
|
| 238 |
+
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
|
| 239 |
+
# Drive a KL breach.
|
| 240 |
+
fired = False
|
| 241 |
+
for i in range(5, 12):
|
| 242 |
+
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
|
| 243 |
+
if s.fire:
|
| 244 |
+
fired = True
|
| 245 |
+
break
|
| 246 |
+
assert fired
|
| 247 |
+
# Now KL recovers to healthy — verdict must stay fired (latched).
|
| 248 |
+
s = g.update(99, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.01)
|
| 249 |
+
assert s.fire and s.reason.startswith("latched:")
|
| 250 |
+
assert g.should_halt()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_raise_if_fired_raises_typed_exception():
|
| 254 |
+
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
|
| 255 |
+
for i in range(5):
|
| 256 |
+
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
|
| 257 |
+
status = None
|
| 258 |
+
for i in range(5, 12):
|
| 259 |
+
status = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
|
| 260 |
+
if status.fire:
|
| 261 |
+
break
|
| 262 |
+
assert status is not None and status.fire
|
| 263 |
+
with pytest.raises(CollapseStopError) as exc:
|
| 264 |
+
g.raise_if_fired(status)
|
| 265 |
+
assert exc.value.status is status
|
| 266 |
+
assert isinstance(str(exc.value), str) and str(exc.value)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def test_raise_if_fired_noop_when_clean():
|
| 270 |
+
g = _guard(max_proxy_real_gap=10.0)
|
| 271 |
+
s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
|
| 272 |
+
# No fire => no raise (uses last_status when arg omitted).
|
| 273 |
+
g.raise_if_fired(s)
|
| 274 |
+
g.raise_if_fired()
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def test_status_halt_alias_matches_fire():
|
| 278 |
+
g = _guard(max_proxy_real_gap=10.0)
|
| 279 |
+
s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
|
| 280 |
+
assert s.halt == s.fire is False
|
| 281 |
+
assert isinstance(s, TripwireStatus)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def test_non_contiguous_round_idx_uses_internal_counter():
|
| 285 |
+
"""min_steps gates on the internal update counter, not round_idx, so a caller
|
| 286 |
+
that logs sparse / non-contiguous round indices still warms up correctly."""
|
| 287 |
+
g = _guard(min_steps=3, max_proxy_real_gap=0.01, decline_patience=1)
|
| 288 |
+
# Pass huge round_idx values; only the 3rd UPDATE clears warm-up.
|
| 289 |
+
g.update(1000, in_loop_reward=0.10, heldout_score=0.90, kl_to_init=0.9)
|
| 290 |
+
g.update(2000, in_loop_reward=0.50, heldout_score=0.50, kl_to_init=0.9)
|
| 291 |
+
s = g.update(3000, in_loop_reward=0.90, heldout_score=0.10, kl_to_init=0.9)
|
| 292 |
+
assert s.fire # 3rd update, n==3 not < min_steps
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# --- config validation -------------------------------------------------------
|
| 296 |
+
|
| 297 |
+
def test_bad_ema_alpha_rejected():
|
| 298 |
+
with pytest.raises(ValueError, match="ema_alpha"):
|
| 299 |
+
HeldOutGuard(ema_alpha=1.0)
|
| 300 |
+
with pytest.raises(ValueError, match="ema_alpha"):
|
| 301 |
+
HeldOutGuard(ema_alpha=-0.1)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def test_bad_kl_hard_stop_rejected():
|
| 305 |
+
with pytest.raises(ValueError, match="kl_hard_stop"):
|
| 306 |
+
HeldOutGuard(kl_hard_stop=0.0)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def test_bad_decline_patience_rejected():
|
| 310 |
+
with pytest.raises(ValueError, match="decline_patience"):
|
| 311 |
+
HeldOutGuard(decline_patience=0)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# --- kl_token_trust_filter helper -------------------------------------------
|
| 315 |
+
|
| 316 |
+
def test_kl_token_trust_filter_masks_above_threshold():
|
| 317 |
+
# 0.5 * logratio^2; mask when it exceeds the per-token KL ceiling.
|
| 318 |
+
assert kl_token_trust_filter(0.20, threshold=0.08) is True # too large -> mask
|
| 319 |
+
assert kl_token_trust_filter(0.05, threshold=0.08) is False # within trust region
|
| 320 |
+
assert kl_token_trust_filter(0.08, threshold=0.08) is False # boundary, not masked
|
|
@@ -52,6 +52,31 @@ Goal-driven systematic resolution of every pending item. This doc is the live au
|
|
| 52 |
| F1 (`…-cb74`) | **ROTATE exposed HF write-token** — USER-ONLY (requires HF account access). AUDIT done: no live token in tracked tree (only env-var reads). Action = user rotates on huggingface.co. | P1 | DOCUMENTED (user-only) |
|
| 53 |
| F2 | Real 8B LMA run (A2/A3/A4 arms `…-42f5`,`…-dd7b`) + higher-lr sweep RUNS — GPU + budget + user go/no-go. Harness buildable (E1/E2); the spend is user-only. | — | GATED (harness only) |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
## Wave plan
|
| 56 |
- **Wave 1 (parallel):** B1, B2, B3, B4, B5, B6, B7, B8 (bugs + doc debt) ‖ D1 (Docker E2E) ‖ research fan-out (Tavily/Exa/DeepWiki) for C1/C2/E1/E2 best practices.
|
| 57 |
- **Wave 2 (parallel, after research):** C1 (held-out eval + kill-switch) ‖ C2 (EKSExecutor) ‖ C3 (containerized sandbox) ‖ E1/E2/E3 harnesses.
|
|
|
|
| 52 |
| F1 (`…-cb74`) | **ROTATE exposed HF write-token** — USER-ONLY (requires HF account access). AUDIT done: no live token in tracked tree (only env-var reads). Action = user rotates on huggingface.co. | P1 | DOCUMENTED (user-only) |
|
| 53 |
| F2 | Real 8B LMA run (A2/A3/A4 arms `…-42f5`,`…-dd7b`) + higher-lr sweep RUNS — GPU + budget + user go/no-go. Harness buildable (E1/E2); the spend is user-only. | — | GATED (harness only) |
|
| 54 |
|
| 55 |
+
## Status log
|
| 56 |
+
|
| 57 |
+
**Wave 1 — DONE (commit `c11cf49`):** B1 ✅ (fixture generated, 8 tests pass), B2 ✅ ([dev] installs on arm64), B3 ✅ ([serverless] deps), B4 ✅ (266/62 canonical), B5 ✅ (WSL footers), B6 ✅ (dead ADR link), B7 ✅ (config factories re-exported + documented), B8 ✅ (refine-summary + OVERVIEW xref), **D1 ✅ (Docker substrate E2E GREEN — 2/2 gates on real container; long-blocked item closed)**. F1 (token rotation) audited — no live token in tracked tree; user-only action documented.
|
| 58 |
+
|
| 59 |
+
**Wave 2 — DONE (built + integrated + tested):** C1 ✅ HeldOutGuard kill-switch (`composer_replication/safety/`, 23 tests), C2 ✅ EKSExecutor (single Indexed Job → N handles, gang-cancel; `eks.py` + 28 tests), C3 ✅ DockerSandbox (`docker_sandbox.py` + shared `scrub_tree` refactor; live Docker tests pass), E3 ✅ SageMakerExecutor (`sagemaker.py`; +13-test module I added — the build agent shipped it test-less, gap closed during integration). All 4 modules lint-clean, re-exported, 90/3 on targeted suite. Grounded in Phase-3 research.
|
| 60 |
+
|
| 61 |
+
**Wave 3 — Phase-7 reconciliation (from the concurrent review team `research/review-*.json`):**
|
| 62 |
+
| ID | Item | Sev | Status |
|
| 63 |
+
|---|---|---|---|
|
| 64 |
+
| R1 | **Wire `HeldOutGuard` into `composer_trainer.py`** at per-checkpoint cadence (alongside `DifficultyCurriculum.update`), feeding `token_mean_kl` as `kl_to_init`, converting a fired verdict to halt via `raise_if_fired`. Currently dead code — the #2 safeguard never fires in production. | HIGH | OPEN |
|
| 65 |
+
| R2 | **Build `composer_replication/safety/holdout.py` `HeldoutSplit`** disjointness enforcer (id/hash set-difference, raises on train↔held-out overlap) — the un-built second half of C1; the guard's gap signal is meaningless without it. | HIGH | OPEN |
|
| 66 |
+
| R3 | **EKS contract bug:** `launch_replicas` default container command runs `replica_entrypoint __main__` (argparse needs `--rendezvous/--world-size/--trainer-module`) but the indexed-job spec passes rank/world via env, not argv → a real run would fail arg-parsing. Reconcile the entrypoint contract. | HIGH | OPEN |
|
| 67 |
+
| R4 | `calibrate_kl_threshold` can yield a NEGATIVE `kl_hard_stop` on `factor<=0`/negative baseline → fires every healthy step. Guard inputs / clamp to positive floor. | LOW | OPEN |
|
| 68 |
+
| R5 | EKS/SageMaker `cancel()` swallow ALL exceptions (report success on real failure). Narrow to already-terminated (404/ResourceNotFound). | LOW | OPEN |
|
| 69 |
+
| R6 | `EKSExecutor.collect()` result dicts miss the `result` key the other backends include — cross-backend shape uniformity. | LOW | OPEN |
|
| 70 |
+
| R7 | **Doc-debt:** the 4 new Wave-2 public symbols (EKSExecutor, SageMakerExecutor, DockerSandbox, HeldOutGuard/safety) are undocumented in API_REFERENCE.md; add §12 + `.eks`/`.aws` extras. | MED | OPEN |
|
| 71 |
+
| R8 | **ADR-015** for the held-out kill-switch — referenced by `safety/__init__.py:17` + kill_switch docstrings but doesn't exist (dangling refs). Author it or drop the refs. | LOW | OPEN |
|
| 72 |
+
| R9 | Re-measure + refresh canonical test count in V1_V8_COVERAGE (Wave 2 added ~93 tests; 328→~420 collected). | LOW | OPEN |
|
| 73 |
+
| R10 | Add a test pinning the kill-switch path-(c) both-rising gap-blowout behavior; document path-(c) as a divergence-rate gate. | LOW | OPEN |
|
| 74 |
+
|
| 75 |
+
| R11 | Flaky test `spikes/006-real-hf-model-smoke/tests/test_strict.py::test_alternating_batches_loss_decreases` — fails under CPU contention (full suite w/ concurrent pytest + Docker), PASSES in isolation (verified 3x). Loss-trend assertion is timing/noise-sensitive. Pin seed / widen tolerance / mark flaky. Pre-existing, not a Wave-2 regression. | LOW | OPEN |
|
| 76 |
+
| R12 | B7-complete ✅ (top-level `__all__` now includes the 3 factories) + B4-complete ✅ (the 4 surviving "115" claims → 266/62). | — | DONE |
|
| 77 |
+
|
| 78 |
+
Sandbox refactor verdict: **clean** (no regression to LocalSubprocessSandbox/FeatureDeletionEnv).
|
| 79 |
+
|
| 80 |
## Wave plan
|
| 81 |
- **Wave 1 (parallel):** B1, B2, B3, B4, B5, B6, B7, B8 (bugs + doc debt) ‖ D1 (Docker E2E) ‖ research fan-out (Tavily/Exa/DeepWiki) for C1/C2/E1/E2 best practices.
|
| 82 |
- **Wave 2 (parallel, after research):** C1 (held-out eval + kill-switch) ‖ C2 (EKSExecutor) ‖ C3 (containerized sandbox) ‖ E1/E2/E3 harnesses.
|
|
@@ -52,8 +52,8 @@ where channel 1 is real GRPO rather than the LM-CE stub. See
|
|
| 52 |
trainer on a real reasoning benchmark.
|
| 53 |
- **Economic feasibility of channel 3.** 150 real OpenRouter calls, $0.98/trace mean, 0
|
| 54 |
errors (Spike 001).
|
| 55 |
-
- **Installable + tested.** `pip install -e .` works; **
|
| 56 |
-
|
| 57 |
|
| 58 |
## What's gapped (honest, NOT closed)
|
| 59 |
|
|
|
|
| 52 |
trainer on a real reasoning benchmark.
|
| 53 |
- **Economic feasibility of channel 3.** 150 real OpenRouter calls, $0.98/trace mean, 0
|
| 54 |
errors (Spike 001).
|
| 55 |
+
- **Installable + tested.** `pip install -e .` works; **266 passing / 62 skipped** (measured 2026-06-09;
|
| 56 |
+
canonical count + why skips vary by env: [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md)).
|
| 57 |
|
| 58 |
## What's gapped (honest, NOT closed)
|
| 59 |
|
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# Vision Validation: Does the Framework Encapsulate the Original Brief?
|
| 2 |
|
| 3 |
> **## Status as of 2026-06 (current through ADR-014)**
|
| 4 |
-
> The framework is past-skeleton: 8 subpackages (`composer_replication/*`),
|
| 5 |
> tests + 1 skip-marked (see [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md) for the
|
| 6 |
> canonical count), and operational end-to-end examples (`gsm8k_grpo`,
|
| 7 |
> `sdpo_with_real_traces_production`). The 3-channel loss, layered hint-generation,
|
|
|
|
| 1 |
# Vision Validation: Does the Framework Encapsulate the Original Brief?
|
| 2 |
|
| 3 |
> **## Status as of 2026-06 (current through ADR-014)**
|
| 4 |
+
> The framework is past-skeleton: 8 subpackages (`composer_replication/*`), 266 passing (canonical count + env-variance note in docs/V1_V8_COVERAGE.md)
|
| 5 |
> tests + 1 skip-marked (see [`docs/V1_V8_COVERAGE.md`](V1_V8_COVERAGE.md) for the
|
| 6 |
> canonical count), and operational end-to-end examples (`gsm8k_grpo`,
|
| 7 |
> `sdpo_with_real_traces_production`). The 3-channel loss, layered hint-generation,
|
|
@@ -69,6 +69,15 @@ serverless = [
|
|
| 69 |
"boto3>=1.34", # SageMakerExecutor (create_training_job) + S3 IAM
|
| 70 |
"kubernetes>=29.0", # EKSExecutor (indexed k8s Jobs via BatchV1Api)
|
| 71 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
# Replaysim dataset normalization (per ADR-004)
|
| 73 |
#
|
| 74 |
# NOTE: data-juicer is intentionally NOT pinned as an extra. The package
|
|
|
|
| 69 |
"boto3>=1.34", # SageMakerExecutor (create_training_job) + S3 IAM
|
| 70 |
"kubernetes>=29.0", # EKSExecutor (indexed k8s Jobs via BatchV1Api)
|
| 71 |
]
|
| 72 |
+
# Amazon EKS / Kubernetes Indexed-Job executor (EKSExecutor, per ADR-005).
|
| 73 |
+
# kubernetes is lazy-imported at adapter-init/method time (not at package import).
|
| 74 |
+
eks = [
|
| 75 |
+
"kubernetes>=29",
|
| 76 |
+
]
|
| 77 |
+
# Amazon SageMaker training-job executor (SageMakerExecutor, per ADR-005).
|
| 78 |
+
aws = [
|
| 79 |
+
"boto3>=1.34",
|
| 80 |
+
]
|
| 81 |
# Replaysim dataset normalization (per ADR-004)
|
| 82 |
#
|
| 83 |
# NOTE: data-juicer is intentionally NOT pinned as an extra. The package
|
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"area": "composer_replication/diloco/serverless: EKSExecutor + SageMakerExecutor vs ServerlessExecutor Protocol",
|
| 3 |
+
"verdict": "minor-issues",
|
| 4 |
+
"findings": [
|
| 5 |
+
{
|
| 6 |
+
"severity": "high",
|
| 7 |
+
"what": "EKS rank/arg plumbing contract mismatch: launch_replicas defaults the container command to ['python','-m','composer_replication.diloco.serverless.replica_entrypoint'] with NO container args, and plumbs rendezvous_uri/world_size as UPPER-CASED env vars (RENDEZVOUS_URI, WORLD_SIZE). But replica_entrypoint.py's __main__ block uses argparse with --rendezvous, --world-size, --trainer-module ALL required=True and reads NONE of those env vars (only REPLICA_RANK via os.environ). It also reads trainer_module via --trainer-module, which EKS never plumbs in any form. A pod launched with the documented EKS defaults therefore SystemExits at startup ('the following arguments are required: --rendezvous, --world-size, --trainer-module'). SageMakerExecutor does this correctly (passes ContainerArguments=['--rendezvous',...,'--world-size',...,'--trainer-module',...,'--trainer-fn',...,'--trainer-kwargs-json',...] matching the entrypoint argparse exactly). This is an end-to-end run correctness bug, not a Protocol-signature gap, and it is untested (test_launch_uses_default_entrypoint_command only asserts the command vector; no test asserts the entrypoint can actually parse what EKS supplies; trainer_module is never asserted to reach the container).",
|
| 8 |
+
"where": "composer_replication/diloco/serverless/eks.py:220-224 (default command), :296-303 (_build_env upper-cases scalars, drops nothing else), :405-458 (launch_replicas passes no ContainerArguments); contract owner composer_replication/diloco/serverless/replica_entrypoint.py:91-109 (argparse required=True, no env fallback)",
|
| 9 |
+
"recommendation": "Pick ONE: (a) make EKS pass the same arg vector SageMaker does by appending args to self.command (e.g. command + ['--rendezvous', uri, '--world-size', N, '--trainer-module', tm, ...]); OR (b) add an env-var fallback to replica_entrypoint.__main__ (read RENDEZVOUS_URI/WORLD_SIZE/TRAINER_MODULE/etc. from os.environ when CLI args are absent) so the env-only EKS plumbing works unchanged. Add a test that constructs the entrypoint argv/env exactly as EKS would and asserts main() can be invoked (e.g. argparse parse over the supplied tokens, or env-driven path). Ensure trainer_module is plumbed in whichever channel is chosen."
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"severity": "low",
|
| 13 |
+
"what": "EKS cancel() swallows ALL non-404 ApiExceptions and even generic Exceptions with a bare 'return', reporting success even when the gang delete genuinely failed (e.g. 403 RBAC-denied, 409 conflict). Because the whole point of gang-cancel is to stop the entire GPU-burning cohort, a silently-swallowed real teardown failure leaves the cohort running while the caller believes it was cancelled — the exact failure mode the design calls out. SageMakerExecutor.cancel has the same broad swallow. The Protocol only requires 'no exception if already terminated', which a 404 satisfies; swallowing 403/409 is broader than the contract needs.",
|
| 14 |
+
"where": "composer_replication/diloco/serverless/eks.py:594-600 (except ApiException -> swallow non-404; except Exception -> swallow); composer_replication/diloco/serverless/sagemaker.py:470-475 (bare except Exception: pass)",
|
| 15 |
+
"recommendation": "Narrow the swallow to the 'already terminated' cases (404 / ResourceNotFound, and SageMaker's already-terminal ValidationException) and at minimum log/warn (or re-raise) on other API errors so a failed gang-teardown of GPU resources is observable rather than silent. Best-effort can still mean 'do not raise', but it should emit a warning on a non-idempotent failure."
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"severity": "low",
|
| 19 |
+
"what": "Result-dict shape inconsistency across executors in collect(): SageMaker and Modal/Local include a 'result' key (SageMaker surfaces ModelArtifacts.S3ModelArtifacts path; Local/Modal include the in-process return value), but EKS _result_dict omits 'result' entirely and instead adds 'job_name'. The Protocol only mandates {rank,status,exit_code,error} 'at least', so this is conformant, but the divergence makes a backend-agnostic caller that reads result['result'] KeyError on EKS. Note: this is NOT a 'collect() not reading S3' Protocol violation — the Protocol/ADR-005 do not require collect() to read S3 contents; the payload flows through ObjectStoreAllReduce/S3 written by the replica itself, and collect() correctly returns status metadata (the reference LocalProcessExecutor returns an in-process value, not S3). SageMaker surfacing the S3 artifact path is a nice-to-have, not a requirement.",
|
| 20 |
+
"where": "composer_replication/diloco/serverless/eks.py:655-671 (_result_dict: no 'result' key, adds 'job_name'); compare sagemaker.py:588-595 (includes 'result': artifacts.get('S3ModelArtifacts')) and executor.py:104-107 (Protocol documents only the 4 required keys)",
|
| 21 |
+
"recommendation": "For cross-backend uniformity, add a 'result': None (or the rendezvous output URI if known) key to EKS _result_dict so callers can read result['result'] uniformly across executors. Optionally document in the Protocol docstring that 'result' is an optional, backend-specific extra key so callers use .get('result')."
|
| 22 |
+
}
|
| 23 |
+
],
|
| 24 |
+
"confirmed_good": [
|
| 25 |
+
"Both EKSExecutor and SageMakerExecutor satisfy the runtime_checkable ServerlessExecutor Protocol: isinstance(EKSExecutor(image=...,batch_api=...,core_api=...), ServerlessExecutor) is True and isinstance(SageMakerExecutor(...), ServerlessExecutor) is True (verified at runtime); both expose backend_name ('eks'/'sagemaker'), supports_inter_replica_network (both False, correct — S3-only rendezvous), and all five methods launch_replicas/poll/stream_logs/cancel/collect.",
|
| 26 |
+
"Both are exported from serverless/__init__.py and present in __all__ (EKSExecutor line 50/62, SageMakerExecutor line 59/68).",
|
| 27 |
+
"EKS single-Indexed-Job -> N-handles topology is correct: exactly one create_namespaced_job, completions==parallelism==n_replicas, completionMode='Indexed', restartPolicy='Never', backoffLimit=0, active_deadline_seconds==timeout, ttl_seconds_after_finished set; returns N rank-ordered handles (handles[i].rank==i) all sharing job_name/namespace (test_launch_returns_n_rank_ordered_handles, test_launch_creates_indexed_job_spec).",
|
| 28 |
+
"EKS gang-cancel is correct: cancel(any handle) deletes the WHOLE shared Job with propagation_policy='Background' (cascading pod deletion, not the k8s default Orphan) and grace_period_seconds=0; idempotent on 404 (test_cancel_uses_background_propagation_on_shared_job, test_cancel_swallows_404, test_cancel_unknown_handle_is_noop).",
|
| 29 |
+
"EKS rank plumbing via downward API is correct: REPLICA_RANK set via V1EnvVarSource.field_ref field_path metadata.annotations['batch.kubernetes.io/job-completion-index'] (value is None, value_from set), bridging k8s completion-index to the entrypoint's REPLICA_RANK read without modifying the entrypoint; rank_env LocalProcessExecutor convention is stripped (test_launch_rank_env_uses_downward_api_field_ref, test_launch_strips_rank_env_kwarg). NOTE: this rank channel works; the BROKEN channel is rendezvous_uri/trainer_module (see high finding).",
|
| 30 |
+
"EKS poll status mapping covers all five Protocol states: rank in completed_indexes->succeeded (checked first, so a succeeded rank is not mis-flagged by a whole-job Failed condition), rank in failed_indexes->failed, whole-job Failed condition->failed (DeadlineExceeded/backoff), active>0->running, else pending, 404->cancelled, non-404 ApiException re-raised; run-length index strings expanded correctly incl. reversed ranges and whitespace (test_poll_* x7, test_expand_indexes_*).",
|
| 31 |
+
"EKS GPU resource limit is always a STRING ('1' not int 1) per OpenAPI dict[str,str] typing; GPU node selector merged (caller wins) and nvidia.com/gpu NoSchedule toleration auto-added; CPU-only omits the gpu limit (test_launch_gpu_limit_is_string, test_launch_cpu_only_omits_gpu_limit).",
|
| 32 |
+
"EKS partial-failure sibling cleanup is correctly N/A: launch issues exactly ONE create_namespaced_job (atomic gang scheduling), so there are no siblings to clean up if it fails — a genuine advantage of the single-Indexed-Job topology over N-job designs.",
|
| 33 |
+
"SageMaker correctly uses N independent single-instance training jobs (ResourceConfig.InstanceCount==1) with rank via the Environment map (REPLICA_RANK/WORLD_SIZE/RENDEZVOUS_URI), and correctly passes the entrypoint args via ContainerArguments matching replica_entrypoint argparse; EnableNetworkIsolation pinned False (else S3 rendezvous deadlocks) — verified in test_launch_injects_rank_world_size_and_rendezvous_env.",
|
| 34 |
+
"SageMaker partial-failure sibling cleanup is correct: a create_training_job failure at rank k best-effort stops the k already-launched siblings then raises with rank context (test_launch_partial_failure_stops_siblings_and_raises asserts 2 siblings stopped).",
|
| 35 |
+
"SageMaker poll status mapping covers all 5 documented TrainingJobStatus values (InProgress->running, with SecondaryStatus refinement to pending for Starting/Pending/LaunchingMLInstances/PreparingTrainingStack; Completed->succeeded; Failed->failed; Stopping->running; Stopped->cancelled), vanished job (ResourceNotFound)->cancelled, unknown handle->cancelled; collect() correctly checks RAW SM status for terminality so Stopping keeps polling until Stopped (test_poll_status_mapping, test_poll_failed_and_stopped, test_poll_vanished_job_is_cancelled, test_poll_unknown_handle_is_cancelled).",
|
| 36 |
+
"collect() reading S3: NOT a violation. The Protocol (executor.py:96-108) and ADR-005 require collect() to return status/exit metadata, not S3 contents — the result payload flows through ObjectStoreAllReduce written to S3 by each replica. SageMaker even surfaces the ModelArtifacts S3 path in result['result']. The reference LocalProcessExecutor returns an in-process value, confirming collect is not contractually an S3 reader.",
|
| 37 |
+
"Full suite green: .venv/bin/python -m pytest composer_replication/diloco/serverless -q => 53 passed, 17 skipped (skips are the boto3/kubernetes/modal absent-path guards that cannot fire when the package is importable in this interpreter, plus integration gates)."
|
| 38 |
+
],
|
| 39 |
+
"new_backlog_items": [
|
| 40 |
+
"EKS end-to-end run bug: default container command runs replica_entrypoint __main__ (argparse --rendezvous/--world-size/--trainer-module required) but EKSExecutor supplies env vars + no args and never plumbs trainer_module -> pod crashes on startup. Fix by passing ContainerArguments-equivalent args OR adding an env-var fallback to replica_entrypoint.__main__; add a test that the EKS-supplied argv/env actually parses. (Not in BACKLOG_RESOLUTION_2026-06-09; C2 only tracked building EKSExecutor, not the entrypoint contract.)",
|
| 41 |
+
"Tighten EKSExecutor.cancel and SageMakerExecutor.cancel exception handling: only swallow 'already-terminated' errors (404/ResourceNotFound, already-terminal ValidationException); log/warn on other API errors so a failed gang-teardown of GPU resources is observable instead of silently leaving the cohort burning compute.",
|
| 42 |
+
"Add a 'result' key to EKSExecutor.collect() result dicts (None or the rendezvous output URI) for cross-backend uniformity with Local/Modal/SageMaker, OR document in the Protocol that 'result' is an optional backend-specific extra so callers use .get('result')."
|
| 43 |
+
]
|
| 44 |
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"area": "Wave-1+2 broad sweep for NEW gaps (imports/laziness, unfinished-work markers, doc-debt, ADR-015, optional-dep eager-load)",
|
| 3 |
+
"verdict": "minor-issues",
|
| 4 |
+
"findings": [
|
| 5 |
+
{
|
| 6 |
+
"severity": "medium",
|
| 7 |
+
"what": "Doc-debt: the 4 NEW Wave-2 public symbols are entirely undocumented in docs/API_REFERENCE.md. grep for EKSExecutor / SageMakerExecutor / DockerSandbox / HeldOutGuard / TripwireStatus / CollapseStopError / kl_token_trust_filter all return 0 hits. API_REFERENCE §12 (serverless) header (line 23) lists `.modal`, `.hf_jobs` but not `.eks` / `.sagemaker`, and documents the loud-failing ModalExecutor/HFJobsExecutor stubs while omitting the two NEW *production* executors. There is no `safety` section at all, and no `datagen` section (DockerSandbox + its LocalSubprocessSandbox/FakeSandbox siblings are all undocumented). All four are real, exported public API (in their package __all__) and Protocol-conformant (isinstance(eks, ServerlessExecutor) == True).",
|
| 8 |
+
"where": "docs/API_REFERENCE.md (§12 line 1153-1376; header line 23); new public symbols in composer_replication/diloco/serverless/{eks,sagemaker}.py, composer_replication/datagen/docker_sandbox.py, composer_replication/safety/kill_switch.py",
|
| 9 |
+
"recommendation": "Add API_REFERENCE entries: under §12 add `class EKSExecutor` and `class SageMakerExecutor` (and update the §12 line-23 module list to include `.eks`, `.sagemaker`); add a `composer_replication.safety` section documenting HeldOutGuard / TripwireStatus / CollapseStopError / kl_token_trust_filter; and a `composer_replication.datagen` section documenting DockerSandbox (alongside the existing-but-also-undocumented LocalSubprocessSandbox/FakeSandbox)."
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"severity": "low",
|
| 13 |
+
"what": "Dangling ADR reference: composer_replication/safety/__init__.py:17 says 'See docs/adrs/ADR-015-*.md' but no ADR-015 file exists (docs/adrs/ stops at ADR-014). The research plan called for ADR-015 to document the safety/kill-switch design decision; the module docstring already cites the literature (Zhao et al. RSI, EvilGenie, Gao self-evolving survey, Shumailov collapse, Catastrophic Goodhart, GRPO KL band) so the design rationale exists in-code but is not captured as an ADR, and the __init__ points readers to a file that isn't there.",
|
| 14 |
+
"where": "composer_replication/safety/__init__.py:17 (the dangling 'docs/adrs/ADR-015-*.md' pointer); docs/adrs/ (ADR-015 absent)",
|
| 15 |
+
"recommendation": "Either author docs/adrs/ADR-015-holdout-killswitch.md (the kill_switch.py module docstring is effectively the ADR draft already — proxy_real_gap Hacking-Gap, KL 0.08 nats/token hard stop, decline-patience collapse signature, defense-in-depth-over-HackMonitor) and index it in docs/adrs/README.md, OR remove the forward reference from safety/__init__.py until the ADR lands."
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"severity": "low",
|
| 19 |
+
"what": "Test-count drift re-introduced by Wave 2. docs/V1_V8_COVERAGE.md:117 still states the canonical count as '266 passed / 62 skipped / 328 collected (measured 2026-06-09)' — that was the Wave-1 figure. Wave 2 added 93 tests across 4 new files (test_kill_switch 23, test_eks_executor 28, test_sagemaker_executor 14, test_docker_sandbox 28); the tree now collects 420 tests (328 -> 420, +92 net). B4 closed test-drift in Wave 1 but the doc is stale again post-Wave-2.",
|
| 20 |
+
"where": "docs/V1_V8_COVERAGE.md:117-134 (canonical count claim) vs actual `pytest --collect-only` = 420 collected",
|
| 21 |
+
"recommendation": "Re-run `.venv/bin/python -m pytest` to get the post-Wave-2 passed/skipped split and update the single canonical figure in V1_V8_COVERAGE.md (the doc explicitly says this line is 'the one canonical figure' that other docs reference)."
|
| 22 |
+
}
|
| 23 |
+
],
|
| 24 |
+
"confirmed_good": [
|
| 25 |
+
"Required import smoke test passes: `import composer_replication; from composer_replication.diloco.serverless import EKSExecutor, SageMakerExecutor; from composer_replication.datagen import DockerSandbox; from composer_replication.safety import HeldOutGuard` -> exit 0, 'ALL IMPORTS OK'.",
|
| 26 |
+
"Optional-dep laziness (question 5) is CORRECT for all 4 new modules: no top-level `import kubernetes/boto3/docker` in eks.py / sagemaker.py / docker_sandbox.py / kill_switch.py (grep for eager imports returns empty). Blocking kubernetes+docker at import time and importing the new modules in isolation succeeds. EKSExecutor lazy-imports `kubernetes` only when no api injected / per-method; SageMakerExecutor lazy-imports boto3 in _make_boto3_client (construction-time, not import-time); DockerSandbox lazy-imports docker via _require_docker() inside methods.",
|
| 27 |
+
"NOTE on the whole-package blocked-import failure: blocking boto3 breaks `import composer_replication`, but the cause is PRE-EXISTING and NOT a Wave-2 regression — composer_replication/__init__.py:98 imports the trainer, which imports `trl.GRPOTrainer` -> accelerate.commands.config.sagemaker -> `import boto3`. boto3 is already a hard transitive dependency of the base trainer stack on main; Wave 2 did not introduce it.",
|
| 28 |
+
"No NEW unfinished-work markers (question 2): all NotImplementedError/TODO/FIXME/STUB hits in composer_replication/ are PRE-EXISTING and intentional (prime_rl/composer_loss.py deferred SDPO channel-2, recipes/monarch/actors.py v0 skeleton per ADR-006, diloco/serverless/{modal,hf_jobs,modal_spawn}.py documented loud-failing stubs). The 4 new modules contain ZERO NotImplementedError/TODO/FIXME/STUB — they are finished, not skeletons. SageMakerExecutor's docstring explicitly contrasts itself as 'fully-working, not the loud-failing modal.py/hf_jobs.py skeletons'.",
|
| 29 |
+
"Both new executors satisfy the runtime_checkable ServerlessExecutor Protocol (isinstance checks pass), expose correct backend_name ('eks'/'sagemaker') and supports_inter_replica_network=False (S3-only rendezvous).",
|
| 30 |
+
"All 90 collectable Wave-2 tests pass (3 skipped, the live-docker-daemon gated ones) via `pytest composer_replication/safety/tests composer_replication/diloco/serverless/tests/test_{eks,sagemaker}_executor.py composer_replication/datagen/tests/test_docker_sandbox.py`. Whole suite still collects cleanly (420 tests, no collection errors).",
|
| 31 |
+
"DockerSandbox.run_tests pytest-pass heuristic (`f\"{t} PASSED\" in out or (returncode==0 and not failed)`) is a faithful copy of the established LocalSubprocessSandbox.run_tests (sandbox.py:214) — not a new bug, consistent with the documented sibling behavior.",
|
| 32 |
+
"safety/ not being in the top-level composer_replication.__all__ is consistent with existing structure (datagen/diloco subpackages aren't fully surfaced at top level either); `composer_replication.safety` imports correctly as a subpackage."
|
| 33 |
+
],
|
| 34 |
+
"new_backlog_items": [
|
| 35 |
+
"DOC: Document the 4 NEW Wave-2 public symbols in docs/API_REFERENCE.md — add EKSExecutor + SageMakerExecutor under §12 (and add .eks/.sagemaker to the §12 module list at line 23), add a new `composer_replication.safety` section (HeldOutGuard, TripwireStatus, CollapseStopError, kl_token_trust_filter), and a `composer_replication.datagen` section covering DockerSandbox (+ the also-undocumented LocalSubprocessSandbox/FakeSandbox).",
|
| 36 |
+
"ADR: Author docs/adrs/ADR-015-holdout-killswitch.md (the safety kill-switch / held-out-guard design) — currently referenced by composer_replication/safety/__init__.py:17 as 'docs/adrs/ADR-015-*.md' but the file does not exist; index it in docs/adrs/README.md. The kill_switch.py module docstring is the ready-made draft.",
|
| 37 |
+
"DOC: Refresh the canonical test count in docs/V1_V8_COVERAGE.md:117 — Wave 2 added 93 tests (collection 328 -> 420); the stated '266 passed / 62 skipped / 328 collected' is the Wave-1 figure and is now stale."
|
| 38 |
+
]
|
| 39 |
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"area": "composer_replication/safety/kill_switch.py + test_kill_switch.py (Wave-2 C1)",
|
| 3 |
+
"verdict": "material-issues",
|
| 4 |
+
"findings": [
|
| 5 |
+
{
|
| 6 |
+
"severity": "high",
|
| 7 |
+
"what": "C1 was scoped as 'Held-out disjoint eval + depth/generation kill-switch' but ONLY the kill-switch half (HeldOutGuard) was built. The HeldoutSplit disjointness-enforcer does not exist anywhere in the tree (no composer_replication/safety/holdout.py, no HeldoutSplit class). The guard's heldout_score is an unvalidated caller-supplied float; nothing enforces that the held-out pool is actually disjoint from the train/generator set. The module's own docstring (kill_switch.py:41-43, 214-216) states this is load-bearing: 'if held-out drifts with the train set the gap signal is meaningless.' So the kill-switch's central proxy-real-gap and decline-streak signals can be silently meaningless with no guard rail.",
|
| 8 |
+
"where": "composer_replication/safety/ (missing holdout.py / HeldoutSplit); referenced at kill_switch.py:43, kill_switch.py:214-216",
|
| 9 |
+
"recommendation": "Build the HeldoutSplit disjointness enforcer (hash/id-based set-difference check that the held-out eval IDs never intersect the generator/train IDs, raising on overlap) as the second half of C1, OR explicitly re-scope C1 to two items and track the disjointness enforcer as a distinct OPEN backlog item. Do not mark C1 done with only the guard built."
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"severity": "high",
|
| 13 |
+
"what": "HeldOutGuard is NOT wired into the trainer. Zero references to HeldOutGuard / kill_switch / CollapseStopError / should_halt / raise_if_fired in composer_replication/trainer/composer_trainer.py (or anywhere outside the safety package + its own test). The 'most load-bearing collapse safeguard (#2)' for the self-evolving flywheel exists as dead, never-invoked code. The trainer's GRPO loop never calls update() per checkpoint, so the run-level tripwire cannot fire in production.",
|
| 14 |
+
"where": "composer_replication/trainer/composer_trainer.py (no integration); HeldOutGuard defined composer_replication/safety/kill_switch.py:117",
|
| 15 |
+
"recommendation": "Wire HeldOutGuard.update(round_idx, in_loop_reward, heldout_score, kl_to_init=token_mean_kl(...)) into the trainer loop at the same checkpoint cadence DifficultyCurriculum.update is called (curriculum.py:78), and convert a fired verdict to a halt via raise_if_fired / should_halt. token_mean_kl already exists (kl_logging.py:53) to supply the per-token KL. Until wired, C1's safety claim is unrealized."
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"severity": "low",
|
| 19 |
+
"what": "calibrate_kl_threshold does not re-validate the > 0 invariant that __post_init__ enforces. A negative factor (or negative baseline_kls) yields min(negative, 0.08) = a NEGATIVE kl_hard_stop, after which the KL tripwire fires on EVERY healthy step (any positive KL EMA > negative ceiling). Verified empirically: factor=-3.0 on baseline [0.01] sets kl_hard_stop=-0.03 and a healthy KL of 0.01 then fires. The min() 'tighten-only' clamp is satisfied in the literal numeric sense but violates the documented collapse-band semantics.",
|
| 20 |
+
"where": "composer_replication/safety/kill_switch.py:412-418 (calibrate_kl_threshold)",
|
| 21 |
+
"recommendation": "Validate factor > 0 and all(k >= 0 for k in baseline_kls) at the top of calibrate_kl_threshold, and/or clamp the result to a small positive floor (e.g. assert calibrated > 0). KL values are non-negative by definition so a negative factor is nonsensical input, but the invariant should be guarded since the method mutates a field __post_init__ otherwise protects."
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"severity": "low",
|
| 25 |
+
"what": "Dangling cross-references in docstrings to artifacts that do not exist: safety/__init__.py:17-18 cites 'docs/adrs/ADR-015-*.md' (highest existing ADR is ADR-014; no ADR-015 file) and a \"'holdout-killswitch' research digest\" (no such file under research/). kill_switch.py:43,214 cite composer_replication.safety.holdout / HeldoutSplit 'design notes' that do not exist (same missing module as the high finding).",
|
| 26 |
+
"where": "composer_replication/safety/__init__.py:17-18; composer_replication/safety/kill_switch.py:43, 214-216",
|
| 27 |
+
"recommendation": "Either author ADR-015 documenting the kill-switch design decision (the module is substantial enough to warrant one and the docstring already promises it), or drop the dangling citations. Keep doc references honest to avoid the stale-cross-ref foot-guns the backlog (B5/B6/B8) is already cleaning up."
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"severity": "low",
|
| 31 |
+
"what": "Gap-blowout path (c) fires when the proxy gain exceeds real gain by max_proxy_real_gap EVEN WHEN the held-out (real) score is still genuinely RISING. Verified: with both rising but proxy faster, it halts the run while real improvement is ongoing. This is defensible per the docstring ('fast single-generation divergence', lines 144-145), and the reason string is accurate, but it is a potential false-positive halt on a healthy-but-fast-proxy run and is not covered by a test asserting the desired behavior in the both-rising case (only the proxy-flat-real case is tested at test_kill_switch.py:143).",
|
| 32 |
+
"where": "composer_replication/safety/kill_switch.py:326-335 (path c); test gap test_kill_switch.py:143-158 only exercises real-flat",
|
| 33 |
+
"recommendation": "Add a test pinning the intended behavior when BOTH rise but proxy outpaces real beyond the ceiling (assert whether it should fire), and document in the docstring that path (c) is a divergence-RATE gate, not a real-decline gate, so future readers do not mistake a fired path-(c) for confirmed real regression."
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"confirmed_good": [
|
| 37 |
+
"All 23 tests in composer_replication/safety pass (.venv pytest, 23 passed).",
|
| 38 |
+
"Latched-fire is correct and cannot un-halt: _fired flips True in update() (line 277-278) and _evaluate() short-circuits with a 'latched:' verdict carrying the original reason before any threshold re-check (lines 294-296). Verified a full KL/gap recovery after fire stays fire=True.",
|
| 39 |
+
"Three halt conditions are individually correct: (b) KL EMA > kl_hard_stop checked first; (a) held-out-declines-while-in-loop-rises only increments the streak when BOTH conditions hold (a both-declining 'hard batch' correctly does NOT count, verified), fires at decline_patience; (c) proxy-real gap > ceiling. min_steps warm-up gate uses the internal _n counter (robust to non-contiguous round_idx, tested).",
|
| 40 |
+
"EMA denoising is sound: _fold seeds on first sample (no warm-up bias), alpha is weight-on-prior validated to [0,1); first-sample baseline seeding makes proxy_real_gap a gain-since-baseline quantity exactly matching the RSI Hacking-Gap definition. proxy_real_gap math verified (0.15 expected case) and returns 0.0 before first update.",
|
| 41 |
+
"CollapseStopError raise path: raise_if_fired raises the typed exception carrying .status only when fired, is a no-op when clean, and is a safe no-op before any update (last_status None). Strict > boundary on gap/KL confirmed (gap==ceiling does not fire).",
|
| 42 |
+
"calibrate tighten-only works for the intended (positive) inputs: min(3x baseline, current) so a drifting baseline cannot loosen past 0.08 (tested), and only tightens for a clean low baseline.",
|
| 43 |
+
"kl_token_trust_filter boundary correct (strict >, so threshold value itself is not masked).",
|
| 44 |
+
"Docstring cross-refs that DO resolve: DifficultyCurriculum.update (curriculum.py:78) and token_mean_kl (kl_logging.py:53) both exist, so the claimed cadence and KL-units convention are anchored to real code.",
|
| 45 |
+
"No false claim anywhere in examples/ or docs that the kill-switch is already wired/used (grep clean)."
|
| 46 |
+
],
|
| 47 |
+
"new_backlog_items": [
|
| 48 |
+
"Build composer_replication/safety/holdout.py with a HeldoutSplit disjointness enforcer (id/hash set-difference, raises on train/held-out overlap) — the un-built second half of C1 that the kill-switch's gap/decline signals depend on for validity.",
|
| 49 |
+
"Wire HeldOutGuard into composer_replication/trainer/composer_trainer.py at the per-checkpoint cadence (alongside DifficultyCurriculum.update), feeding token_mean_kl as kl_to_init and converting a fired verdict to a halt via raise_if_fired/should_halt — the C1 safeguard is currently dead code.",
|
| 50 |
+
"Guard calibrate_kl_threshold against factor<=0 / negative baseline_kls (or clamp result to a positive floor) so calibration cannot drive kl_hard_stop negative and make the KL tripwire fire on every healthy step.",
|
| 51 |
+
"Author docs/adrs/ADR-015 for the held-out kill-switch (referenced by safety/__init__.py:17 but nonexistent) or remove the dangling ADR-015 + 'holdout-killswitch research digest' citations.",
|
| 52 |
+
"Add a test pinning path-(c) gap-blowout behavior in the BOTH-rising case (proxy outpaces a still-rising real) to lock the intended false-positive/true-positive decision."
|
| 53 |
+
]
|
| 54 |
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"area": "composer_replication/datagen/docker_sandbox.py + sandbox.py scrub_tree refactor",
|
| 3 |
+
"verdict": "clean",
|
| 4 |
+
"findings": [
|
| 5 |
+
{
|
| 6 |
+
"severity": "low",
|
| 7 |
+
"what": "run_tests pass/fail parse carries the order-dependent fallback clause `if f\"{t} PASSED\" in out or (returncode == 0 and not failed)` verbatim from LocalSubprocessSandbox. If a runner exits 0 but does not print '<nodeid> PASSED' for every node id, the first un-printed node is marked passed solely on the exit code (and `not failed` is true only until the first failure is recorded). This is a pre-existing pattern (identical on main's LocalSubprocessSandbox at sandbox.py:214) faithfully mirrored into DockerSandbox, NOT a new regression — flagged only for completeness.",
|
| 8 |
+
"where": "composer_replication/datagen/docker_sandbox.py:272-276 (and the source LocalSubprocessSandbox at sandbox.py:212-217)",
|
| 9 |
+
"recommendation": "No action required for this review. If ever hardened, require an explicit PASSED token per node id and stop trusting the bare exit code; do it in both sandboxes together so they stay in lock-step."
|
| 10 |
+
}
|
| 11 |
+
],
|
| 12 |
+
"confirmed_good": [
|
| 13 |
+
"REFACTOR DID NOT BREAK LocalSubprocessSandbox: boot() still scrubs — boot() (sandbox.py:169-172) calls self._scrub_tree() which delegates to the shared module-level scrub_tree() free function (sandbox.py:174-177). Smoke test confirmed __pycache__, .git, and *.pyc are removed on boot while real source (keep.py) survives.",
|
| 14 |
+
"No broken/dangling references to the old per-class _scrub_tree: the only remaining _scrub_tree occurrences are (a) the intentional back-compat delegating method + its self-call in boot, and (b) one descriptive comment in test_docker_substrate_e2e.py:161. grep for external callers of .SCRUB_NAMES/._SCRUB_NAMES/.SCRUB_SUFFIXES returned EMPTY.",
|
| 15 |
+
"Back-compat preserved: LocalSubprocessSandbox._SCRUB_NAMES / ._SCRUB_SUFFIXES class aliases still point at the module-level SCRUB_NAMES/SCRUB_SUFFIXES; the _scrub_tree() method is retained.",
|
| 16 |
+
"FeatureDeletionEnv unaffected: env.py uses the Sandbox Protocol generically (boot/exec/run_tests/trajectory at env.py:59,69,86,89) — agnostic to the scrub refactor.",
|
| 17 |
+
"SCRUB-BEFORE-MOUNT ORDERING IS CORRECT (no security bug): DockerSandbox.boot() runs scrub_tree(self.workdir) at line 190 BEFORE self._client.containers.run(**kwargs) at line 198. The container (and thus the RW bind mount) does not exist when the host-side scrub runs, so the scrub is provably pre-mount. The scrub-AFTER-mount security bug the audit asked to rule out is NOT present.",
|
| 18 |
+
"--network none: both network_disabled=True AND network_mode='none' set (docker_sandbox.py:154-155); live test_live_network_is_disabled actually ran on a real container and asserted egress BLOCKED / not CONNECTED.",
|
| 19 |
+
"Resource limits: mem_limit == memswap_limit (forbids swap), pids_limit (fork-bomb guard), nano_cpus (CPU quota); all present, configurable, and unit-asserted.",
|
| 20 |
+
"Ephemeral teardown: close() force-removes (idempotent, swallows errors), reap_leaked() sweeps label-filtered orphan containers at boot and shutdown, __enter__/__exit__/__del__ wired. Verified by test_close_removes_container_force, test_context_manager_closes, test_reap_leaked_sweeps_labelled_containers.",
|
| 21 |
+
"gVisor runtime option: runtime defaults to None (=> 'runtime' kwarg omitted, daemon-default runc); 'runsc' is only passed through when explicitly set (docker_sandbox.py:178-179) and gated by runsc_available(). test_live_runsc_runtime correctly SKIPPED (gVisor not installed on host).",
|
| 22 |
+
"Lazy docker import: _require_docker() imports `docker` inside the function with a clear RuntimeError on ImportError; docker SDK is never required by the FakeSandbox/pure-core path. Verified by test_require_docker_missing_sdk_raises.",
|
| 23 |
+
"Privilege lockdown: cap_drop=['ALL'], security_opt=['no-new-privileges:true'], user='1000:1000' (non-root), read_only root fs with tmpfs /tmp (noexec,nosuid), keep_root_writable escape hatch.",
|
| 24 |
+
"shlex.quote applied to every test node id in run_tests (shell-injection guard, matches LocalSubprocessSandbox); non-UTF-8 output decoded with errors='replace' (test_exec_decodes_non_utf8_bytes); exec wraps commands in coreutils `timeout`.",
|
| 25 |
+
"TEST SUITE: `.venv/bin/python -m pytest composer_replication/datagen -q` => 61 passed, 1 skipped (runsc only). The LIVE Docker E2E genuinely RAN (not skipped): test_live_four_inversion_gates_in_hardened_container, test_live_network_is_disabled, test_live_cache_scrub_removes_bytecode all PASSED on a real python:3.11-slim container. The long-blocked D1 substrate E2E (test_docker_substrate_e2e.py) is also GREEN (2/2). Broader regression datagen+safety+serverless => 137 passed, 18 skipped, no failures.",
|
| 26 |
+
"Public surface re-exports DockerSandbox and scrub_tree from composer_replication/datagen/__init__.py and __all__; package imports cleanly."
|
| 27 |
+
],
|
| 28 |
+
"new_backlog_items": []
|
| 29 |
+
}
|