File size: 3,321 Bytes
52c82e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Full evaluation: every method × every instance + statistical battery + plots."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import pandas as pd

from dash_jsp.benchmarks import taillard, lawrence, dmu
from dash_jsp.bandit.linucb import LinUCBDispatcher
from dash_jsp.bandit.thompson import ThompsonDispatcher
from dash_jsp.eval.benchmark import MethodSpec, run_benchmark, aggregate_by_method
from dash_jsp.eval.statistical import full_statistical_battery, write_results
from dash_jsp.heuristics.rules import ALL_RULES


def _bandit_factory(loader, path: Path):
    """Create a fresh dispatcher per seed, loading from disk if available."""
    if path.exists():
        return loader(str(path))
    return loader.__self__() if hasattr(loader, "__self__") else loader()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-dir", default="data")
    parser.add_argument("--models-dir", default="models")
    parser.add_argument("--out-dir", default="results")
    parser.add_argument("--families", nargs="+", default=["taillard", "lawrence"])
    parser.add_argument("--seeds", nargs="+", type=int, default=[0, 1, 2])
    args = parser.parse_args()

    data_dir = Path(args.data_dir)
    models_dir = Path(args.models_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    instances = []
    if "taillard" in args.families:
        instances.extend(taillard.load_all(data_dir / "taillard"))
    if "lawrence" in args.families:
        instances.extend(lawrence.load_all(data_dir / "lawrence"))
    if "dmu" in args.families:
        instances.extend(dmu.load_all(data_dir / "dmu"))

    methods = [
        MethodSpec.from_fixed_rule(name, rule) for name, rule in ALL_RULES.items()
    ]

    # Bandit method specs
    linucb_path = models_dir / "bandit_linucb.npz"
    thompson_path = models_dir / "bandit_thompson.npz"

    def linucb_factory(_seed: int):
        b = (
            LinUCBDispatcher.load(str(linucb_path))
            if linucb_path.exists() else LinUCBDispatcher(rng_seed=_seed)
        )
        return b

    def thompson_factory(_seed: int):
        b = (
            ThompsonDispatcher.load(str(thompson_path))
            if thompson_path.exists() else ThompsonDispatcher(rng_seed=_seed)
        )
        return b

    methods.append(MethodSpec(name="dash_linucb", factory=linucb_factory, is_bandit=True))
    methods.append(MethodSpec(name="dash_thompson", factory=thompson_factory, is_bandit=True))

    print(f"Running {len(methods)} methods on {len(instances)} instances "
          f"× {len(args.seeds)} seeds")
    df = run_benchmark(methods, instances, seeds=args.seeds)

    out_csv = out_dir / "benchmark_full.csv"
    df.to_csv(out_csv, index=False)
    print(f"Wrote {out_csv} ({len(df)} rows)")

    summary = aggregate_by_method(df)
    summary.to_csv(out_dir / "benchmark_summary.csv", index=False)
    print(summary.to_string(index=False))

    # Statistical battery
    stats = full_statistical_battery(
        df, headline_method="dash_linucb", metric="makespan", direction="lower",
    )
    write_results(stats, str(out_dir / "statistical_tests.json"))
    print(f"Wrote {out_dir / 'statistical_tests.json'}")


if __name__ == "__main__":
    main()