"""build_corpus.py — the local Stage-0 stage-driver (architecture step 6-7). One function wires the whole local pipeline: holdout-split the task pool (holdout tasks are NEVER rolled out — they are the eval anchor), roll out a policy over each train task, admit + type + route trajectories (sft / dpo-candidate / quarantine), dedup the SFT rows (within-run AND against a prior generation's signatures), and write everything through the single s3_contract layout with a manifest + dataset card. Deliberately LOCAL-first (finding D-9): the five-service AWS orchestration is Stage 4; this driver must produce one real corpus end-to-end on a laptop with a FakeSandbox before anything is distributed. Write-once per layout (finding D-21): refuses to run if the manifest already exists. """ from __future__ import annotations from typing import Callable, Sequence from composer_replication.datagen.env import FeatureDeletionEnv from composer_replication.datagen.rollout_harness import ( RolloutPolicy, admit, collect_trajectory, ) from composer_replication.datagen.schema import FeatureDeletionTask from composer_replication.datagen.trajectory import to_policy_row from composer_replication.pipeline import s3_contract from composer_replication.pipeline.dedup import dedup from composer_replication.pipeline.s3_contract import RunLayout, RunManifest from composer_replication.safety.holdout import HeldoutSplit def build_corpus( source_tasks: Sequence[FeatureDeletionTask], env_factory: Callable[[], FeatureDeletionEnv], policy_factory: Callable[[], RolloutPolicy], layout: RunLayout, manifest: RunManifest, *, holdout_frac: float = 0.2, holdout_seed: int = 0, max_tasks: int | None = None, cost_per_rollout_usd: float = 0.0, prior_signatures: Sequence[Sequence[int]] | None = None, dedup_threshold: float = 0.85, ) -> RunManifest: """Run the Stage-0 pipeline over `source_tasks`; returns the final manifest. Args: source_tasks: gate_repo-admitted FeatureDeletionTasks (the caller runs `datagen.repo_gate.gate_repo` BEFORE this — the driver assumes the license/decontamination gates already passed). env_factory: fresh `FeatureDeletionEnv` per rollout (a sandbox is stateful; sharing one across episodes leaks trajectory state). policy_factory: fresh policy per rollout (ScriptedPolicy is stateful). manifest: a `RunManifest` with run_id/created_at/budget set by the caller (created_at is caller-passed for reproducibility). cost_per_rollout_usd: accounting hook — API policies should report real cost; the driver enforces `manifest.budget_usd` with it. prior_signatures: previous generation's MinHash signatures (cross-generation dedup, finding D-12). Raises: FileExistsError: if the layout already has a manifest (write-once). """ if s3_contract.manifest_exists(layout): raise FileExistsError( f"Run layout already has a manifest at {layout.manifest_path} — " "runs are write-once per (root, run_id); mint a new run_id " "(finding D-21 idempotency)." ) # 1. Holdout split FIRST — held-out tasks are never rolled out, so no # training signal can derive from them (the HeldoutSplit discipline). split = HeldoutSplit.split(source_tasks, holdout_frac=holdout_frac, seed=holdout_seed, check_content=True) by_id = {t.task_id: t for t in source_tasks} holdout_tasks = [by_id[i] for i in sorted(split.holdout_ids)] train_tasks = [by_id[i] for i in sorted(split.train_ids)] if max_tasks is not None: train_tasks = train_tasks[:max_tasks] # 2. Rollouts + admission routing. sft_rows: list[dict] = [] dpo_rows: list[dict] = [] quarantine_rows: list[dict] = [] traj_rows: list[dict] = [] partial = False for task in train_tasks: # Hard ceiling (Wave-21 review P1): a rollout only starts if its cost # still fits — pre-charging prevents the one-rollout overshoot the # old check-then-spend ordering allowed. if manifest.budget_usd is not None and ( manifest.cost_usd + cost_per_rollout_usd > manifest.budget_usd ): partial = True break traj = collect_trajectory(env_factory(), task, policy_factory(), provenance={"run_id": manifest.run_id}) manifest.spend(cost_per_rollout_usd) verdict = admit(traj) row = to_policy_row(traj, task) traj_rows.append({**row, "admission": list(verdict.reasons)}) if verdict.sft_admitted: sft_rows.append(row) elif verdict.dpo_candidate: dpo_rows.append(row) else: quarantine_rows.append({**row, "reasons": list(verdict.reasons)}) # 3. Dedup the SFT corpus (within-run + cross-generation). def _key(r: dict) -> str: return " ".join(m.get("content", "") for m in r.get("messages", [])) sft_rows, dedup_stats = dedup(sft_rows, _key, dedup_threshold, prior_signatures=prior_signatures) # 4. Write everything through the contract. s3_contract.write_tasks(layout, train_tasks) s3_contract.write_tasks_full(layout, train_tasks) s3_contract.write_holdout(layout, holdout_tasks) s3_contract.write_trajectories(layout, traj_rows) s3_contract.write_sft_rows(layout, sft_rows) s3_contract.write_dpo_rows(layout, dpo_rows) s3_contract.write_quarantine(layout, quarantine_rows) manifest.counts = { "tasks_train": len(train_tasks), "tasks_holdout": len(holdout_tasks), "rollouts": len(traj_rows), "sft_rows": len(sft_rows), "dpo_rows": len(dpo_rows), "quarantined": len(quarantine_rows), **{f"dedup_{k}": v for k, v in dedup_stats.items()}, } manifest.status = "partial" if partial else "building" manifest.write(layout) s3_contract.write_dataset_card(layout, manifest, dedup_stats=dedup_stats) return manifest __all__ = ["build_corpus"]