Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| run_experiments.py | |
| ------------------ | |
| CLI orchestrator for SpatialBench experiments. | |
| Run on the cluster with SLURM: | |
| python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --mode slurm | |
| Run directly (uses API keys, no SLURM required): | |
| python run_experiments.py --tasks maze_navigation --models gemini-2.5-flash --mode direct | |
| Dry-run (print commands without executing): | |
| python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --dry-run | |
| Filter experiments: | |
| python run_experiments.py --tasks maze_navigation \\ | |
| --models gemini-2.5-flash claude-haiku-4-5 \\ | |
| --grid-sizes 5 6 7 \\ | |
| --formats raw \\ | |
| --strategies cot reasoning | |
| Show status of running SLURM jobs: | |
| python run_experiments.py --status | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| # Load .env if present (before importing pipeline modules) | |
| _env_file = Path(__file__).parent / ".env" | |
| if _env_file.exists(): | |
| with open(_env_file) as _f: | |
| for _line in _f: | |
| _line = _line.strip() | |
| if _line and not _line.startswith("#") and "=" in _line: | |
| _k, _v = _line.split("=", 1) | |
| os.environ.setdefault(_k.strip(), _v.strip()) | |
| from pipeline.task_builder import ( | |
| load_config, build_all_jobs, make_sbatch_script, ExperimentJob, | |
| ) | |
| from pipeline.job_monitor import JobMonitor, submit_sbatch, submit_direct | |
| CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml" | |
| REPO_ROOT = CONFIG_PATH.parent.parent.parent # llm-maze-solver/ | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _check_api_key(job: ExperimentJob) -> bool: | |
| val = os.environ.get(job.api_key_env, "") | |
| if not val: | |
| print(f" [WARN] {job.api_key_env} not set — skipping: {job.label}") | |
| return False | |
| return True | |
| def _print_job(job: ExperimentJob) -> None: | |
| print(f"\n {job.label}") | |
| print(f" cmd : {' '.join(job.python_cmd[:4])} ...") | |
| print(f" wdir: {job.working_dir}") | |
| print(f" out : {job.output_dir}") | |
| # --------------------------------------------------------------------------- | |
| # Run modes | |
| # --------------------------------------------------------------------------- | |
| def run_slurm(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None: | |
| log_dir = REPO_ROOT / "maze-solver" / "eval_llm_logs" | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| for job in jobs: | |
| if not _check_api_key(job): | |
| continue | |
| script_text = make_sbatch_script(job, log_dir) | |
| if dry_run: | |
| _print_job(job) | |
| print(" --- sbatch script ---") | |
| print(script_text) | |
| continue | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", suffix=".sh", prefix="spatialbench_", | |
| dir=log_dir, delete=False | |
| ) as tmp: | |
| tmp.write(script_text) | |
| script_path = tmp.name | |
| job_id = submit_sbatch(script_path) | |
| if job_id: | |
| monitor.add( | |
| job_id=job_id, | |
| label=job.label, | |
| task_id=job.task_id, | |
| model=job.model, | |
| output_dir=str(job.output_dir), | |
| log_out=str(log_dir / f"{job_id}.out"), | |
| log_err=str(log_dir / f"{job_id}.err"), | |
| ) | |
| print(f" Submitted {job.label} → SLURM job {job_id}") | |
| else: | |
| print(f" [ERROR] Failed to submit: {job.label}") | |
| def run_direct(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None: | |
| for job in jobs: | |
| if not _check_api_key(job): | |
| continue | |
| if dry_run: | |
| _print_job(job) | |
| print(f" cmd: {' '.join(job.python_cmd)}\n") | |
| continue | |
| env_patch = {job.api_key_env: os.environ[job.api_key_env]} | |
| job.output_dir.mkdir(parents=True, exist_ok=True) | |
| print(f" Starting: {job.label}") | |
| proc = submit_direct( | |
| cmd=job.python_cmd, | |
| working_dir=str(job.working_dir), | |
| env=env_patch, | |
| ) | |
| monitor.add_direct( | |
| proc=proc, | |
| label=job.label, | |
| task_id=job.task_id, | |
| model=job.model, | |
| output_dir=str(job.output_dir), | |
| ) | |
| # Small gap to avoid hammering APIs simultaneously | |
| time.sleep(2) | |
| # --------------------------------------------------------------------------- | |
| # Status display | |
| # --------------------------------------------------------------------------- | |
| def show_status(monitor: JobMonitor) -> None: | |
| monitor.refresh() | |
| summary = monitor.summary() | |
| print(f"\nTotal jobs: {summary['total']}") | |
| for status, count in summary["counts"].items(): | |
| print(f" {status:12s}: {count}") | |
| print() | |
| for r in summary["records"]: | |
| print(f" [{r['status']:9s}] {r['label']:<60s} elapsed: {r['elapsed']}") | |
| # --------------------------------------------------------------------------- | |
| # Argument parsing | |
| # --------------------------------------------------------------------------- | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="SpatialBench experiment orchestrator", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| parser.add_argument( | |
| "--tasks", nargs="+", | |
| default=["maze_navigation", "point_reuse", "compositional_distance"], | |
| choices=["maze_navigation", "point_reuse", "compositional_distance"], | |
| help="Which tasks to run (default: all three)", | |
| ) | |
| parser.add_argument( | |
| "--models", nargs="+", default=None, | |
| help="Model IDs to run (default: all models in config)", | |
| ) | |
| parser.add_argument( | |
| "--grid-sizes", nargs="+", type=int, default=None, | |
| dest="grid_sizes", | |
| help="Grid sizes to evaluate, e.g. --grid-sizes 5 6 7 (default: per-task config)", | |
| ) | |
| parser.add_argument( | |
| "--formats", nargs="+", default=None, | |
| choices=["raw", "visual"], | |
| help="Input formats for Task 1 (default: both raw and visual)", | |
| ) | |
| parser.add_argument( | |
| "--strategies", nargs="+", default=None, | |
| choices=["base", "cot", "reasoning"], | |
| help="Prompt strategies (default: all)", | |
| ) | |
| parser.add_argument( | |
| "--mode", default="slurm", choices=["slurm", "direct"], | |
| help="Execution mode: 'slurm' submits sbatch jobs, 'direct' runs inline (default: slurm)", | |
| ) | |
| parser.add_argument( | |
| "--dry-run", action="store_true", | |
| help="Print commands without executing them", | |
| ) | |
| parser.add_argument( | |
| "--no-wait", action="store_true", | |
| help="Return immediately after submission (don't poll for completion)", | |
| ) | |
| parser.add_argument( | |
| "--status", action="store_true", | |
| help="Query and display SLURM job status (requires --job-ids or a running monitor)", | |
| ) | |
| parser.add_argument( | |
| "--job-ids", nargs="+", default=None, | |
| help="SLURM job IDs to check status for (used with --status)", | |
| ) | |
| parser.add_argument( | |
| "--config", default=str(CONFIG_PATH), | |
| help=f"Path to experiments.yaml (default: {CONFIG_PATH})", | |
| ) | |
| parser.add_argument( | |
| "--poll-interval", type=int, default=60, | |
| dest="poll_interval", | |
| help="Seconds between SLURM status polls when waiting (default: 60)", | |
| ) | |
| return parser.parse_args() | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| args = parse_args() | |
| cfg = load_config(args.config) | |
| # Status-only mode | |
| if args.status: | |
| monitor = JobMonitor(mode="slurm") | |
| if args.job_ids: | |
| for jid in args.job_ids: | |
| monitor.add(job_id=jid, label=jid, task_id="?", model="?") | |
| show_status(monitor) | |
| return | |
| # Build jobs | |
| jobs = build_all_jobs( | |
| cfg=cfg, | |
| tasks=args.tasks, | |
| models=args.models, | |
| grid_sizes=args.grid_sizes, | |
| input_formats=args.formats, | |
| prompt_strategies=args.strategies, | |
| config_path=Path(args.config), | |
| ) | |
| if not jobs: | |
| print("No jobs matched the requested filters.") | |
| return | |
| print(f"\nSpatialBench — {len(jobs)} job(s) to run") | |
| print(f" mode : {args.mode}") | |
| print(f" tasks : {args.tasks}") | |
| print(f" models : {args.models or 'all'}") | |
| print(f" grids : {args.grid_sizes or 'per-task default'}") | |
| print(f" formats : {args.formats or 'per-task default'}") | |
| print(f" strategies: {args.strategies or 'all'}") | |
| print(f" dry-run : {args.dry_run}") | |
| print() | |
| monitor = JobMonitor(mode=args.mode) | |
| if args.mode == "slurm": | |
| run_slurm(jobs, monitor, dry_run=args.dry_run) | |
| else: | |
| run_direct(jobs, monitor, dry_run=args.dry_run) | |
| if args.dry_run or args.no_wait: | |
| if not args.dry_run: | |
| print(f"\nSubmitted {len(monitor.all_records())} job(s). Use --status to check progress.") | |
| return | |
| # Wait for completion | |
| print("\nWaiting for jobs to complete...") | |
| def _progress(summary: dict) -> None: | |
| counts = summary["counts"] | |
| parts = [f"{s}: {n}" for s, n in counts.items()] | |
| print(f" [{time.strftime('%H:%M:%S')}] {' | '.join(parts)}") | |
| monitor.wait_all(poll_interval=args.poll_interval, callback=_progress) | |
| # Final summary | |
| summary = monitor.summary() | |
| print(f"\nDone. {summary['counts'].get('completed', 0)} completed, " | |
| f"{summary['counts'].get('failed', 0)} failed.") | |
| failed = [r for r in summary["records"] if r["status"] == "failed"] | |
| if failed: | |
| print("\nFailed jobs:") | |
| for r in failed: | |
| print(f" {r['label']} (job_id={r['job_id']})") | |
| if __name__ == "__main__": | |
| main() | |