"""Full evaluation: every method × every instance + statistical battery + plots.""" from __future__ import annotations import argparse import json from pathlib import Path import pandas as pd from dash_jsp.benchmarks import taillard, lawrence, dmu from dash_jsp.bandit.linucb import LinUCBDispatcher from dash_jsp.bandit.thompson import ThompsonDispatcher from dash_jsp.eval.benchmark import MethodSpec, run_benchmark, aggregate_by_method from dash_jsp.eval.statistical import full_statistical_battery, write_results from dash_jsp.heuristics.rules import ALL_RULES def _bandit_factory(loader, path: Path): """Create a fresh dispatcher per seed, loading from disk if available.""" if path.exists(): return loader(str(path)) return loader.__self__() if hasattr(loader, "__self__") else loader() def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--data-dir", default="data") parser.add_argument("--models-dir", default="models") parser.add_argument("--out-dir", default="results") parser.add_argument("--families", nargs="+", default=["taillard", "lawrence"]) parser.add_argument("--seeds", nargs="+", type=int, default=[0, 1, 2]) args = parser.parse_args() data_dir = Path(args.data_dir) models_dir = Path(args.models_dir) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) instances = [] if "taillard" in args.families: instances.extend(taillard.load_all(data_dir / "taillard")) if "lawrence" in args.families: instances.extend(lawrence.load_all(data_dir / "lawrence")) if "dmu" in args.families: instances.extend(dmu.load_all(data_dir / "dmu")) methods = [ MethodSpec.from_fixed_rule(name, rule) for name, rule in ALL_RULES.items() ] # Bandit method specs linucb_path = models_dir / "bandit_linucb.npz" thompson_path = models_dir / "bandit_thompson.npz" def linucb_factory(_seed: int): b = ( LinUCBDispatcher.load(str(linucb_path)) if linucb_path.exists() else LinUCBDispatcher(rng_seed=_seed) ) return b def thompson_factory(_seed: int): b = ( ThompsonDispatcher.load(str(thompson_path)) if thompson_path.exists() else ThompsonDispatcher(rng_seed=_seed) ) return b methods.append(MethodSpec(name="dash_linucb", factory=linucb_factory, is_bandit=True)) methods.append(MethodSpec(name="dash_thompson", factory=thompson_factory, is_bandit=True)) print(f"Running {len(methods)} methods on {len(instances)} instances " f"× {len(args.seeds)} seeds") df = run_benchmark(methods, instances, seeds=args.seeds) out_csv = out_dir / "benchmark_full.csv" df.to_csv(out_csv, index=False) print(f"Wrote {out_csv} ({len(df)} rows)") summary = aggregate_by_method(df) summary.to_csv(out_dir / "benchmark_summary.csv", index=False) print(summary.to_string(index=False)) # Statistical battery stats = full_statistical_battery( df, headline_method="dash_linucb", metric="makespan", direction="lower", ) write_results(stats, str(out_dir / "statistical_tests.json")) print(f"Wrote {out_dir / 'statistical_tests.json'}") if __name__ == "__main__": main()