Baladithya Balamurugan
Wave 21: adversarial-review fixes — all 9 verified findings closed
3bbcf21
Raw
History Blame Contribute Delete
6.2 kB
"""build_corpus.py — the local Stage-0 stage-driver (architecture step 6-7).
One function wires the whole local pipeline: holdout-split the task pool
(holdout tasks are NEVER rolled out — they are the eval anchor), roll out a
policy over each train task, admit + type + route trajectories
(sft / dpo-candidate / quarantine), dedup the SFT rows (within-run AND against
a prior generation's signatures), and write everything through the single
s3_contract layout with a manifest + dataset card.
Deliberately LOCAL-first (finding D-9): the five-service AWS orchestration is
Stage 4; this driver must produce one real corpus end-to-end on a laptop with a
FakeSandbox before anything is distributed. Write-once per layout
(finding D-21): refuses to run if the manifest already exists.
"""
from __future__ import annotations
from typing import Callable, Sequence
from composer_replication.datagen.env import FeatureDeletionEnv
from composer_replication.datagen.rollout_harness import (
RolloutPolicy,
admit,
collect_trajectory,
)
from composer_replication.datagen.schema import FeatureDeletionTask
from composer_replication.datagen.trajectory import to_policy_row
from composer_replication.pipeline import s3_contract
from composer_replication.pipeline.dedup import dedup
from composer_replication.pipeline.s3_contract import RunLayout, RunManifest
from composer_replication.safety.holdout import HeldoutSplit
def build_corpus(
source_tasks: Sequence[FeatureDeletionTask],
env_factory: Callable[[], FeatureDeletionEnv],
policy_factory: Callable[[], RolloutPolicy],
layout: RunLayout,
manifest: RunManifest,
*,
holdout_frac: float = 0.2,
holdout_seed: int = 0,
max_tasks: int | None = None,
cost_per_rollout_usd: float = 0.0,
prior_signatures: Sequence[Sequence[int]] | None = None,
dedup_threshold: float = 0.85,
) -> RunManifest:
"""Run the Stage-0 pipeline over `source_tasks`; returns the final manifest.
Args:
source_tasks: gate_repo-admitted FeatureDeletionTasks (the caller runs
`datagen.repo_gate.gate_repo` BEFORE this — the driver assumes the
license/decontamination gates already passed).
env_factory: fresh `FeatureDeletionEnv` per rollout (a sandbox is
stateful; sharing one across episodes leaks trajectory state).
policy_factory: fresh policy per rollout (ScriptedPolicy is stateful).
manifest: a `RunManifest` with run_id/created_at/budget set by the
caller (created_at is caller-passed for reproducibility).
cost_per_rollout_usd: accounting hook — API policies should report
real cost; the driver enforces `manifest.budget_usd` with it.
prior_signatures: previous generation's MinHash signatures
(cross-generation dedup, finding D-12).
Raises:
FileExistsError: if the layout already has a manifest (write-once).
"""
if s3_contract.manifest_exists(layout):
raise FileExistsError(
f"Run layout already has a manifest at {layout.manifest_path} — "
"runs are write-once per (root, run_id); mint a new run_id "
"(finding D-21 idempotency)."
)
# 1. Holdout split FIRST — held-out tasks are never rolled out, so no
# training signal can derive from them (the HeldoutSplit discipline).
split = HeldoutSplit.split(source_tasks, holdout_frac=holdout_frac,
seed=holdout_seed, check_content=True)
by_id = {t.task_id: t for t in source_tasks}
holdout_tasks = [by_id[i] for i in sorted(split.holdout_ids)]
train_tasks = [by_id[i] for i in sorted(split.train_ids)]
if max_tasks is not None:
train_tasks = train_tasks[:max_tasks]
# 2. Rollouts + admission routing.
sft_rows: list[dict] = []
dpo_rows: list[dict] = []
quarantine_rows: list[dict] = []
traj_rows: list[dict] = []
partial = False
for task in train_tasks:
# Hard ceiling (Wave-21 review P1): a rollout only starts if its cost
# still fits — pre-charging prevents the one-rollout overshoot the
# old check-then-spend ordering allowed.
if manifest.budget_usd is not None and (
manifest.cost_usd + cost_per_rollout_usd > manifest.budget_usd
):
partial = True
break
traj = collect_trajectory(env_factory(), task, policy_factory(),
provenance={"run_id": manifest.run_id})
manifest.spend(cost_per_rollout_usd)
verdict = admit(traj)
row = to_policy_row(traj, task)
traj_rows.append({**row, "admission": list(verdict.reasons)})
if verdict.sft_admitted:
sft_rows.append(row)
elif verdict.dpo_candidate:
dpo_rows.append(row)
else:
quarantine_rows.append({**row, "reasons": list(verdict.reasons)})
# 3. Dedup the SFT corpus (within-run + cross-generation).
def _key(r: dict) -> str:
return " ".join(m.get("content", "") for m in r.get("messages", []))
sft_rows, dedup_stats = dedup(sft_rows, _key, dedup_threshold,
prior_signatures=prior_signatures)
# 4. Write everything through the contract.
s3_contract.write_tasks(layout, train_tasks)
s3_contract.write_tasks_full(layout, train_tasks)
s3_contract.write_holdout(layout, holdout_tasks)
s3_contract.write_trajectories(layout, traj_rows)
s3_contract.write_sft_rows(layout, sft_rows)
s3_contract.write_dpo_rows(layout, dpo_rows)
s3_contract.write_quarantine(layout, quarantine_rows)
manifest.counts = {
"tasks_train": len(train_tasks),
"tasks_holdout": len(holdout_tasks),
"rollouts": len(traj_rows),
"sft_rows": len(sft_rows),
"dpo_rows": len(dpo_rows),
"quarantined": len(quarantine_rows),
**{f"dedup_{k}": v for k, v in dedup_stats.items()},
}
manifest.status = "partial" if partial else "building"
manifest.write(layout)
s3_contract.write_dataset_card(layout, manifest, dedup_stats=dedup_stats)
return manifest
__all__ = ["build_corpus"]