File size: 10,430 Bytes
9a2ce20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bbcf21
 
 
 
 
 
 
 
 
 
9a2ce20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""s3_contract.py — THE single dataset layout + manifest (finding V8/D-7/D-8).

Supersedes BOTH prior contracts: design-F1's `runs/<id>/{sft_corpus,dpo_pairs,
rl_task_pool,divergence_pairs,wm_tuples,holdout,diloco_rendezvous}` and
design-F2's `{traces,tasks,replay,task_grades,corpus}/v1/run_id=<id>` — the two
were never reconciled and coexisted in the grounding doc. One layout, one
manifest, two explicit serializers with a unit-tested leak guard.

Deliberate exclusions from the run layout:
  * `diloco_rendezvous/` — training-comms state, not dataset; lives in its own
    prefix/bucket (finding D-19).
  * `wm_tuples/` — emitted only when the P4 world-model ablation is scheduled
    (finding D-14); not part of Stage 0.

Layout (root = any local path or fsspec URI):
    <root>/runs/<run_id>/
        tasks/manifest.jsonl       policy-safe task rows (golden_diff -> sha256)
        tasks_full/manifest.jsonl  construction-side full rows (RESTRICTED prefix)
        traj/*.jsonl               CanonicalTrajectory records (audit trail)
        corpus_sft/rows.jsonl      admitted SFT rows (to_policy_row output)
        corpus_dpo/rows.jsonl      DPO-candidate rows
        holdout/tasks.jsonl        held-out task ids+rows (never rolled out)
        quarantine/*.jsonl         rejected trajectories w/ reasons (audit)
        manifest.json              RunManifest
        DATASET_CARD.md            human-readable card
"""
from __future__ import annotations

import dataclasses
import hashlib
import json
from dataclasses import dataclass, field
from typing import IO, Iterable

from composer_replication.datagen.schema import FeatureDeletionTask

SCHEMA_VERSION = "1"


def _is_local(root: str) -> bool:
    return "://" not in root or root.startswith("file://")


def _open(path: str, mode: str = "w") -> IO[str]:
    """Open a path for text IO; plain `open` locally, fsspec for s3:// etc.

    fsspec is lazy so the module (and all local-corpus runs) need no extra dep.
    """
    if _is_local(path):
        import os
        local = path.removeprefix("file://")
        os.makedirs(os.path.dirname(local), exist_ok=True)
        return open(local, mode, encoding="utf-8")
    try:
        import fsspec  # noqa: PLC0415 — lazy heavy dep
    except ImportError as e:
        raise RuntimeError(
            "Non-local corpus roots require fsspec; install with "
            "`pip install -e .[serverless]`. Got: " + repr(e)
        ) from e
    return fsspec.open(path, mode, encoding="utf-8").open()


def _exists(path: str) -> bool:
    if _is_local(path):
        import os
        return os.path.exists(path.removeprefix("file://"))
    import fsspec  # noqa: PLC0415
    fs, _, paths = fsspec.get_fs_token_paths(path)
    return bool(fs.exists(paths[0]))


@dataclass(frozen=True)
class RunLayout:
    """Pure-path logic for one run's prefixes — testable without any IO."""

    root: str
    run_id: str

    def __post_init__(self) -> None:
        # Defense-in-depth (Wave-21 review P2): run_id is operator-supplied,
        # but a separator or `..` would silently escape the corpus root.
        if not self.run_id or "/" in self.run_id or "\\" in self.run_id \
                or ".." in self.run_id:
            raise ValueError(
                f"run_id {self.run_id!r} must be a single non-empty path "
                "segment (no separators, no '..')."
            )

    def _p(self, *parts: str) -> str:
        base = self.root.rstrip("/")
        return f"{base}/runs/{self.run_id}/" + "/".join(parts)

    @property
    def tasks_path(self) -> str:
        return self._p("tasks", "manifest.jsonl")

    @property
    def tasks_full_path(self) -> str:
        # RESTRICTED prefix: carries golden_diff/deleted_symbols. On S3 this
        # prefix gets a deny-by-default policy; locally it is still separated
        # so a naive `corpus_*` glob can never sweep it up.
        return self._p("tasks_full", "manifest.jsonl")

    @property
    def traj_path(self) -> str:
        return self._p("traj", "trajectories.jsonl")

    @property
    def sft_path(self) -> str:
        return self._p("corpus_sft", "rows.jsonl")

    @property
    def dpo_path(self) -> str:
        return self._p("corpus_dpo", "rows.jsonl")

    @property
    def holdout_path(self) -> str:
        return self._p("holdout", "tasks.jsonl")

    @property
    def quarantine_path(self) -> str:
        return self._p("quarantine", "rejected.jsonl")

    @property
    def manifest_path(self) -> str:
        return self._p("manifest.json")

    @property
    def card_path(self) -> str:
        return self._p("DATASET_CARD.md")


@dataclass
class RunManifest:
    """Run-level metadata: counts, cost, lineage, budget, acceptance status.

    `created_at` is CALLER-passed (never datetime.now() in here) so manifests
    are reproducible in tests. `parent_run_id` threads flywheel lineage so
    cross-generation dedup (finding D-12) can find prior signatures.
    """

    run_id: str
    created_at: str
    source: str = ""
    counts: dict = field(default_factory=dict)
    cost_usd: float = 0.0
    parent_run_id: str | None = None
    schema_version: str = SCHEMA_VERSION
    status: str = "building"          # building | accepted | rejected | partial
    budget_usd: float | None = None

    def spend(self, usd: float) -> None:
        self.cost_usd += usd

    @property
    def over_budget(self) -> bool:
        return self.budget_usd is not None and self.cost_usd >= self.budget_usd

    def write(self, layout: RunLayout) -> None:
        with _open(layout.manifest_path) as f:
            json.dump(dataclasses.asdict(self), f, indent=2)

    @classmethod
    def read(cls, layout: RunLayout) -> RunManifest:
        with _open(layout.manifest_path, "r") as f:
            return cls(**json.load(f))


# ---------------------------------------------------------------------
# Writers — the leak guard lives here (finding D-8)
# ---------------------------------------------------------------------


def _task_row_policy_safe(task: FeatureDeletionTask) -> dict:
    """Task row with the construction-side secrets REPLACED, not just hidden.

    `asdict()` includes `golden_diff` despite `repr=False` — that is exactly
    the leak D-8 flagged. We keep provenance via a sha256 (verifiable, not
    recoverable) and drop `deleted_symbols` entirely (they name the answer).
    """
    row = dataclasses.asdict(task)
    gold = row.pop("golden_diff", "")
    row.pop("deleted_symbols", None)
    row["golden_diff_sha256"] = hashlib.sha256(gold.encode()).hexdigest() if gold else ""
    return row


def write_tasks(layout: RunLayout, tasks: Iterable[FeatureDeletionTask]) -> int:
    """Write the POLICY-SAFE task manifest (the default everything reads)."""
    n = 0
    with _open(layout.tasks_path) as f:
        for t in tasks:
            f.write(json.dumps(_task_row_policy_safe(t)) + "\n")
            n += 1
    return n


def write_tasks_full(layout: RunLayout, tasks: Iterable[FeatureDeletionTask]) -> int:
    """Write FULL task rows (incl. golden_diff) to the RESTRICTED prefix.

    Only the validator/monitor side reads this; never corpus consumers.
    """
    n = 0
    with _open(layout.tasks_full_path) as f:
        for t in tasks:
            f.write(json.dumps(dataclasses.asdict(t)) + "\n")
            n += 1
    return n


def _write_jsonl(path: str, rows: Iterable[dict]) -> int:
    n = 0
    with _open(path) as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
            n += 1
    return n


def write_sft_rows(layout: RunLayout, rows: Iterable[dict]) -> int:
    return _write_jsonl(layout.sft_path, rows)


def write_dpo_rows(layout: RunLayout, rows: Iterable[dict]) -> int:
    return _write_jsonl(layout.dpo_path, rows)


def write_quarantine(layout: RunLayout, rows: Iterable[dict]) -> int:
    return _write_jsonl(layout.quarantine_path, rows)


def write_holdout(layout: RunLayout, tasks: Iterable[FeatureDeletionTask]) -> int:
    return _write_jsonl(layout.holdout_path, (_task_row_policy_safe(t) for t in tasks))


def write_trajectories(layout: RunLayout, rows: Iterable[dict]) -> int:
    return _write_jsonl(layout.traj_path, rows)


def write_dataset_card(layout: RunLayout, manifest: RunManifest,
                       *, license_tiers: dict[str, int] | None = None,
                       dedup_stats: dict | None = None,
                       decontamination_note: str = "") -> None:
    """A small human-readable dataset card (finding D-18)."""
    lines = [
        f"# Dataset card — run `{manifest.run_id}`",
        "",
        f"- **created:** {manifest.created_at}",
        f"- **source:** {manifest.source}",
        f"- **status:** {manifest.status}",
        f"- **schema_version:** {manifest.schema_version}",
        f"- **cost (USD):** {manifest.cost_usd:.2f}"
        + (f" / budget {manifest.budget_usd:.2f}" if manifest.budget_usd else ""),
        f"- **lineage:** parent_run_id={manifest.parent_run_id or 'none'}",
        "",
        "## Counts",
        "",
    ]
    for k, v in sorted(manifest.counts.items()):
        lines.append(f"- {k}: {v}")
    if license_tiers:
        lines += ["", "## License tiers seen", ""]
        lines += [f"- {k}: {v}" for k, v in sorted(license_tiers.items())]
    lines += ["", "## Decontamination", "",
              decontamination_note or
              "All source repos checked against the SWE-bench-family eval list "
              "(datagen.repo_gate.DECONTAMINATION_LIST) at ingest."]
    if dedup_stats:
        lines += ["", "## Dedup", ""]
        lines += [f"- {k}: {v}" for k, v in sorted(dedup_stats.items())]
    lines += ["", "Policy-safe rows only: `golden_diff` is sha256-hashed and "
              "`deleted_symbols` dropped in `tasks/`, `corpus_*/`, `holdout/` "
              "(full rows live in the restricted `tasks_full/`).", ""]
    with _open(layout.card_path) as f:
        f.write("\n".join(lines))


def manifest_exists(layout: RunLayout) -> bool:
    """Write-once guard for the driver (finding D-21 idempotency)."""
    return _exists(layout.manifest_path)


__all__ = [
    "SCHEMA_VERSION",
    "RunLayout",
    "RunManifest",
    "manifest_exists",
    "write_dataset_card",
    "write_dpo_rows",
    "write_holdout",
    "write_quarantine",
    "write_sft_rows",
    "write_tasks",
    "write_tasks_full",
    "write_trajectories",
]