File size: 3,754 Bytes
c1060df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
"""Plot AdaptShield training CSV or metrics JSON."""

from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path
from typing import List, Tuple


def load_scores(path: Path) -> Tuple[List[int], List[float], str, List[str]]:
    if path.suffix == ".json":
        data = json.loads(path.read_text())
        rows = data.get("rows", []) or data.get("evaluation_rows", [])
        episodes = [int(row["episode"]) for row in rows]
        scores = [float(row["score"]) for row in rows]
        stages = [str(row.get("stage", row.get("task", ""))) for row in rows]
        return episodes, scores, str(data.get("model", "adaptshield")), stages

    with path.open() as handle:
        rows = list(csv.DictReader(handle))
    episodes = [int(row["episode"]) for row in rows]
    scores = [float(row["score"]) for row in rows]
    stages = [str(row.get("stage", row.get("task", ""))) for row in rows]
    return episodes, scores, "adaptshield-smoke", stages


def moving_average(values: List[float], window: int) -> List[float]:
    smoothed = []
    for index in range(len(values)):
        start = max(0, index - window + 1)
        chunk = values[start:index + 1]
        smoothed.append(sum(chunk) / len(chunk))
    return smoothed


def plot(path: Path, output: Path) -> None:
    episodes, scores, label, stages = load_scores(path)
    if not scores:
        raise SystemExit("No scores found to plot.")

    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        first = sum(scores[:max(1, len(scores) // 5)]) / max(1, len(scores) // 5)
        last = sum(scores[-max(1, len(scores) // 5):]) / max(1, len(scores) // 5)
        print("matplotlib is not installed; skipping PNG generation.")
        print(f"Episodes: {len(scores)}")
        print(f"First-window avg: {first:.3f}")
        print(f"Last-window avg:  {last:.3f}")
        print(f"Delta:            {last - first:+.3f}")
        return

    window = max(1, min(10, len(scores) // 5))
    smoothed = moving_average(scores, window)

    output.parent.mkdir(parents=True, exist_ok=True)
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(episodes, scores, color="#6b8fbf", alpha=0.35, label="raw score")
    ax.plot(episodes, smoothed, color="#123c69", linewidth=2.5, label=f"{window}-episode avg")
    for episode, stage in stage_boundaries(episodes, stages):
        ax.axvline(episode, color="#c44e52", linestyle="--", alpha=0.45)
        ax.text(episode, 0.04, stage.replace("curriculum:", ""), rotation=90, fontsize=8, color="#7a1f24")
    ax.set_title(f"AdaptShield Training Curve ({label})")
    ax.set_xlabel("Episode")
    ax.set_ylabel("normalized_score")
    ax.set_ylim(0.0, 1.0)
    ax.grid(alpha=0.25)
    ax.legend()
    fig.tight_layout()
    fig.savefig(output, dpi=160)
    print(f"Saved plot: {output}")


def stage_boundaries(episodes: List[int], stages: List[str]) -> List[Tuple[int, str]]:
    if not stages:
        return []

    boundaries = []
    previous = stages[0]
    for episode, stage in zip(episodes, stages):
        if stage != previous:
            boundaries.append((episode, stage))
            previous = stage
    return boundaries


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Plot AdaptShield training output.")
    parser.add_argument("--input", default="training_runs/train_smoke.csv")
    parser.add_argument("--output", default="training_runs/reward_curve.png")
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    plot(Path(args.input), Path(args.output))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())