SpatialBench / pipeline /task_builder.py
weijiang99's picture
Upload folder using huggingface_hub
cffeecf verified
"""
task_builder.py
---------------
Translates experiments.yaml into concrete shell commands (direct or sbatch).
Each public function returns a list of ExperimentJob dataclasses, one per
(model × format × prompt_strategy × grid_sizes) combination. The caller
decides whether to run them directly or wrap them in sbatch.
"""
from __future__ import annotations
import os
import itertools
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class ExperimentJob:
"""A single runnable experiment unit."""
task_id: str # e.g. "maze_navigation"
model: str # e.g. "gemini-2.5-flash"
label: str # human-readable label for this job
working_dir: Path # where to cd before running
python_cmd: list[str] # [python, script.py, --arg, value, ...]
api_key_env: str # env-var name that must be set
output_dir: Path # where results land
sbatch_cfg: dict # mem, time, cpus, partition, log_dir
grid_sizes: list[int] # for display / filtering
# ---------------------------------------------------------------------------
# Config loader
# ---------------------------------------------------------------------------
def load_config(config_path: str | Path) -> dict:
with open(config_path) as f:
return yaml.safe_load(f)
def _repo_root(config_path: Path) -> Path:
"""pipeline/configs/experiments.yaml → llm-maze-solver/"""
return config_path.parent.parent.parent
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _merge_sbatch(defaults: dict, override: dict) -> dict:
merged = dict(defaults)
merged.update(override)
return merged
def _grid_str(grid_sizes: list[int]) -> str:
return ",".join(str(g) for g in grid_sizes)
def _output_subdir(base: str, model: str, tag: str) -> str:
"""Produce a deterministic output subdirectory path."""
return f"{base}/{model.replace('.', '_').replace('-', '_')}/{tag}"
# ---------------------------------------------------------------------------
# Maze Navigation
# ---------------------------------------------------------------------------
def build_maze_navigation_jobs(
cfg: dict,
models: list[str] | None = None,
grid_sizes: list[int] | None = None,
input_formats: list[str] | None = None,
prompt_strategies: list[str] | None = None,
config_path: Path = None,
) -> list[ExperimentJob]:
"""Build jobs for Maze Navigation (planning, k-shot)."""
task = cfg["maze_navigation"]
defaults = cfg["defaults"]
model_cfg = cfg["models"]
selected_models = models or list(model_cfg.keys())
selected_formats = input_formats or task["input_formats"]
selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
selected_grids = grid_sizes or task["grid_sizes"]
repo = _repo_root(config_path) if config_path else Path(".")
script = repo / task["script"]
wdir = repo / task["working_dir"]
jobs: list[ExperimentJob] = []
for model, fmt, strat in itertools.product(
selected_models, selected_formats, selected_strategies
):
if model not in model_cfg:
continue
strat_cfg = task["prompt_strategies"][strat]
tag = f"{fmt}_input_{strat}"
out_dir = repo / _output_subdir(task["output_base"], model, tag)
cmd = [
"python", str(script),
"--model_name", model,
"--input_format", fmt,
"--k_shots", task["k_shots"],
"--n_test_mazes", str(cfg["defaults"]["n_test_mazes"]),
"--test_grid_sizes", _grid_str(selected_grids),
"--maze_type", task["maze_type"],
"--seed", str(defaults["seed"]),
"--output_dir", str(out_dir),
]
for flag in strat_cfg["flags"]:
cmd.append(flag)
if task.get("visualize"):
cmd.append("--visualize")
jobs.append(ExperimentJob(
task_id="maze_navigation",
model=model,
label=f"Maze Navigation | {model} | {fmt} | {strat}",
working_dir=wdir,
python_cmd=cmd,
api_key_env=model_cfg[model]["api_key_env"],
output_dir=out_dir,
sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
grid_sizes=selected_grids,
))
return jobs
# ---------------------------------------------------------------------------
# Sequential Reasoning with Point Reuse (Q3 = Q0)
# ---------------------------------------------------------------------------
def build_point_reuse_jobs(
cfg: dict,
models: list[str] | None = None,
grid_sizes: list[int] | None = None,
prompt_strategies: list[str] | None = None,
config_path: Path = None,
) -> list[ExperimentJob]:
"""Build jobs for Sequential Reasoning with Point Reuse (Q3=Q0)."""
task = cfg["point_reuse"]
defaults = cfg["defaults"]
model_cfg = cfg["models"]
selected_models = models or list(model_cfg.keys())
selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
selected_grids = grid_sizes or task["grid_sizes"]
repo = _repo_root(config_path) if config_path else Path(".")
script = repo / task["script"]
wdir = repo / task["working_dir"]
jobs: list[ExperimentJob] = []
for model, strat in itertools.product(selected_models, selected_strategies):
if model not in model_cfg:
continue
strat_cfg = task["prompt_strategies"][strat]
tag = f"point_reuse_q3q0_{strat}"
out_dir = repo / _output_subdir(task["output_base"], model, tag)
cmd = [
"python", str(script),
"--model_name", model,
"--input_format", task["input_format"],
"--strategy", task["strategy"],
"--reuse_pattern", task["reuse_pattern"],
"--prompt_type", strat_cfg["prompt_type"],
"--n_questions_per_maze", str(task["n_questions_per_maze"]),
"--n_test_mazes", str(defaults["n_test_mazes"]),
"--test_grid_sizes", _grid_str(selected_grids),
"--output_dir", str(out_dir),
]
if task.get("sequential_questions"):
cmd.append("--sequential_questions")
if task.get("visualize"):
cmd.append("--visualize")
if task.get("save_details"):
cmd.append("--save_details")
jobs.append(ExperimentJob(
task_id="point_reuse",
model=model,
label=f"Point Reuse | {model} | {strat}",
working_dir=wdir,
python_cmd=cmd,
api_key_env=model_cfg[model]["api_key_env"],
output_dir=out_dir,
sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
grid_sizes=selected_grids,
))
return jobs
# ---------------------------------------------------------------------------
# Compositional Distance Comparison
# ---------------------------------------------------------------------------
def build_compositional_distance_jobs(
cfg: dict,
models: list[str] | None = None,
grid_sizes: list[int] | None = None,
prompt_strategies: list[str] | None = None,
config_path: Path = None,
) -> list[ExperimentJob]:
"""Build jobs for Compositional Distance Comparison (corners-to-center)."""
task = cfg["compositional_distance"]
defaults = cfg["defaults"]
model_cfg = cfg["models"]
selected_models = models or list(model_cfg.keys())
selected_strategies = prompt_strategies or list(task["prompt_strategies"].keys())
selected_grids = grid_sizes or task["grid_sizes"]
repo = _repo_root(config_path) if config_path else Path(".")
script = repo / task["script"]
wdir = repo / task["working_dir"]
jobs: list[ExperimentJob] = []
for model, strat in itertools.product(selected_models, selected_strategies):
if model not in model_cfg:
continue
strat_cfg = task["prompt_strategies"][strat]
tag = f"orthogonal_{task['corner_pattern']}_{strat}"
out_dir = repo / _output_subdir(task["output_base"], model, tag)
cmd = [
"python", str(script),
"--model_name", model,
"--input_format", task["input_format"],
"--strategy", task["strategy"],
"--corner_pattern", task["corner_pattern"],
"--prompt_type", strat_cfg["prompt_type"],
"--n_questions_per_maze", str(task["n_questions_per_maze"]),
"--n_test_mazes", str(defaults["n_test_mazes"]),
"--test_grid_sizes", _grid_str(selected_grids),
"--output_dir", str(out_dir),
]
if task.get("visualize"):
cmd.append("--visualize")
if task.get("save_details"):
cmd.append("--save_details")
jobs.append(ExperimentJob(
task_id="compositional_distance",
model=model,
label=f"Compositional Distance | {model} | {strat}",
working_dir=wdir,
python_cmd=cmd,
api_key_env=model_cfg[model]["api_key_env"],
output_dir=out_dir,
sbatch_cfg=_merge_sbatch(defaults["sbatch"], task.get("sbatch", {})),
grid_sizes=selected_grids,
))
return jobs
# ---------------------------------------------------------------------------
# Unified builder
# ---------------------------------------------------------------------------
def build_all_jobs(
cfg: dict,
tasks: list[str] | None = None,
models: list[str] | None = None,
grid_sizes: list[int] | None = None,
input_formats: list[str] | None = None,
prompt_strategies: list[str] | None = None,
config_path: Path = None,
) -> list[ExperimentJob]:
"""Build jobs for all requested tasks."""
selected_tasks = tasks or ["maze_navigation", "point_reuse", "compositional_distance"]
jobs: list[ExperimentJob] = []
kw = dict(
models=models,
grid_sizes=grid_sizes,
prompt_strategies=prompt_strategies,
config_path=config_path,
)
if "maze_navigation" in selected_tasks:
jobs += build_maze_navigation_jobs(cfg, input_formats=input_formats, **kw)
if "point_reuse" in selected_tasks:
jobs += build_point_reuse_jobs(cfg, **kw)
if "compositional_distance" in selected_tasks:
jobs += build_compositional_distance_jobs(cfg, **kw)
return jobs
# ---------------------------------------------------------------------------
# sbatch script generator
# ---------------------------------------------------------------------------
def make_sbatch_script(job: ExperimentJob, log_dir: Path) -> str:
"""Return the text of an sbatch submission script for a job."""
s = job.sbatch_cfg
log_dir.mkdir(parents=True, exist_ok=True)
safe_label = job.label.replace(" ", "_").replace("|", "").replace("/", "_")
lines = [
"#!/bin/bash",
f"#SBATCH -c {s.get('cpus', 2)}",
f"#SBATCH -t {s.get('time', '10:00:00')}",
f"#SBATCH -p {s.get('partition', 'short')}",
f"#SBATCH --mem={s.get('mem', '8G')}",
f"#SBATCH -o {log_dir}/{safe_label}_%j.out",
f"#SBATCH -e {log_dir}/{safe_label}_%j.err",
"",
f"# {job.label}",
f"export {job.api_key_env}=${{{job.api_key_env}}}",
"",
f"cd {job.working_dir}",
" \\\n ".join(job.python_cmd),
]
return "\n".join(lines) + "\n"