dash-jsp-trainer / scripts /run_evaluation.py
Vittal-M's picture
Trainer Space: download -> train -> push -> sleep
52c82e4
"""Full evaluation: every method × every instance + statistical battery + plots."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pandas as pd
from dash_jsp.benchmarks import taillard, lawrence, dmu
from dash_jsp.bandit.linucb import LinUCBDispatcher
from dash_jsp.bandit.thompson import ThompsonDispatcher
from dash_jsp.eval.benchmark import MethodSpec, run_benchmark, aggregate_by_method
from dash_jsp.eval.statistical import full_statistical_battery, write_results
from dash_jsp.heuristics.rules import ALL_RULES
def _bandit_factory(loader, path: Path):
"""Create a fresh dispatcher per seed, loading from disk if available."""
if path.exists():
return loader(str(path))
return loader.__self__() if hasattr(loader, "__self__") else loader()
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default="data")
parser.add_argument("--models-dir", default="models")
parser.add_argument("--out-dir", default="results")
parser.add_argument("--families", nargs="+", default=["taillard", "lawrence"])
parser.add_argument("--seeds", nargs="+", type=int, default=[0, 1, 2])
args = parser.parse_args()
data_dir = Path(args.data_dir)
models_dir = Path(args.models_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
instances = []
if "taillard" in args.families:
instances.extend(taillard.load_all(data_dir / "taillard"))
if "lawrence" in args.families:
instances.extend(lawrence.load_all(data_dir / "lawrence"))
if "dmu" in args.families:
instances.extend(dmu.load_all(data_dir / "dmu"))
methods = [
MethodSpec.from_fixed_rule(name, rule) for name, rule in ALL_RULES.items()
]
# Bandit method specs
linucb_path = models_dir / "bandit_linucb.npz"
thompson_path = models_dir / "bandit_thompson.npz"
def linucb_factory(_seed: int):
b = (
LinUCBDispatcher.load(str(linucb_path))
if linucb_path.exists() else LinUCBDispatcher(rng_seed=_seed)
)
return b
def thompson_factory(_seed: int):
b = (
ThompsonDispatcher.load(str(thompson_path))
if thompson_path.exists() else ThompsonDispatcher(rng_seed=_seed)
)
return b
methods.append(MethodSpec(name="dash_linucb", factory=linucb_factory, is_bandit=True))
methods.append(MethodSpec(name="dash_thompson", factory=thompson_factory, is_bandit=True))
print(f"Running {len(methods)} methods on {len(instances)} instances "
f"× {len(args.seeds)} seeds")
df = run_benchmark(methods, instances, seeds=args.seeds)
out_csv = out_dir / "benchmark_full.csv"
df.to_csv(out_csv, index=False)
print(f"Wrote {out_csv} ({len(df)} rows)")
summary = aggregate_by_method(df)
summary.to_csv(out_dir / "benchmark_summary.csv", index=False)
print(summary.to_string(index=False))
# Statistical battery
stats = full_statistical_battery(
df, headline_method="dash_linucb", metric="makespan", direction="lower",
)
write_results(stats, str(out_dir / "statistical_tests.json"))
print(f"Wrote {out_dir / 'statistical_tests.json'}")
if __name__ == "__main__":
main()