roverdevkit / scripts /build_dataset.py
jjreif's picture
Deploy roverdevkit @ 2676a67
b3d14e3
Raw
History Blame Contribute Delete
7.15 kB
"""Generate a surrogate-training analytical dataset (LHS sampler -> evaluator -> Parquet).
Single canonical entry point for any dataset rebuild: pilot, full
training run, or release-time benchmark slice. The same flags drive
all three; see ``--help``.
Examples
--------
::
# 200-sample smoke pilot (canonical reproduction)
python scripts/build_dataset.py \\
--n-per-scenario 50 \\
--out data/analytical/lhs_pilot.parquet \\
--seed 42 \\
--workers 1 \\
--notes "baseline-surrogate pilot pilot rebuild under v2."
# Full 40k training set on the v3 widened bounds (current canonical)
python scripts/build_dataset.py \\
--n-per-scenario 10000 \\
--out data/analytical/lhs_v3.parquet \\
--seed 42 \\
--notes "v3 widened LHS bounds (chassis 3-50 kg, wheel_width 0.03-0.20 m, grouser 0-0.020 m)."
The script writes a single Parquet file with the schema documented in
``data/analytical/SCHEMA.md`` (``SCHEMA_VERSION`` constant in
``roverdevkit.surrogate.dataset`` is the source of truth). Dataset-
level metadata (seed, n_per_scenario, fidelity, build timestamp,
free-form notes) is written to the file footer so re-runs are
reproducible from disk alone.
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
import time
from pathlib import Path
from roverdevkit.surrogate.dataset import (
SCHEMA_VERSION,
DatasetMetadata,
build_and_write,
)
from roverdevkit.surrogate.sampling import FAMILIES, generate_samples
DEFAULT_FAMILIES: tuple[str, ...] = tuple(FAMILIES.keys())
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument(
"--n-per-scenario",
type=int,
required=True,
help="LHS samples per scenario family. Total rows = n * len(families).",
)
p.add_argument(
"--out",
type=Path,
required=True,
help="Output Parquet path (parent dirs are created).",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="Sampler RNG seed. Same seed -> same samples (default: 42).",
)
p.add_argument(
"--workers",
type=int,
default=0,
help=(
"Worker process count. 0 (default) => os.cpu_count() - 1 (capped at 1). "
"1 => serial; useful for debugging or if multiprocessing/spawn misbehaves."
),
)
p.add_argument(
"--families",
nargs="+",
choices=list(DEFAULT_FAMILIES),
default=list(DEFAULT_FAMILIES),
help="Scenario families to include (default: all four).",
)
p.add_argument(
"--val-frac",
type=float,
default=0.1,
help="Validation split fraction (default: 0.1).",
)
p.add_argument(
"--test-frac",
type=float,
default=0.1,
help="Test split fraction (default: 0.1).",
)
p.add_argument(
"--chunksize",
type=int,
default=32,
help="multiprocessing.imap_unordered chunk size (default: 32).",
)
p.add_argument(
"--no-progress",
action="store_true",
help="Disable the tqdm progress bar.",
)
p.add_argument(
"--notes",
type=str,
default="",
help="Free-form notes string written to the Parquet metadata.",
)
p.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level for the build run (default: INFO).",
)
return p.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
logging.basicConfig(
level=args.log_level,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
log = logging.getLogger("build_dataset")
workers = args.workers if args.workers > 0 else max(1, (os.cpu_count() or 2) - 1)
log.info(
"schema=%s n_per_scenario=%d families=%d total_samples=%d workers=%d seed=%d "
"out=%s",
SCHEMA_VERSION,
args.n_per_scenario,
len(args.families),
args.n_per_scenario * len(args.families),
workers,
args.seed,
args.out,
)
log.info("generating LHS samples...")
samples = generate_samples(
n_per_scenario=args.n_per_scenario,
seed=args.seed,
scenario_names=list(args.families),
val_frac=args.val_frac,
test_frac=args.test_frac,
)
log.info("generated %d samples", len(samples))
meta = DatasetMetadata(
sampler_seed=args.seed,
n_per_scenario=args.n_per_scenario,
scenario_families=tuple(args.families),
val_frac=args.val_frac,
test_frac=args.test_frac,
notes=args.notes,
)
log.info("evaluating samples (workers=%d)...", workers)
t0 = time.perf_counter()
df, path = build_and_write(
samples,
args.out,
metadata=meta,
build_kwargs={
"n_workers": workers,
"chunksize": args.chunksize,
"progress": not args.no_progress,
},
)
elapsed = time.perf_counter() - t0
n_ok = int((df["status"] == "ok").sum())
n_total = len(df)
n_failed = n_total - n_ok
log.info(
"wrote %d rows x %d cols to %s (ok=%d/%d, %.2f%%) in %.1f s (%.2f s/sample, "
"%.2f s/sample/worker)",
n_total,
len(df.columns),
path,
n_ok,
n_total,
100 * n_ok / max(1, n_total),
elapsed,
elapsed / max(1, n_total),
elapsed * workers / max(1, n_total),
)
if n_failed > 0:
# Per-sample exceptions are recorded in the ``status`` column and
# are part of the documented graceful-failure contract, not a
# script-level error. Surface them so they're visible without
# tripping CI.
top_reasons = df.loc[df["status"] != "ok", "status"].value_counts().head(5)
log.warning(
"%d sample(s) hit a graceful-failure path (still recorded with NaN metrics). "
"Top reasons: %s",
n_failed,
", ".join(f"{k}={v}" for k, v in top_reasons.items()),
)
if "stalled" in df.columns:
# Schema v6 (v6 schema update): infeasibility flag flipped from
# ``motor_torque_ok`` to ``stalled`` (positive class = bad).
# Report the *non-stalled* rate so the headline number stays
# comparable to pre-v6 datasets ("higher is better").
feas_rate = 1.0 - float(df["stalled"].astype(bool).mean())
log.info("feasibility (non-stalled) rate: %.2f%%", 100 * feas_rate)
# Non-zero exit only on catastrophic build failure (no rows written).
# Per-sample graceful failures are by design and must not masquerade
# as a script error, otherwise CI / shell wrappers misclassify
# successful runs.
return 0 if n_total > 0 else 1
if __name__ == "__main__":
sys.exit(main())