File size: 1,678 Bytes
90bdd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import annotations

import json
from pathlib import Path

from runway_zero.baselines import FifoPolicy, RandomPolicy, RecoveryPolicy, rollout
from runway_zero.qlearning import TrainedRLPolicy


def main() -> None:
    out_dir = Path("results")
    out_dir.mkdir(exist_ok=True)
    seeds = [7, 11, 17]
    rows = []
    for stage in [1, 2, 3]:
        trained_path = Path(f"results/trained/q_policy_stage{stage}.json")
        policies = [RandomPolicy(1), FifoPolicy(), RecoveryPolicy()]
        if trained_path.exists():
            policies.append(TrainedRLPolicy.from_file(trained_path))
        for seed in seeds:
            for policy in policies:
                result = rollout(policy, stage=stage, seed=seed)
                metrics = result["metrics"]
                rows.append(
                    {
                        "policy": result["policy"],
                        "stage": stage,
                        "seed": seed,
                        "total_reward": result["total_reward"],
                        "flights_arrived": metrics["flights_arrived"],
                        "flights_cancelled": metrics["flights_cancelled"],
                        "total_dep_delay": metrics["total_dep_delay"],
                        "total_arr_delay": metrics["total_arr_delay"],
                        "stranded_passengers": metrics["stranded_passengers"],
                        "avg_satisfaction": metrics["avg_satisfaction"],
                    }
                )
    (out_dir / "baseline_metrics.json").write_text(json.dumps(rows, indent=2), encoding="utf-8")
    print(json.dumps(rows, indent=2))


if __name__ == "__main__":
    main()