File size: 2,431 Bytes
90bdd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations

import json
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt


def main() -> None:
    metrics_path = Path("results/baseline_metrics.json")
    if not metrics_path.exists():
        raise SystemExit("Run scripts/evaluate_baselines.py first")
    rows = json.loads(metrics_path.read_text(encoding="utf-8"))
    out_dir = Path("results/plots")
    out_dir.mkdir(parents=True, exist_ok=True)

    for metric in ["total_reward", "total_dep_delay", "stranded_passengers", "avg_satisfaction"]:
        grouped = defaultdict(list)
        for row in rows:
            grouped[(row["stage"], row["policy"])].append(row[metric])
        labels = []
        values = []
        for key in sorted(grouped):
            stage, policy = key
            labels.append(f"S{stage}\n{policy}")
            values.append(sum(grouped[key]) / len(grouped[key]))
        plt.figure(figsize=(12, 5))
        plt.bar(labels, values, color=["#1f77b4", "#ff7f0e", "#2ca02c"] * 3)
        plt.title(f"Runway Zero baseline comparison: {metric}")
        plt.ylabel(metric)
        plt.xticks(rotation=35, ha="right")
        plt.tight_layout()
        plt.savefig(out_dir / f"{metric}.png", dpi=180)
        plt.close()
    trained_dir = Path("results/trained")
    for stage in [1, 2, 3]:
        path = trained_dir / f"q_policy_stage{stage}.json"
        if not path.exists():
            continue
        artifact = json.loads(path.read_text(encoding="utf-8"))
        curve = artifact["learning_curve"]
        window = 10
        xs = [point["episode"] for point in curve]
        ys = [point["reward"] for point in curve]
        smooth = [
            sum(ys[max(0, index - window + 1) : index + 1])
            / len(ys[max(0, index - window + 1) : index + 1])
            for index in range(len(ys))
        ]
        plt.figure(figsize=(10, 4))
        plt.plot(xs, ys, color="#9ab8d6", linewidth=1, label="episode reward")
        plt.plot(xs, smooth, color="#1f6fb8", linewidth=2, label="10-episode average")
        plt.title(f"Runway Zero RL controller learning curve: stage {stage}")
        plt.xlabel("episode")
        plt.ylabel("reward")
        plt.legend()
        plt.tight_layout()
        plt.savefig(out_dir / f"rl_learning_stage{stage}.png", dpi=180)
        plt.close()
    print(f"Wrote plots to {out_dir}")


if __name__ == "__main__":
    main()