Baladithya Balamurugan
Wave 21: adversarial-review fixes — all 9 verified findings closed
3bbcf21
Raw
History Blame Contribute Delete
10.4 kB
"""s3_contract.py — THE single dataset layout + manifest (finding V8/D-7/D-8).
Supersedes BOTH prior contracts: design-F1's `runs/<id>/{sft_corpus,dpo_pairs,
rl_task_pool,divergence_pairs,wm_tuples,holdout,diloco_rendezvous}` and
design-F2's `{traces,tasks,replay,task_grades,corpus}/v1/run_id=<id>` — the two
were never reconciled and coexisted in the grounding doc. One layout, one
manifest, two explicit serializers with a unit-tested leak guard.
Deliberate exclusions from the run layout:
* `diloco_rendezvous/` — training-comms state, not dataset; lives in its own
prefix/bucket (finding D-19).
* `wm_tuples/` — emitted only when the P4 world-model ablation is scheduled
(finding D-14); not part of Stage 0.
Layout (root = any local path or fsspec URI):
<root>/runs/<run_id>/
tasks/manifest.jsonl policy-safe task rows (golden_diff -> sha256)
tasks_full/manifest.jsonl construction-side full rows (RESTRICTED prefix)
traj/*.jsonl CanonicalTrajectory records (audit trail)
corpus_sft/rows.jsonl admitted SFT rows (to_policy_row output)
corpus_dpo/rows.jsonl DPO-candidate rows
holdout/tasks.jsonl held-out task ids+rows (never rolled out)
quarantine/*.jsonl rejected trajectories w/ reasons (audit)
manifest.json RunManifest
DATASET_CARD.md human-readable card
"""
from __future__ import annotations
import dataclasses
import hashlib
import json
from dataclasses import dataclass, field
from typing import IO, Iterable
from composer_replication.datagen.schema import FeatureDeletionTask
SCHEMA_VERSION = "1"
def _is_local(root: str) -> bool:
return "://" not in root or root.startswith("file://")
def _open(path: str, mode: str = "w") -> IO[str]:
"""Open a path for text IO; plain `open` locally, fsspec for s3:// etc.
fsspec is lazy so the module (and all local-corpus runs) need no extra dep.
"""
if _is_local(path):
import os
local = path.removeprefix("file://")
os.makedirs(os.path.dirname(local), exist_ok=True)
return open(local, mode, encoding="utf-8")
try:
import fsspec # noqa: PLC0415 — lazy heavy dep
except ImportError as e:
raise RuntimeError(
"Non-local corpus roots require fsspec; install with "
"`pip install -e .[serverless]`. Got: " + repr(e)
) from e
return fsspec.open(path, mode, encoding="utf-8").open()
def _exists(path: str) -> bool:
if _is_local(path):
import os
return os.path.exists(path.removeprefix("file://"))
import fsspec # noqa: PLC0415
fs, _, paths = fsspec.get_fs_token_paths(path)
return bool(fs.exists(paths[0]))
@dataclass(frozen=True)
class RunLayout:
"""Pure-path logic for one run's prefixes — testable without any IO."""
root: str
run_id: str
def __post_init__(self) -> None:
# Defense-in-depth (Wave-21 review P2): run_id is operator-supplied,
# but a separator or `..` would silently escape the corpus root.
if not self.run_id or "/" in self.run_id or "\\" in self.run_id \
or ".." in self.run_id:
raise ValueError(
f"run_id {self.run_id!r} must be a single non-empty path "
"segment (no separators, no '..')."
)
def _p(self, *parts: str) -> str:
base = self.root.rstrip("/")
return f"{base}/runs/{self.run_id}/" + "/".join(parts)
@property
def tasks_path(self) -> str:
return self._p("tasks", "manifest.jsonl")
@property
def tasks_full_path(self) -> str:
# RESTRICTED prefix: carries golden_diff/deleted_symbols. On S3 this
# prefix gets a deny-by-default policy; locally it is still separated
# so a naive `corpus_*` glob can never sweep it up.
return self._p("tasks_full", "manifest.jsonl")
@property
def traj_path(self) -> str:
return self._p("traj", "trajectories.jsonl")
@property
def sft_path(self) -> str:
return self._p("corpus_sft", "rows.jsonl")
@property
def dpo_path(self) -> str:
return self._p("corpus_dpo", "rows.jsonl")
@property
def holdout_path(self) -> str:
return self._p("holdout", "tasks.jsonl")
@property
def quarantine_path(self) -> str:
return self._p("quarantine", "rejected.jsonl")
@property
def manifest_path(self) -> str:
return self._p("manifest.json")
@property
def card_path(self) -> str:
return self._p("DATASET_CARD.md")
@dataclass
class RunManifest:
"""Run-level metadata: counts, cost, lineage, budget, acceptance status.
`created_at` is CALLER-passed (never datetime.now() in here) so manifests
are reproducible in tests. `parent_run_id` threads flywheel lineage so
cross-generation dedup (finding D-12) can find prior signatures.
"""
run_id: str
created_at: str
source: str = ""
counts: dict = field(default_factory=dict)
cost_usd: float = 0.0
parent_run_id: str | None = None
schema_version: str = SCHEMA_VERSION
status: str = "building" # building | accepted | rejected | partial
budget_usd: float | None = None
def spend(self, usd: float) -> None:
self.cost_usd += usd
@property
def over_budget(self) -> bool:
return self.budget_usd is not None and self.cost_usd >= self.budget_usd
def write(self, layout: RunLayout) -> None:
with _open(layout.manifest_path) as f:
json.dump(dataclasses.asdict(self), f, indent=2)
@classmethod
def read(cls, layout: RunLayout) -> RunManifest:
with _open(layout.manifest_path, "r") as f:
return cls(**json.load(f))
# ---------------------------------------------------------------------
# Writers — the leak guard lives here (finding D-8)
# ---------------------------------------------------------------------
def _task_row_policy_safe(task: FeatureDeletionTask) -> dict:
"""Task row with the construction-side secrets REPLACED, not just hidden.
`asdict()` includes `golden_diff` despite `repr=False` — that is exactly
the leak D-8 flagged. We keep provenance via a sha256 (verifiable, not
recoverable) and drop `deleted_symbols` entirely (they name the answer).
"""
row = dataclasses.asdict(task)
gold = row.pop("golden_diff", "")
row.pop("deleted_symbols", None)
row["golden_diff_sha256"] = hashlib.sha256(gold.encode()).hexdigest() if gold else ""
return row
def write_tasks(layout: RunLayout, tasks: Iterable[FeatureDeletionTask]) -> int:
"""Write the POLICY-SAFE task manifest (the default everything reads)."""
n = 0
with _open(layout.tasks_path) as f:
for t in tasks:
f.write(json.dumps(_task_row_policy_safe(t)) + "\n")
n += 1
return n
def write_tasks_full(layout: RunLayout, tasks: Iterable[FeatureDeletionTask]) -> int:
"""Write FULL task rows (incl. golden_diff) to the RESTRICTED prefix.
Only the validator/monitor side reads this; never corpus consumers.
"""
n = 0
with _open(layout.tasks_full_path) as f:
for t in tasks:
f.write(json.dumps(dataclasses.asdict(t)) + "\n")
n += 1
return n
def _write_jsonl(path: str, rows: Iterable[dict]) -> int:
n = 0
with _open(path) as f:
for r in rows:
f.write(json.dumps(r) + "\n")
n += 1
return n
def write_sft_rows(layout: RunLayout, rows: Iterable[dict]) -> int:
return _write_jsonl(layout.sft_path, rows)
def write_dpo_rows(layout: RunLayout, rows: Iterable[dict]) -> int:
return _write_jsonl(layout.dpo_path, rows)
def write_quarantine(layout: RunLayout, rows: Iterable[dict]) -> int:
return _write_jsonl(layout.quarantine_path, rows)
def write_holdout(layout: RunLayout, tasks: Iterable[FeatureDeletionTask]) -> int:
return _write_jsonl(layout.holdout_path, (_task_row_policy_safe(t) for t in tasks))
def write_trajectories(layout: RunLayout, rows: Iterable[dict]) -> int:
return _write_jsonl(layout.traj_path, rows)
def write_dataset_card(layout: RunLayout, manifest: RunManifest,
*, license_tiers: dict[str, int] | None = None,
dedup_stats: dict | None = None,
decontamination_note: str = "") -> None:
"""A small human-readable dataset card (finding D-18)."""
lines = [
f"# Dataset card — run `{manifest.run_id}`",
"",
f"- **created:** {manifest.created_at}",
f"- **source:** {manifest.source}",
f"- **status:** {manifest.status}",
f"- **schema_version:** {manifest.schema_version}",
f"- **cost (USD):** {manifest.cost_usd:.2f}"
+ (f" / budget {manifest.budget_usd:.2f}" if manifest.budget_usd else ""),
f"- **lineage:** parent_run_id={manifest.parent_run_id or 'none'}",
"",
"## Counts",
"",
]
for k, v in sorted(manifest.counts.items()):
lines.append(f"- {k}: {v}")
if license_tiers:
lines += ["", "## License tiers seen", ""]
lines += [f"- {k}: {v}" for k, v in sorted(license_tiers.items())]
lines += ["", "## Decontamination", "",
decontamination_note or
"All source repos checked against the SWE-bench-family eval list "
"(datagen.repo_gate.DECONTAMINATION_LIST) at ingest."]
if dedup_stats:
lines += ["", "## Dedup", ""]
lines += [f"- {k}: {v}" for k, v in sorted(dedup_stats.items())]
lines += ["", "Policy-safe rows only: `golden_diff` is sha256-hashed and "
"`deleted_symbols` dropped in `tasks/`, `corpus_*/`, `holdout/` "
"(full rows live in the restricted `tasks_full/`).", ""]
with _open(layout.card_path) as f:
f.write("\n".join(lines))
def manifest_exists(layout: RunLayout) -> bool:
"""Write-once guard for the driver (finding D-21 idempotency)."""
return _exists(layout.manifest_path)
__all__ = [
"SCHEMA_VERSION",
"RunLayout",
"RunManifest",
"manifest_exists",
"write_dataset_card",
"write_dpo_rows",
"write_holdout",
"write_quarantine",
"write_sft_rows",
"write_tasks",
"write_tasks_full",
"write_trajectories",
]