File size: 2,797 Bytes
6a47c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import random
import sys
from pathlib import Path
from typing import Any

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from scripts.benchmark_datasets import resolve_benchmark_dataset
from scripts.benchmark_suite import validate_sample


def load_rows(path: Path) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    for line in path.read_text(encoding="utf-8").splitlines():
        if line.strip():
            rows.append(json.loads(line))
    return rows


def build_subset_rows(*, source: Path, benchmark: str, n: int, seed: int) -> list[dict[str, Any]]:
    rows = load_rows(source)
    if n > len(rows):
        raise ValueError(f"requested subset size {n} exceeds dataset size {len(rows)}")

    chooser = random.Random(seed)
    selected_indices = sorted(chooser.sample(range(len(rows)), n))
    subset: list[dict[str, Any]] = []
    for index in selected_indices:
        row = dict(rows[index])
        validate_sample(benchmark, row)
        row["source_row_id"] = index
        subset.append(row)
    return subset


def write_subset(*, source: Path, benchmark: str, n: int, seed: int, out: Path) -> Path:
    subset = build_subset_rows(source=source, benchmark=benchmark, n=n, seed=seed)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text("".join(json.dumps(row) + "\n" for row in subset), encoding="utf-8")

    manifest_path = out.with_suffix(out.suffix + ".manifest.json")
    manifest = {
        "benchmark": benchmark,
        "n": n,
        "seed": seed,
        "source_path": str(source),
        "out_path": str(out),
        "source_row_ids": [row["source_row_id"] for row in subset],
    }
    manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
    return manifest_path


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Build a deterministic benchmark subset JSONL and manifest")
    parser.add_argument("--benchmark", required=True, choices=["MBPP", "GSM8K", "HumanEval", "ARC-Challenge"])
    parser.add_argument("--samples", type=Path)
    parser.add_argument("--n", type=int, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--out", type=Path, required=True)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    source = resolve_benchmark_dataset(args.benchmark, args.samples)
    write_subset(source=source, benchmark=args.benchmark, n=args.n, seed=args.seed, out=args.out)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())