corp-env / plot_results.py
Navigam's picture
feat: add summary generation and visualization for model evaluation results
368fe4f
Raw
History Blame Contribute Delete
11.3 kB
"""Create hackathon result plots from CORP-ENV eval JSONL files or run folders."""
from __future__ import annotations
import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, List
STAGE_ORDER = ("baseline", "base", "sft", "grpo", "oracle")
TASK_LABELS = {
"e1_launch_readiness": "E1 Launch",
"m1_budget_reallocation": "M1 Budget",
"h1_acquisition_defence": "H1 Acquisition",
}
COLORS = {
"baseline": "#8c8c8c",
"base": "#4C78A8",
"sft": "#54A24B",
"grpo": "#F58518",
"oracle": "#B279A2",
}
def expand_inputs(inputs: Iterable[str]) -> List[Path]:
paths: List[Path] = []
for raw in inputs:
path = Path(raw)
if path.is_dir():
paths.extend(sorted(path.rglob("*_eval.jsonl")))
paths.extend(sorted(path.rglob("eval.jsonl")))
elif path.exists():
paths.append(path)
else:
matches = sorted(Path().glob(raw))
paths.extend([m for m in matches if m.is_file()])
# Preserve order but remove duplicates.
seen = set()
out: List[Path] = []
for path in paths:
key = str(path.resolve())
if key not in seen:
seen.add(key)
out.append(path)
return out
def read_rows(paths: Iterable[str]) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for path in expand_inputs(paths):
with path.open("r", encoding="utf-8") as f:
for line in f:
if line.strip():
row = json.loads(line)
row.setdefault("model_stage", path.stem.replace("_eval", ""))
if "run_id" not in row:
row["run_id"] = path.parent.name
steps = max(float(row.get("steps", 0) or 0), 1.0)
row["invalid_action_rate"] = float(row.get("invalid_action_count", 0) or 0) / steps
rows.append(row)
return rows
def stage_family(stage: str) -> str:
low = stage.lower()
for family in STAGE_ORDER:
if family in low:
return family
return low
def stage_sort_key(stage: str) -> tuple:
family = stage_family(stage)
try:
family_idx = STAGE_ORDER.index(family)
except ValueError:
family_idx = len(STAGE_ORDER)
return (family_idx, stage)
def task_label(task_id: str) -> str:
return TASK_LABELS.get(task_id, task_id.replace("_", " "))
def stage_label(stage: str) -> str:
return stage.replace("_", " ").replace("-", " ")
def grouped_mean(rows: List[Dict[str, Any]], metric: str) -> Dict[str, Dict[str, float]]:
grouped: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
for row in rows:
stage = str(row.get("model_stage", "unknown"))
task = str(row.get("task_id", "unknown"))
grouped[stage][task].append(float(row.get(metric, 0.0)))
return {
stage: {task: sum(vals) / len(vals) for task, vals in by_task.items()}
for stage, by_task in grouped.items()
}
def summary_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
grouped: Dict[tuple, List[Dict[str, Any]]] = defaultdict(list)
for row in rows:
grouped[(str(row.get("model_stage", "unknown")), str(row.get("task_id", "unknown")))].append(row)
out: List[Dict[str, Any]] = []
for (stage, task), vals in sorted(grouped.items(), key=lambda x: (stage_sort_key(x[0][0]), x[0][1])):
steps = [float(v.get("steps", 0) or 0) for v in vals]
invalid_rates = [
float(v.get("invalid_action_count", 0) or 0) / max(float(v.get("steps", 0) or 0), 1.0)
for v in vals
]
out.append(
{
"model_stage": stage,
"task_id": task,
"episodes": len(vals),
"avg_terminal_reward": round(sum(float(v.get("terminal_reward", 0.0)) for v in vals) / len(vals), 6),
"avg_total_reward": round(sum(float(v.get("total_reward", 0.0)) for v in vals) / len(vals), 6),
"avg_verifier_pass_rate": round(sum(float(v.get("verifier_pass_rate", 0.0)) for v in vals) / len(vals), 6),
"success_rate": round(sum(1 for v in vals if v.get("success")) / len(vals), 6),
"avg_invalid_action_rate": round(sum(invalid_rates) / len(invalid_rates), 6),
"avg_steps": round(sum(steps) / len(steps), 3),
}
)
return out
def write_summary(rows: List[Dict[str, Any]], output_dir: Path) -> None:
summary = summary_rows(rows)
csv_path = output_dir / "comparison_summary.csv"
md_path = output_dir / "comparison_summary.md"
if not summary:
return
with csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=list(summary[0].keys()))
writer.writeheader()
writer.writerows(summary)
headers = [
"Model Stage",
"Task",
"Episodes",
"Terminal Reward",
"Verifier Pass",
"Success",
"Invalid Rate",
"Avg Steps",
]
lines = [
"# CORP-ENV Result Comparison",
"",
"| " + " | ".join(headers) + " |",
"| " + " | ".join(["---"] * len(headers)) + " |",
]
for row in summary:
lines.append(
"| "
+ " | ".join(
[
str(row["model_stage"]),
task_label(str(row["task_id"])),
str(row["episodes"]),
f"{row['avg_terminal_reward']:.3f}",
f"{row['avg_verifier_pass_rate']:.3f}",
f"{row['success_rate']:.3f}",
f"{row['avg_invalid_action_rate']:.3f}",
f"{row['avg_steps']:.1f}",
]
)
+ " |"
)
lines.append("")
lines.append("Generated by `plot_results.py` from eval JSONL files.")
md_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def plot_grouped_bars(
data: Dict[str, Dict[str, float]],
title: str,
ylabel: str,
output: Path,
*,
clamp_unit: bool = True,
) -> None:
import matplotlib.pyplot as plt
plt.style.use("seaborn-v0_8-whitegrid")
stages = sorted(data.keys(), key=stage_sort_key)
tasks = sorted({task for by_task in data.values() for task in by_task})
x = list(range(len(tasks)))
width = 0.8 / max(len(stages), 1)
fig, ax = plt.subplots(figsize=(max(10, len(tasks) * 2.2), 5.8))
for idx, stage in enumerate(stages):
vals = [data[stage].get(task, 0.0) for task in tasks]
offsets = [pos - 0.4 + width / 2 + idx * width for pos in x]
family = stage_family(stage)
bars = ax.bar(
offsets,
vals,
width,
label=stage_label(stage),
color=COLORS.get(family),
edgecolor="white",
linewidth=0.8,
)
for bar, val in zip(bars, vals):
if val > 0:
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + (0.015 if clamp_unit else 0.02),
f"{val:.2f}",
ha="center",
va="bottom",
fontsize=8,
rotation=0,
)
ax.set_title(title, fontsize=15, weight="bold", pad=14)
ax.set_xlabel("Task")
ax.set_ylabel(ylabel)
ax.set_xticks(x)
ax.set_xticklabels([task_label(t) for t in tasks], rotation=0, ha="center")
if clamp_unit:
ax.set_ylim(0, 1.05)
ax.spines[["top", "right"]].set_visible(False)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.12), ncol=max(1, min(len(stages), 4)), frameon=False)
fig.tight_layout()
fig.savefig(output, dpi=160)
plt.close(fig)
def plot_reward_curve(rows: List[Dict[str, Any]], output: Path) -> None:
import matplotlib.pyplot as plt
plt.style.use("seaborn-v0_8-whitegrid")
grouped: Dict[tuple, List[List[float]]] = defaultdict(list)
for row in rows:
trace = [float(x) for x in (row.get("reward_trace") or [])]
if trace:
grouped[(str(row.get("model_stage", "model")), str(row.get("task_id", "task")))].append(trace)
fig, ax = plt.subplots(figsize=(12, 6.5))
plotted = bool(grouped)
for (stage, task), traces in sorted(grouped.items(), key=lambda x: (stage_sort_key(x[0][0]), x[0][1])):
max_len = max(len(t) for t in traces)
means: List[float] = []
mins: List[float] = []
maxs: List[float] = []
for idx in range(max_len):
vals = [trace[idx] for trace in traces if idx < len(trace)]
means.append(sum(vals) / len(vals))
mins.append(min(vals))
maxs.append(max(vals))
xs = list(range(1, max_len + 1))
family = stage_family(stage)
label = f"{stage_label(stage)} · {task_label(task)}"
color = COLORS.get(family)
ax.plot(xs, means, marker="o", linewidth=2.2, markersize=4, label=label, color=color)
if len(traces) > 1:
ax.fill_between(xs, mins, maxs, alpha=0.12, color=color)
if not plotted:
ax.text(0.5, 0.5, "No reward traces found", ha="center", va="center")
ax.axhline(0, color="#666666", linewidth=0.9, alpha=0.5)
ax.set_title("Episode Reward Curve By Model Stage", fontsize=15, weight="bold", pad=14)
ax.set_xlabel("Environment step")
ax.set_ylabel("Step reward")
ax.spines[["top", "right"]].set_visible(False)
if plotted:
ax.legend(fontsize=8, ncol=2, frameon=False, loc="upper center", bbox_to_anchor=(0.5, -0.12))
fig.tight_layout()
fig.savefig(output, dpi=160)
plt.close(fig)
def main() -> None:
parser = argparse.ArgumentParser(description="Plot CORP-ENV eval results.")
parser.add_argument("--inputs", nargs="+", required=True, help="Eval JSONL files, folders, or glob patterns.")
parser.add_argument("--output-dir", default="results")
args = parser.parse_args()
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
rows = read_rows(args.inputs)
if not rows:
raise SystemExit("No rows found in input files.")
plot_grouped_bars(
grouped_mean(rows, "terminal_reward"),
"Average Terminal Reward By Model Stage",
"Terminal reward",
out / "model_comparison.png",
clamp_unit=True,
)
plot_grouped_bars(
grouped_mean(rows, "verifier_pass_rate"),
"Verifier Pass Rate By Task",
"Verifier pass rate",
out / "success_by_task.png",
clamp_unit=True,
)
plot_grouped_bars(
grouped_mean(rows, "invalid_action_rate"),
"Invalid Action Rate By Task",
"Invalid actions / environment step",
out / "invalid_action_rate.png",
clamp_unit=True,
)
plot_reward_curve(rows, out / "reward_curve.png")
write_summary(rows, out)
print(f"Wrote plots to {out}")
print(f"Wrote summaries to {out / 'comparison_summary.md'} and {out / 'comparison_summary.csv'}")
if __name__ == "__main__":
main()