File size: 10,171 Bytes
cffeecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52ea128
cffeecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
#!/usr/bin/env python3
"""
run_experiments.py
------------------
CLI orchestrator for SpatialBench experiments.

Run on the cluster with SLURM:
    python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --mode slurm

Run directly (uses API keys, no SLURM required):
    python run_experiments.py --tasks maze_navigation --models gemini-2.5-flash --mode direct

Dry-run (print commands without executing):
    python run_experiments.py --tasks maze_navigation point_reuse compositional_distance --dry-run

Filter experiments:
    python run_experiments.py --tasks maze_navigation \\
        --models gemini-2.5-flash claude-haiku-4-5 \\
        --grid-sizes 5 6 7 \\
        --formats raw \\
        --strategies cot reasoning

Show status of running SLURM jobs:
    python run_experiments.py --status
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path

# Load .env if present (before importing pipeline modules)
_env_file = Path(__file__).parent / ".env"
if _env_file.exists():
    with open(_env_file) as _f:
        for _line in _f:
            _line = _line.strip()
            if _line and not _line.startswith("#") and "=" in _line:
                _k, _v = _line.split("=", 1)
                os.environ.setdefault(_k.strip(), _v.strip())

from pipeline.task_builder import (
    load_config, build_all_jobs, make_sbatch_script, ExperimentJob,
)
from pipeline.job_monitor import JobMonitor, submit_sbatch, submit_direct

CONFIG_PATH = Path(__file__).parent / "configs" / "experiments.yaml"
REPO_ROOT   = CONFIG_PATH.parent.parent.parent  # llm-maze-solver/


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _check_api_key(job: ExperimentJob) -> bool:
    val = os.environ.get(job.api_key_env, "")
    if not val:
        print(f"  [WARN] {job.api_key_env} not set — skipping: {job.label}")
        return False
    return True


def _print_job(job: ExperimentJob) -> None:
    print(f"\n  {job.label}")
    print(f"    cmd : {' '.join(job.python_cmd[:4])} ...")
    print(f"    wdir: {job.working_dir}")
    print(f"    out : {job.output_dir}")


# ---------------------------------------------------------------------------
# Run modes
# ---------------------------------------------------------------------------

def run_slurm(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
    log_dir = REPO_ROOT / "maze-solver" / "eval_llm_logs"
    log_dir.mkdir(parents=True, exist_ok=True)

    for job in jobs:
        if not _check_api_key(job):
            continue
        script_text = make_sbatch_script(job, log_dir)
        if dry_run:
            _print_job(job)
            print("  --- sbatch script ---")
            print(script_text)
            continue

        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".sh", prefix="spatialbench_",
            dir=log_dir, delete=False
        ) as tmp:
            tmp.write(script_text)
            script_path = tmp.name

        job_id = submit_sbatch(script_path)
        if job_id:
            monitor.add(
                job_id=job_id,
                label=job.label,
                task_id=job.task_id,
                model=job.model,
                output_dir=str(job.output_dir),
                log_out=str(log_dir / f"{job_id}.out"),
                log_err=str(log_dir / f"{job_id}.err"),
            )
            print(f"  Submitted {job.label}  →  SLURM job {job_id}")
        else:
            print(f"  [ERROR] Failed to submit: {job.label}")


def run_direct(jobs: list[ExperimentJob], monitor: JobMonitor, dry_run: bool) -> None:
    for job in jobs:
        if not _check_api_key(job):
            continue
        if dry_run:
            _print_job(job)
            print(f"  cmd: {' '.join(job.python_cmd)}\n")
            continue

        env_patch = {job.api_key_env: os.environ[job.api_key_env]}
        job.output_dir.mkdir(parents=True, exist_ok=True)

        print(f"  Starting: {job.label}")
        proc = submit_direct(
            cmd=job.python_cmd,
            working_dir=str(job.working_dir),
            env=env_patch,
        )
        monitor.add_direct(
            proc=proc,
            label=job.label,
            task_id=job.task_id,
            model=job.model,
            output_dir=str(job.output_dir),
        )
        # Small gap to avoid hammering APIs simultaneously
        time.sleep(2)


# ---------------------------------------------------------------------------
# Status display
# ---------------------------------------------------------------------------

def show_status(monitor: JobMonitor) -> None:
    monitor.refresh()
    summary = monitor.summary()
    print(f"\nTotal jobs: {summary['total']}")
    for status, count in summary["counts"].items():
        print(f"  {status:12s}: {count}")
    print()
    for r in summary["records"]:
        print(f"  [{r['status']:9s}] {r['label']:<60s}  elapsed: {r['elapsed']}")


# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="SpatialBench experiment orchestrator",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )

    parser.add_argument(
        "--tasks", nargs="+",
        default=["maze_navigation", "point_reuse", "compositional_distance"],
        choices=["maze_navigation", "point_reuse", "compositional_distance"],
        help="Which tasks to run (default: all three)",
    )
    parser.add_argument(
        "--models", nargs="+", default=None,
        help="Model IDs to run (default: all models in config)",
    )
    parser.add_argument(
        "--grid-sizes", nargs="+", type=int, default=None,
        dest="grid_sizes",
        help="Grid sizes to evaluate, e.g. --grid-sizes 5 6 7 (default: per-task config)",
    )
    parser.add_argument(
        "--formats", nargs="+", default=None,
        choices=["raw", "visual"],
        help="Input formats for Task 1 (default: both raw and visual)",
    )
    parser.add_argument(
        "--strategies", nargs="+", default=None,
        choices=["base", "cot", "reasoning"],
        help="Prompt strategies (default: all)",
    )
    parser.add_argument(
        "--mode", default="slurm", choices=["slurm", "direct"],
        help="Execution mode: 'slurm' submits sbatch jobs, 'direct' runs inline (default: slurm)",
    )
    parser.add_argument(
        "--dry-run", action="store_true",
        help="Print commands without executing them",
    )
    parser.add_argument(
        "--no-wait", action="store_true",
        help="Return immediately after submission (don't poll for completion)",
    )
    parser.add_argument(
        "--status", action="store_true",
        help="Query and display SLURM job status (requires --job-ids or a running monitor)",
    )
    parser.add_argument(
        "--job-ids", nargs="+", default=None,
        help="SLURM job IDs to check status for (used with --status)",
    )
    parser.add_argument(
        "--config", default=str(CONFIG_PATH),
        help=f"Path to experiments.yaml (default: {CONFIG_PATH})",
    )
    parser.add_argument(
        "--poll-interval", type=int, default=60,
        dest="poll_interval",
        help="Seconds between SLURM status polls when waiting (default: 60)",
    )

    return parser.parse_args()


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    args = parse_args()
    cfg = load_config(args.config)

    # Status-only mode
    if args.status:
        monitor = JobMonitor(mode="slurm")
        if args.job_ids:
            for jid in args.job_ids:
                monitor.add(job_id=jid, label=jid, task_id="?", model="?")
        show_status(monitor)
        return

    # Build jobs
    jobs = build_all_jobs(
        cfg=cfg,
        tasks=args.tasks,
        models=args.models,
        grid_sizes=args.grid_sizes,
        input_formats=args.formats,
        prompt_strategies=args.strategies,
        config_path=Path(args.config),
    )

    if not jobs:
        print("No jobs matched the requested filters.")
        return

    print(f"\nSpatialBench — {len(jobs)} job(s) to run")
    print(f"  mode     : {args.mode}")
    print(f"  tasks    : {args.tasks}")
    print(f"  models   : {args.models or 'all'}")
    print(f"  grids    : {args.grid_sizes or 'per-task default'}")
    print(f"  formats  : {args.formats or 'per-task default'}")
    print(f"  strategies: {args.strategies or 'all'}")
    print(f"  dry-run  : {args.dry_run}")
    print()

    monitor = JobMonitor(mode=args.mode)

    if args.mode == "slurm":
        run_slurm(jobs, monitor, dry_run=args.dry_run)
    else:
        run_direct(jobs, monitor, dry_run=args.dry_run)

    if args.dry_run or args.no_wait:
        if not args.dry_run:
            print(f"\nSubmitted {len(monitor.all_records())} job(s). Use --status to check progress.")
        return

    # Wait for completion
    print("\nWaiting for jobs to complete...")

    def _progress(summary: dict) -> None:
        counts = summary["counts"]
        parts = [f"{s}: {n}" for s, n in counts.items()]
        print(f"  [{time.strftime('%H:%M:%S')}] {' | '.join(parts)}")

    monitor.wait_all(poll_interval=args.poll_interval, callback=_progress)

    # Final summary
    summary = monitor.summary()
    print(f"\nDone. {summary['counts'].get('completed', 0)} completed, "
          f"{summary['counts'].get('failed', 0)} failed.")

    failed = [r for r in summary["records"] if r["status"] == "failed"]
    if failed:
        print("\nFailed jobs:")
        for r in failed:
            print(f"  {r['label']}  (job_id={r['job_id']})")


if __name__ == "__main__":
    main()