File size: 4,954 Bytes
5b695bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import annotations

import argparse
import csv
from collections import defaultdict
from pathlib import Path
from typing import Any


def read_rows(csv_path: Path) -> list[dict[str, Any]]:
    with csv_path.open("r", encoding="utf-8", newline="") as handle:
        reader = csv.DictReader(handle)
        rows: list[dict[str, Any]] = []
        for row in reader:
            parsed: dict[str, Any] = {}
            for key, value in row.items():
                if value is None:
                    parsed[key] = value
                    continue
                value = value.strip()
                if value == "":
                    parsed[key] = value
                    continue
                try:
                    parsed[key] = float(value) if "." in value else int(value)
                except ValueError:
                    parsed[key] = value
            rows.append(parsed)
        return rows


def rolling_mean(values: list[float], window: int) -> list[float]:
    output: list[float] = []
    for index in range(len(values)):
        start = max(0, index - window + 1)
        chunk = values[start : index + 1]
        output.append(sum(chunk) / len(chunk))
    return output


def plot_reward_curve(rows: list[dict[str, Any]], output_dir: Path) -> None:
    import matplotlib.pyplot as plt

    train_rows = [row for row in rows if row.get("phase") == "train"]
    steps = [int(row["step"]) for row in train_rows]
    rewards = [float(row["episode_reward"]) for row in train_rows]
    reward_smooth = rolling_mean(rewards, window=20)

    plt.figure(figsize=(10, 5))
    plt.plot(steps, rewards, alpha=0.25, label="Episode reward")
    plt.plot(steps, reward_smooth, linewidth=2, label="20-step moving average")
    plt.xlabel("Training step")
    plt.ylabel("Reward")
    plt.title("ADAPT Training Reward Curve")
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_dir / "reward_curve.png", dpi=200)
    plt.close()


def plot_pass_rate_by_difficulty(rows: list[dict[str, Any]], output_dir: Path) -> None:
    import matplotlib.pyplot as plt

    train_rows = [row for row in rows if row.get("phase") == "train"]
    grouped: dict[str, list[tuple[int, float]]] = defaultdict(list)
    for row in train_rows:
        grouped[str(row["difficulty_tier"])].append((int(row["step"]), float(row["pass_rate"])))

    plt.figure(figsize=(10, 5))
    for difficulty in ("easy", "medium", "hard"):
        points = grouped.get(difficulty, [])
        if not points:
            continue
        steps = [step for step, _ in points]
        values = [value for _, value in points]
        smooth = rolling_mean(values, window=10)
        plt.plot(steps, smooth, linewidth=2, label=difficulty.title())

    plt.xlabel("Training step")
    plt.ylabel("Pass rate")
    plt.title("Pass Rate by Difficulty Tier")
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_dir / "pass_rate_by_difficulty.png", dpi=200)
    plt.close()


def plot_family_productivity(rows: list[dict[str, Any]], output_dir: Path) -> None:
    import matplotlib.pyplot as plt

    train_rows = [row for row in rows if row.get("phase") == "train"]
    productivity_columns = [key for key in train_rows[0].keys() if str(key).startswith("family_productivity__")]
    if not productivity_columns:
        return

    ranked_columns = sorted(
        productivity_columns,
        key=lambda column: float(train_rows[-1].get(column, 0.0)),
        reverse=True,
    )[:8]

    plt.figure(figsize=(11, 6))
    steps = [int(row["step"]) for row in train_rows]
    for column in ranked_columns:
        family = column.split("__", 1)[1]
        values = [float(row.get(column, 0.0)) for row in train_rows]
        plt.plot(steps, values, linewidth=2, label=family)

    plt.xlabel("Training step")
    plt.ylabel("Family productivity EMA")
    plt.title("Reward-Aware Family Productivity Over Training")
    plt.legend(loc="upper left", fontsize=8)
    plt.tight_layout()
    plt.savefig(output_dir / "family_productivity.png", dpi=200)
    plt.close()


def main(argv: list[str] | None = None) -> None:
    parser = argparse.ArgumentParser(description="Plot ADAPT reward and curriculum artifacts from reward_curve.csv.")
    parser.add_argument("csv_path", help="Path to reward_curve.csv")
    parser.add_argument("--output-dir", default=None, help="Directory for PNG outputs. Defaults to the CSV directory.")
    args = parser.parse_args(argv)

    csv_path = Path(args.csv_path)
    output_dir = Path(args.output_dir) if args.output_dir else csv_path.parent
    output_dir.mkdir(parents=True, exist_ok=True)

    rows = read_rows(csv_path)
    if not rows:
        raise RuntimeError(f"No rows found in {csv_path}")

    plot_reward_curve(rows, output_dir)
    plot_pass_rate_by_difficulty(rows, output_dir)
    plot_family_productivity(rows, output_dir)
    print(f"Saved plots to {output_dir}")


if __name__ == "__main__":
    main()