Spaces:
Running
Running
File size: 4,954 Bytes
5b695bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | from __future__ import annotations
import argparse
import csv
from collections import defaultdict
from pathlib import Path
from typing import Any
def read_rows(csv_path: Path) -> list[dict[str, Any]]:
with csv_path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
rows: list[dict[str, Any]] = []
for row in reader:
parsed: dict[str, Any] = {}
for key, value in row.items():
if value is None:
parsed[key] = value
continue
value = value.strip()
if value == "":
parsed[key] = value
continue
try:
parsed[key] = float(value) if "." in value else int(value)
except ValueError:
parsed[key] = value
rows.append(parsed)
return rows
def rolling_mean(values: list[float], window: int) -> list[float]:
output: list[float] = []
for index in range(len(values)):
start = max(0, index - window + 1)
chunk = values[start : index + 1]
output.append(sum(chunk) / len(chunk))
return output
def plot_reward_curve(rows: list[dict[str, Any]], output_dir: Path) -> None:
import matplotlib.pyplot as plt
train_rows = [row for row in rows if row.get("phase") == "train"]
steps = [int(row["step"]) for row in train_rows]
rewards = [float(row["episode_reward"]) for row in train_rows]
reward_smooth = rolling_mean(rewards, window=20)
plt.figure(figsize=(10, 5))
plt.plot(steps, rewards, alpha=0.25, label="Episode reward")
plt.plot(steps, reward_smooth, linewidth=2, label="20-step moving average")
plt.xlabel("Training step")
plt.ylabel("Reward")
plt.title("ADAPT Training Reward Curve")
plt.legend()
plt.tight_layout()
plt.savefig(output_dir / "reward_curve.png", dpi=200)
plt.close()
def plot_pass_rate_by_difficulty(rows: list[dict[str, Any]], output_dir: Path) -> None:
import matplotlib.pyplot as plt
train_rows = [row for row in rows if row.get("phase") == "train"]
grouped: dict[str, list[tuple[int, float]]] = defaultdict(list)
for row in train_rows:
grouped[str(row["difficulty_tier"])].append((int(row["step"]), float(row["pass_rate"])))
plt.figure(figsize=(10, 5))
for difficulty in ("easy", "medium", "hard"):
points = grouped.get(difficulty, [])
if not points:
continue
steps = [step for step, _ in points]
values = [value for _, value in points]
smooth = rolling_mean(values, window=10)
plt.plot(steps, smooth, linewidth=2, label=difficulty.title())
plt.xlabel("Training step")
plt.ylabel("Pass rate")
plt.title("Pass Rate by Difficulty Tier")
plt.legend()
plt.tight_layout()
plt.savefig(output_dir / "pass_rate_by_difficulty.png", dpi=200)
plt.close()
def plot_family_productivity(rows: list[dict[str, Any]], output_dir: Path) -> None:
import matplotlib.pyplot as plt
train_rows = [row for row in rows if row.get("phase") == "train"]
productivity_columns = [key for key in train_rows[0].keys() if str(key).startswith("family_productivity__")]
if not productivity_columns:
return
ranked_columns = sorted(
productivity_columns,
key=lambda column: float(train_rows[-1].get(column, 0.0)),
reverse=True,
)[:8]
plt.figure(figsize=(11, 6))
steps = [int(row["step"]) for row in train_rows]
for column in ranked_columns:
family = column.split("__", 1)[1]
values = [float(row.get(column, 0.0)) for row in train_rows]
plt.plot(steps, values, linewidth=2, label=family)
plt.xlabel("Training step")
plt.ylabel("Family productivity EMA")
plt.title("Reward-Aware Family Productivity Over Training")
plt.legend(loc="upper left", fontsize=8)
plt.tight_layout()
plt.savefig(output_dir / "family_productivity.png", dpi=200)
plt.close()
def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Plot ADAPT reward and curriculum artifacts from reward_curve.csv.")
parser.add_argument("csv_path", help="Path to reward_curve.csv")
parser.add_argument("--output-dir", default=None, help="Directory for PNG outputs. Defaults to the CSV directory.")
args = parser.parse_args(argv)
csv_path = Path(args.csv_path)
output_dir = Path(args.output_dir) if args.output_dir else csv_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
rows = read_rows(csv_path)
if not rows:
raise RuntimeError(f"No rows found in {csv_path}")
plot_reward_curve(rows, output_dir)
plot_pass_rate_by_difficulty(rows, output_dir)
plot_family_productivity(rows, output_dir)
print(f"Saved plots to {output_dir}")
if __name__ == "__main__":
main()
|