Spaces:
Sleeping
Sleeping
File size: 1,678 Bytes
90bdd23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from __future__ import annotations
import json
from pathlib import Path
from runway_zero.baselines import FifoPolicy, RandomPolicy, RecoveryPolicy, rollout
from runway_zero.qlearning import TrainedRLPolicy
def main() -> None:
out_dir = Path("results")
out_dir.mkdir(exist_ok=True)
seeds = [7, 11, 17]
rows = []
for stage in [1, 2, 3]:
trained_path = Path(f"results/trained/q_policy_stage{stage}.json")
policies = [RandomPolicy(1), FifoPolicy(), RecoveryPolicy()]
if trained_path.exists():
policies.append(TrainedRLPolicy.from_file(trained_path))
for seed in seeds:
for policy in policies:
result = rollout(policy, stage=stage, seed=seed)
metrics = result["metrics"]
rows.append(
{
"policy": result["policy"],
"stage": stage,
"seed": seed,
"total_reward": result["total_reward"],
"flights_arrived": metrics["flights_arrived"],
"flights_cancelled": metrics["flights_cancelled"],
"total_dep_delay": metrics["total_dep_delay"],
"total_arr_delay": metrics["total_arr_delay"],
"stranded_passengers": metrics["stranded_passengers"],
"avg_satisfaction": metrics["avg_satisfaction"],
}
)
(out_dir / "baseline_metrics.json").write_text(json.dumps(rows, indent=2), encoding="utf-8")
print(json.dumps(rows, indent=2))
if __name__ == "__main__":
main()
|