runway-zero / scripts /evaluate_baselines.py
work-dwivediishivam's picture
Space Redeployed to preserve GPU; switched to FREE TIER; fixed GUI with Actual Replay
90bdd23
from __future__ import annotations
import json
from pathlib import Path
from runway_zero.baselines import FifoPolicy, RandomPolicy, RecoveryPolicy, rollout
from runway_zero.qlearning import TrainedRLPolicy
def main() -> None:
out_dir = Path("results")
out_dir.mkdir(exist_ok=True)
seeds = [7, 11, 17]
rows = []
for stage in [1, 2, 3]:
trained_path = Path(f"results/trained/q_policy_stage{stage}.json")
policies = [RandomPolicy(1), FifoPolicy(), RecoveryPolicy()]
if trained_path.exists():
policies.append(TrainedRLPolicy.from_file(trained_path))
for seed in seeds:
for policy in policies:
result = rollout(policy, stage=stage, seed=seed)
metrics = result["metrics"]
rows.append(
{
"policy": result["policy"],
"stage": stage,
"seed": seed,
"total_reward": result["total_reward"],
"flights_arrived": metrics["flights_arrived"],
"flights_cancelled": metrics["flights_cancelled"],
"total_dep_delay": metrics["total_dep_delay"],
"total_arr_delay": metrics["total_arr_delay"],
"stranded_passengers": metrics["stranded_passengers"],
"avg_satisfaction": metrics["avg_satisfaction"],
}
)
(out_dir / "baseline_metrics.json").write_text(json.dumps(rows, indent=2), encoding="utf-8")
print(json.dumps(rows, indent=2))
if __name__ == "__main__":
main()