meta-rl-dsa-solver / training /plot_results.py
kaustubhg73's picture
v4
5b695bd
from __future__ import annotations
import argparse
import csv
from collections import defaultdict
from pathlib import Path
from typing import Any
def read_rows(csv_path: Path) -> list[dict[str, Any]]:
with csv_path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
rows: list[dict[str, Any]] = []
for row in reader:
parsed: dict[str, Any] = {}
for key, value in row.items():
if value is None:
parsed[key] = value
continue
value = value.strip()
if value == "":
parsed[key] = value
continue
try:
parsed[key] = float(value) if "." in value else int(value)
except ValueError:
parsed[key] = value
rows.append(parsed)
return rows
def rolling_mean(values: list[float], window: int) -> list[float]:
output: list[float] = []
for index in range(len(values)):
start = max(0, index - window + 1)
chunk = values[start : index + 1]
output.append(sum(chunk) / len(chunk))
return output
def plot_reward_curve(rows: list[dict[str, Any]], output_dir: Path) -> None:
import matplotlib.pyplot as plt
train_rows = [row for row in rows if row.get("phase") == "train"]
steps = [int(row["step"]) for row in train_rows]
rewards = [float(row["episode_reward"]) for row in train_rows]
reward_smooth = rolling_mean(rewards, window=20)
plt.figure(figsize=(10, 5))
plt.plot(steps, rewards, alpha=0.25, label="Episode reward")
plt.plot(steps, reward_smooth, linewidth=2, label="20-step moving average")
plt.xlabel("Training step")
plt.ylabel("Reward")
plt.title("ADAPT Training Reward Curve")
plt.legend()
plt.tight_layout()
plt.savefig(output_dir / "reward_curve.png", dpi=200)
plt.close()
def plot_pass_rate_by_difficulty(rows: list[dict[str, Any]], output_dir: Path) -> None:
import matplotlib.pyplot as plt
train_rows = [row for row in rows if row.get("phase") == "train"]
grouped: dict[str, list[tuple[int, float]]] = defaultdict(list)
for row in train_rows:
grouped[str(row["difficulty_tier"])].append((int(row["step"]), float(row["pass_rate"])))
plt.figure(figsize=(10, 5))
for difficulty in ("easy", "medium", "hard"):
points = grouped.get(difficulty, [])
if not points:
continue
steps = [step for step, _ in points]
values = [value for _, value in points]
smooth = rolling_mean(values, window=10)
plt.plot(steps, smooth, linewidth=2, label=difficulty.title())
plt.xlabel("Training step")
plt.ylabel("Pass rate")
plt.title("Pass Rate by Difficulty Tier")
plt.legend()
plt.tight_layout()
plt.savefig(output_dir / "pass_rate_by_difficulty.png", dpi=200)
plt.close()
def plot_family_productivity(rows: list[dict[str, Any]], output_dir: Path) -> None:
import matplotlib.pyplot as plt
train_rows = [row for row in rows if row.get("phase") == "train"]
productivity_columns = [key for key in train_rows[0].keys() if str(key).startswith("family_productivity__")]
if not productivity_columns:
return
ranked_columns = sorted(
productivity_columns,
key=lambda column: float(train_rows[-1].get(column, 0.0)),
reverse=True,
)[:8]
plt.figure(figsize=(11, 6))
steps = [int(row["step"]) for row in train_rows]
for column in ranked_columns:
family = column.split("__", 1)[1]
values = [float(row.get(column, 0.0)) for row in train_rows]
plt.plot(steps, values, linewidth=2, label=family)
plt.xlabel("Training step")
plt.ylabel("Family productivity EMA")
plt.title("Reward-Aware Family Productivity Over Training")
plt.legend(loc="upper left", fontsize=8)
plt.tight_layout()
plt.savefig(output_dir / "family_productivity.png", dpi=200)
plt.close()
def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Plot ADAPT reward and curriculum artifacts from reward_curve.csv.")
parser.add_argument("csv_path", help="Path to reward_curve.csv")
parser.add_argument("--output-dir", default=None, help="Directory for PNG outputs. Defaults to the CSV directory.")
args = parser.parse_args(argv)
csv_path = Path(args.csv_path)
output_dir = Path(args.output_dir) if args.output_dir else csv_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
rows = read_rows(csv_path)
if not rows:
raise RuntimeError(f"No rows found in {csv_path}")
plot_reward_curve(rows, output_dir)
plot_pass_rate_by_difficulty(rows, output_dir)
plot_family_productivity(rows, output_dir)
print(f"Saved plots to {output_dir}")
if __name__ == "__main__":
main()