Spaces:
Sleeping
Sleeping
File size: 2,431 Bytes
90bdd23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from __future__ import annotations
import json
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
def main() -> None:
metrics_path = Path("results/baseline_metrics.json")
if not metrics_path.exists():
raise SystemExit("Run scripts/evaluate_baselines.py first")
rows = json.loads(metrics_path.read_text(encoding="utf-8"))
out_dir = Path("results/plots")
out_dir.mkdir(parents=True, exist_ok=True)
for metric in ["total_reward", "total_dep_delay", "stranded_passengers", "avg_satisfaction"]:
grouped = defaultdict(list)
for row in rows:
grouped[(row["stage"], row["policy"])].append(row[metric])
labels = []
values = []
for key in sorted(grouped):
stage, policy = key
labels.append(f"S{stage}\n{policy}")
values.append(sum(grouped[key]) / len(grouped[key]))
plt.figure(figsize=(12, 5))
plt.bar(labels, values, color=["#1f77b4", "#ff7f0e", "#2ca02c"] * 3)
plt.title(f"Runway Zero baseline comparison: {metric}")
plt.ylabel(metric)
plt.xticks(rotation=35, ha="right")
plt.tight_layout()
plt.savefig(out_dir / f"{metric}.png", dpi=180)
plt.close()
trained_dir = Path("results/trained")
for stage in [1, 2, 3]:
path = trained_dir / f"q_policy_stage{stage}.json"
if not path.exists():
continue
artifact = json.loads(path.read_text(encoding="utf-8"))
curve = artifact["learning_curve"]
window = 10
xs = [point["episode"] for point in curve]
ys = [point["reward"] for point in curve]
smooth = [
sum(ys[max(0, index - window + 1) : index + 1])
/ len(ys[max(0, index - window + 1) : index + 1])
for index in range(len(ys))
]
plt.figure(figsize=(10, 4))
plt.plot(xs, ys, color="#9ab8d6", linewidth=1, label="episode reward")
plt.plot(xs, smooth, color="#1f6fb8", linewidth=2, label="10-episode average")
plt.title(f"Runway Zero RL controller learning curve: stage {stage}")
plt.xlabel("episode")
plt.ylabel("reward")
plt.legend()
plt.tight_layout()
plt.savefig(out_dir / f"rl_learning_stage{stage}.png", dpi=180)
plt.close()
print(f"Wrote plots to {out_dir}")
if __name__ == "__main__":
main()
|