SpatialBench / run_experiments.py
weijiang99's picture
Update SpatialBench pipeline
52ea128 verified
#!/usr/bin/env python3
"""
run_experiments.py
------------------
CLI orchestrator for SpatialBench experiments.
Run on the cluster with SLURM:
python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --mode slurm
Run directly (uses API keys, no SLURM required):
python run_experiments.py --tasks maze_navigation --models gemini-2.5-flash --mode direct
Dry-run (print commands without executing):
python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --dry-run
Filter experiments:
python run_experiments.py --tasks maze_navigation \\
--models gemini-2.5-flash claude-haiku-4-5 \\
--grid-sizes 5 6 7 \\
--formats raw \\
--strategies cot reasoning
Show status of running SLURM jobs:
python run_experiments.py --status
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
# Load .env if present (before importing pipeline modules)
_env_file = Path(__file__).parent / ".env"
if _env_file.exists():
with open(_env_file) as _f:
for _line in _f:
_line = _line.strip()
if _line and not _line.startswith("#") and "=" in _line:
_k, _v = _line.split("=", 1)
os.environ.setdefault(_k.strip(), _v.strip())
from pipeline.task_builder import (
load_config, build_all_jobs, make_sbatch_script, ExperimentJob,
)
from pipeline.job_monitor import JobMonitor, submit_sbatch, submit_direct
CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml"
REPO_ROOT = CONFIG_PATH.parent.parent.parent # llm-maze-solver/
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _check_api_key(job: ExperimentJob) -> bool:
val = os.environ.get(job.api_key_env, "")
if not val:
print(f" [WARN] {job.api_key_env} not set — skipping: {job.label}")
return False
return True
def _print_job(job: ExperimentJob) -> None:
print(f"\n {job.label}")
print(f" cmd : {' '.join(job.python_cmd[:4])} ...")
print(f" wdir: {job.working_dir}")
print(f" out : {job.output_dir}")
# ---------------------------------------------------------------------------
# Run modes
# ---------------------------------------------------------------------------
def run_slurm(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
log_dir = REPO_ROOT / "maze-solver" / "eval_llm_logs"
log_dir.mkdir(parents=True, exist_ok=True)
for job in jobs:
if not _check_api_key(job):
continue
script_text = make_sbatch_script(job, log_dir)
if dry_run:
_print_job(job)
print(" --- sbatch script ---")
print(script_text)
continue
with tempfile.NamedTemporaryFile(
mode="w", suffix=".sh", prefix="spatialbench_",
dir=log_dir, delete=False
) as tmp:
tmp.write(script_text)
script_path = tmp.name
job_id = submit_sbatch(script_path)
if job_id:
monitor.add(
job_id=job_id,
label=job.label,
task_id=job.task_id,
model=job.model,
output_dir=str(job.output_dir),
log_out=str(log_dir / f"{job_id}.out"),
log_err=str(log_dir / f"{job_id}.err"),
)
print(f" Submitted {job.label} → SLURM job {job_id}")
else:
print(f" [ERROR] Failed to submit: {job.label}")
def run_direct(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
for job in jobs:
if not _check_api_key(job):
continue
if dry_run:
_print_job(job)
print(f" cmd: {' '.join(job.python_cmd)}\n")
continue
env_patch = {job.api_key_env: os.environ[job.api_key_env]}
job.output_dir.mkdir(parents=True, exist_ok=True)
print(f" Starting: {job.label}")
proc = submit_direct(
cmd=job.python_cmd,
working_dir=str(job.working_dir),
env=env_patch,
)
monitor.add_direct(
proc=proc,
label=job.label,
task_id=job.task_id,
model=job.model,
output_dir=str(job.output_dir),
)
# Small gap to avoid hammering APIs simultaneously
time.sleep(2)
# ---------------------------------------------------------------------------
# Status display
# ---------------------------------------------------------------------------
def show_status(monitor: JobMonitor) -> None:
monitor.refresh()
summary = monitor.summary()
print(f"\nTotal jobs: {summary['total']}")
for status, count in summary["counts"].items():
print(f" {status:12s}: {count}")
print()
for r in summary["records"]:
print(f" [{r['status']:9s}] {r['label']:<60s} elapsed: {r['elapsed']}")
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="SpatialBench experiment orchestrator",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--tasks", nargs="+",
default=["maze_navigation", "point_reuse", "compositional_distance"],
choices=["maze_navigation", "point_reuse", "compositional_distance"],
help="Which tasks to run (default: all three)",
)
parser.add_argument(
"--models", nargs="+", default=None,
help="Model IDs to run (default: all models in config)",
)
parser.add_argument(
"--grid-sizes", nargs="+", type=int, default=None,
dest="grid_sizes",
help="Grid sizes to evaluate, e.g. --grid-sizes 5 6 7 (default: per-task config)",
)
parser.add_argument(
"--formats", nargs="+", default=None,
choices=["raw", "visual"],
help="Input formats for Task 1 (default: both raw and visual)",
)
parser.add_argument(
"--strategies", nargs="+", default=None,
choices=["base", "cot", "reasoning"],
help="Prompt strategies (default: all)",
)
parser.add_argument(
"--mode", default="slurm", choices=["slurm", "direct"],
help="Execution mode: 'slurm' submits sbatch jobs, 'direct' runs inline (default: slurm)",
)
parser.add_argument(
"--dry-run", action="store_true",
help="Print commands without executing them",
)
parser.add_argument(
"--no-wait", action="store_true",
help="Return immediately after submission (don't poll for completion)",
)
parser.add_argument(
"--status", action="store_true",
help="Query and display SLURM job status (requires --job-ids or a running monitor)",
)
parser.add_argument(
"--job-ids", nargs="+", default=None,
help="SLURM job IDs to check status for (used with --status)",
)
parser.add_argument(
"--config", default=str(CONFIG_PATH),
help=f"Path to experiments.yaml (default: {CONFIG_PATH})",
)
parser.add_argument(
"--poll-interval", type=int, default=60,
dest="poll_interval",
help="Seconds between SLURM status polls when waiting (default: 60)",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
cfg = load_config(args.config)
# Status-only mode
if args.status:
monitor = JobMonitor(mode="slurm")
if args.job_ids:
for jid in args.job_ids:
monitor.add(job_id=jid, label=jid, task_id="?", model="?")
show_status(monitor)
return
# Build jobs
jobs = build_all_jobs(
cfg=cfg,
tasks=args.tasks,
models=args.models,
grid_sizes=args.grid_sizes,
input_formats=args.formats,
prompt_strategies=args.strategies,
config_path=Path(args.config),
)
if not jobs:
print("No jobs matched the requested filters.")
return
print(f"\nSpatialBench — {len(jobs)} job(s) to run")
print(f" mode : {args.mode}")
print(f" tasks : {args.tasks}")
print(f" models : {args.models or 'all'}")
print(f" grids : {args.grid_sizes or 'per-task default'}")
print(f" formats : {args.formats or 'per-task default'}")
print(f" strategies: {args.strategies or 'all'}")
print(f" dry-run : {args.dry_run}")
print()
monitor = JobMonitor(mode=args.mode)
if args.mode == "slurm":
run_slurm(jobs, monitor, dry_run=args.dry_run)
else:
run_direct(jobs, monitor, dry_run=args.dry_run)
if args.dry_run or args.no_wait:
if not args.dry_run:
print(f"\nSubmitted {len(monitor.all_records())} job(s). Use --status to check progress.")
return
# Wait for completion
print("\nWaiting for jobs to complete...")
def _progress(summary: dict) -> None:
counts = summary["counts"]
parts = [f"{s}: {n}" for s, n in counts.items()]
print(f" [{time.strftime('%H:%M:%S')}] {' | '.join(parts)}")
monitor.wait_all(poll_interval=args.poll_interval, callback=_progress)
# Final summary
summary = monitor.summary()
print(f"\nDone. {summary['counts'].get('completed', 0)} completed, "
f"{summary['counts'].get('failed', 0)} failed.")
failed = [r for r in summary["records"] if r["status"] == "failed"]
if failed:
print("\nFailed jobs:")
for r in failed:
print(f" {r['label']} (job_id={r['job_id']})")
if __name__ == "__main__":
main()