Spaces:
Runtime error
Runtime error
| """Full evaluation: every method × every instance + statistical battery + plots.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import pandas as pd | |
| from dash_jsp.benchmarks import taillard, lawrence, dmu | |
| from dash_jsp.bandit.linucb import LinUCBDispatcher | |
| from dash_jsp.bandit.thompson import ThompsonDispatcher | |
| from dash_jsp.eval.benchmark import MethodSpec, run_benchmark, aggregate_by_method | |
| from dash_jsp.eval.statistical import full_statistical_battery, write_results | |
| from dash_jsp.heuristics.rules import ALL_RULES | |
| def _bandit_factory(loader, path: Path): | |
| """Create a fresh dispatcher per seed, loading from disk if available.""" | |
| if path.exists(): | |
| return loader(str(path)) | |
| return loader.__self__() if hasattr(loader, "__self__") else loader() | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data-dir", default="data") | |
| parser.add_argument("--models-dir", default="models") | |
| parser.add_argument("--out-dir", default="results") | |
| parser.add_argument("--families", nargs="+", default=["taillard", "lawrence"]) | |
| parser.add_argument("--seeds", nargs="+", type=int, default=[0, 1, 2]) | |
| args = parser.parse_args() | |
| data_dir = Path(args.data_dir) | |
| models_dir = Path(args.models_dir) | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| instances = [] | |
| if "taillard" in args.families: | |
| instances.extend(taillard.load_all(data_dir / "taillard")) | |
| if "lawrence" in args.families: | |
| instances.extend(lawrence.load_all(data_dir / "lawrence")) | |
| if "dmu" in args.families: | |
| instances.extend(dmu.load_all(data_dir / "dmu")) | |
| methods = [ | |
| MethodSpec.from_fixed_rule(name, rule) for name, rule in ALL_RULES.items() | |
| ] | |
| # Bandit method specs | |
| linucb_path = models_dir / "bandit_linucb.npz" | |
| thompson_path = models_dir / "bandit_thompson.npz" | |
| def linucb_factory(_seed: int): | |
| b = ( | |
| LinUCBDispatcher.load(str(linucb_path)) | |
| if linucb_path.exists() else LinUCBDispatcher(rng_seed=_seed) | |
| ) | |
| return b | |
| def thompson_factory(_seed: int): | |
| b = ( | |
| ThompsonDispatcher.load(str(thompson_path)) | |
| if thompson_path.exists() else ThompsonDispatcher(rng_seed=_seed) | |
| ) | |
| return b | |
| methods.append(MethodSpec(name="dash_linucb", factory=linucb_factory, is_bandit=True)) | |
| methods.append(MethodSpec(name="dash_thompson", factory=thompson_factory, is_bandit=True)) | |
| print(f"Running {len(methods)} methods on {len(instances)} instances " | |
| f"× {len(args.seeds)} seeds") | |
| df = run_benchmark(methods, instances, seeds=args.seeds) | |
| out_csv = out_dir / "benchmark_full.csv" | |
| df.to_csv(out_csv, index=False) | |
| print(f"Wrote {out_csv} ({len(df)} rows)") | |
| summary = aggregate_by_method(df) | |
| summary.to_csv(out_dir / "benchmark_summary.csv", index=False) | |
| print(summary.to_string(index=False)) | |
| # Statistical battery | |
| stats = full_statistical_battery( | |
| df, headline_method="dash_linucb", metric="makespan", direction="lower", | |
| ) | |
| write_results(stats, str(out_dir / "statistical_tests.json")) | |
| print(f"Wrote {out_dir / 'statistical_tests.json'}") | |
| if __name__ == "__main__": | |
| main() | |