Agora-Opt / scripts /run_ablation_suite.py
SorrowTea's picture
Upload 45 files
96abbd8 verified
#!/usr/bin/env python3
"""
Run a suite of ablation experiments (generation + evaluation) and summarise results.
"""
from __future__ import annotations
import argparse
import json
import os
import shlex
import subprocess
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
PROJECT_ROOT = Path(__file__).resolve().parent.parent
STANDARD_RESULTS_ROOT = PROJECT_ROOT.parent.parent / "results" / "Agora-Opt"
GENERATE_SCRIPT = PROJECT_ROOT / "scripts" / "generate_with_memory.py"
EXECUTE_SCRIPT = PROJECT_ROOT / "scripts" / "execute.py"
PYTHON_BIN = os.environ.get("PYTHON_BIN", sys.executable)
@dataclass
class Variant:
name: str
description: str
overrides: Dict[str, object]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run generate+evaluate ablations and emit a summary table."
)
parser.add_argument("--model", type=str, default="gpt-4o", help="LLM to query.")
parser.add_argument(
"--datasets",
nargs="+",
default=["IndustryOR", "ComplexLP"],
help="Datasets to evaluate (space-separated, omit .jsonl).",
)
parser.add_argument("--temperature", type=float, default=0.01)
parser.add_argument(
"--max_problems",
type=int,
default=None,
help="Limit number of problems per dataset (omit for full set).",
)
parser.add_argument("--memory_dir", type=str, default="memory_storage")
parser.add_argument(
"--memory_top_k",
type=int,
default=3,
help="Base episodic memory retrieval count for the full variant.",
)
parser.add_argument(
"--max_retries",
type=int,
default=5,
help="Base retry budget for the full variant.",
)
parser.add_argument(
"--debug_case_top_k",
type=int,
default=3,
help="Base debug-case retrieval count.",
)
parser.add_argument(
"--parallel",
type=int,
default=64,
help="Workers for generation (passed to --parallel).",
)
parser.add_argument(
"--execution_timeout",
type=int,
default=90,
help="Timeout per execution attempt in generate_with_memory.",
)
parser.add_argument(
"--debug_memory_path",
type=str,
default="memory_storage/debug_memory.jsonl",
help="Path to debug memory JSONL.",
)
parser.add_argument(
"--debug_case_dir",
type=str,
default="debug_case_memory",
help="Directory containing consolidated debug-case memory.",
)
parser.add_argument(
"--output_root",
type=str,
default=str(STANDARD_RESULTS_ROOT / "ablations"),
help="Root folder for storing ablation artefacts.",
)
parser.add_argument(
"--eval_timeout",
type=int,
default=90,
help="Timeout for scripts/execute.py.",
)
parser.add_argument(
"--num_workers",
type=int,
default=64,
help="ProcessPool workers for evaluation.",
)
parser.add_argument("--tolerance", type=float, default=0.05)
parser.add_argument(
"--relative_tolerance",
action="store_true",
help="Use relative tolerance in evaluation.",
)
parser.add_argument(
"--dry_run",
action="store_true",
help="Print commands without executing or aggregating results.",
)
return parser.parse_args()
def build_variants(args: argparse.Namespace) -> List[Variant]:
base = {
"memory_top_k": args.memory_top_k,
"use_llm_refinement": True,
"debug_case_memory_top_k": args.debug_case_top_k,
"max_retries": args.max_retries,
"auto_debug": True,
}
return [
Variant(
name="full_system",
description="All helpers enabled (reference).",
overrides={**base},
),
Variant(
name="no_llm_refine",
description="Skip LLM summarisation of retrieved cases.",
overrides={**base, "use_llm_refinement": False},
),
Variant(
name="no_debug_case_memory",
description="Disable historical debug-case retrieval.",
overrides={**base, "debug_case_memory_top_k": 0},
),
Variant(
name="no_self_healing",
description="Single attempt (max_retries=1) but still executes locally once.",
overrides={**base, "max_retries": 1},
),
Variant(
name="no_memory",
description="Disable episodic retrieval, keep retries on.",
overrides={**base, "memory_top_k": 0, "use_llm_refinement": False},
),
Variant(
name="vanilla_llm",
description="Pure single-shot LLM (no memory, no auto-debug).",
overrides={
**base,
"memory_top_k": 0,
"use_llm_refinement": False,
"debug_case_memory_top_k": 0,
"max_retries": 1,
"auto_debug": False,
},
),
]
def run_command(cmd: Sequence[str], dry_run: bool = False) -> None:
pretty = " ".join(shlex.quote(part) for part in cmd)
print(f" → {pretty}")
if dry_run:
return
subprocess.run(cmd, check=True)
def compute_attempt_stats(path: Path) -> Tuple[float, int]:
if not path.exists():
return 0.0, 0
total = 0
total_attempts = 0
multi_attempt = 0
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
record = json.loads(line)
attempts = record.get("total_attempts", 1)
total_attempts += attempts
total += 1
if attempts > 1:
multi_attempt += 1
avg = (total_attempts / total) if total else 0.0
return avg, multi_attempt
def format_percent(value: float) -> str:
return f"{value * 100:.1f}%"
def build_generate_args(
dataset: str,
output_file: Path,
debug_dir: Path,
args: argparse.Namespace,
cfg: Dict[str, object],
) -> List[str]:
cmd = [
os.fspath(GENERATE_SCRIPT),
"--dataset",
dataset,
"--model",
args.model,
"--temperature",
str(args.temperature),
"--output",
os.fspath(output_file),
"--memory_dir",
os.fspath(Path(args.memory_dir).resolve()),
"--parallel",
str(args.parallel),
"--execution_timeout",
str(args.execution_timeout),
"--debug_memory_path",
os.fspath(Path(args.debug_memory_path).resolve()),
"--debug_case_memory_dir",
os.fspath(Path(args.debug_case_dir).resolve()),
"--debug_case_memory_top_k",
str(int(cfg.get("debug_case_memory_top_k", 0))),
"--memory_top_k",
str(int(cfg.get("memory_top_k", 0))),
"--max_retries",
str(int(cfg.get("max_retries", 1))),
]
if args.max_problems:
cmd += ["--max_problems", str(args.max_problems)]
if cfg.get("use_llm_refinement"):
cmd.append("--use_llm_refinement")
if not cfg.get("filter_perfect", True):
cmd.append("--no_filter_perfect")
if not cfg.get("auto_debug", True):
cmd.append("--no_auto_debug")
if debug_dir:
cmd += ["--debug_output_dir", os.fspath(debug_dir)]
return [os.fspath(part) for part in cmd]
def build_execute_args(input_file: Path, output_dir: Path, args: argparse.Namespace) -> List[str]:
cmd = [
os.fspath(EXECUTE_SCRIPT),
"--input_file",
os.fspath(input_file),
"--output_dir",
os.fspath(output_dir),
"--timeout",
str(args.eval_timeout),
"--tolerance",
str(args.tolerance),
"--num_workers",
str(args.num_workers),
"--memory_dir",
os.fspath(Path(args.memory_dir).resolve()),
"--debug_memory_path",
os.fspath(Path(args.debug_memory_path).resolve()),
]
if args.relative_tolerance:
cmd.append("--use_relative_tolerance")
return cmd
def summarise_records(records: List[Dict], summary_path: Path) -> None:
if not records:
return
md_lines = [
"| Dataset | Variant | Accuracy | Correct/Total | Exec Err % | Timeout % | No-Code % | Avg Attempts | Notes |",
"| --- | --- | --- | --- | --- | --- | --- | --- | --- |",
]
csv_lines = [
"dataset,variant,accuracy,correct,total,exec_error_pct,timeout_pct,no_code_pct,avg_attempts,notes"
]
for record in records:
dataset = record["dataset"]
variant = record["variant"]
report = record["report"]
status_counts = report.get("status_counts", {})
total = report.get("total_problems", 0)
accuracy_pct = format_percent(report.get("accuracy", 0.0))
correct = report.get("correct", 0)
exec_err_pct = (
(status_counts.get("execution_error", 0) / total) if total else 0.0
)
timeout_pct = (status_counts.get("timeout", 0) / total) if total else 0.0
no_code_pct = (status_counts.get("no_code", 0) / total) if total else 0.0
avg_attempts = record.get("avg_attempts", 0.0)
notes = record["notes"]
md_lines.append(
f"| {dataset} | {variant} | {accuracy_pct} | {correct}/{total} | "
f"{exec_err_pct*100:.1f}% | {timeout_pct*100:.1f}% | {no_code_pct*100:.1f}% | "
f"{avg_attempts:.2f} | {notes} |"
)
safe_notes = notes.replace('"', '""')
csv_lines.append(
f"{dataset},{variant},{report.get('accuracy',0.0):.4f},{correct},{total},"
f"{exec_err_pct:.4f},{timeout_pct:.4f},{no_code_pct:.4f},{avg_attempts:.4f},\"{safe_notes}\""
)
summary_path.write_text("\n".join(md_lines) + "\n", encoding="utf-8")
csv_path = summary_path.with_suffix(".csv")
csv_path.write_text("\n".join(csv_lines) + "\n", encoding="utf-8")
print(f"\n✅ Summary table written to: {summary_path}")
print(f"📄 CSV export written to: {csv_path}")
def main() -> None:
args = parse_args()
variants = build_variants(args)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_root = Path(args.output_root).resolve() / timestamp
if not args.dry_run:
run_root.mkdir(parents=True, exist_ok=True)
print("========================================")
print("Ablation Runner")
print("========================================")
print(f"Model: {args.model}")
print(f"Datasets: {', '.join(args.datasets)}")
print(f"Output root: {run_root if not args.dry_run else args.output_root}")
print(f"Dry run: {args.dry_run}")
print("========================================\n")
records: List[Dict] = []
for dataset in args.datasets:
print(f"Dataset: {dataset}")
for variant in variants:
cfg = variant.overrides
variant_name = variant.name
print(f" Variant: {variant_name}{variant.description}")
dataset_slug = dataset.replace("/", "_")
gen_output = (
run_root / f"{dataset_slug}_{variant_name}.jsonl"
if not args.dry_run
else Path(f"{dataset_slug}_{variant_name}.jsonl")
)
debug_dir = (
run_root / "debug" / dataset_slug / variant_name
if not args.dry_run
else Path(f"debug/{dataset_slug}/{variant_name}")
)
eval_dir = (
run_root / f"{dataset_slug}_{variant_name}_eval"
if not args.dry_run
else Path(f"{dataset_slug}_{variant_name}_eval")
)
if not args.dry_run:
debug_dir.mkdir(parents=True, exist_ok=True)
gen_cmd = [PYTHON_BIN] + build_generate_args(
dataset, gen_output, debug_dir, args, cfg
)
run_command(gen_cmd, dry_run=args.dry_run)
exec_cmd = [
PYTHON_BIN,
] + build_execute_args(gen_output, eval_dir, args)
run_command(exec_cmd, dry_run=args.dry_run)
if args.dry_run:
continue
report_path = eval_dir / "evaluation_report.json"
if not report_path.exists():
raise FileNotFoundError(
f"Missing evaluation report for {dataset} / {variant_name}: {report_path}"
)
with report_path.open("r", encoding="utf-8") as handle:
report = json.load(handle)
avg_attempts, _ = compute_attempt_stats(gen_output)
records.append(
{
"dataset": dataset,
"variant": variant_name,
"report": report,
"avg_attempts": avg_attempts,
"notes": variant.description,
}
)
print("")
if args.dry_run:
print("Dry run completed. No commands were executed.")
return
summary_path = run_root / "ablation_summary.md"
summarise_records(records, summary_path)
if __name__ == "__main__":
main()