adaptshield / plot_sft_checkpoint_curve.py
SaiManish123's picture
Initial deploy of AdaptShield two-phase cybersecurity environment
c1060df verified
#!/usr/bin/env python3
"""Plot an SFT checkpoint curve with an optional honest baseline start point."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
def load_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def mean_baseline(benchmark: dict[str, Any], key: str) -> float:
values = benchmark.get(key, {})
numeric = [float(value) for value in values.values() if value is not None]
if not numeric:
raise ValueError(f"No numeric values found under benchmark key '{key}'")
return sum(numeric) / len(numeric)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot SFT checkpoint learning curve with optional baseline point.")
parser.add_argument("--metrics", required=True, help="Path to sft_metrics.json")
parser.add_argument("--output", required=True, help="Where to write the PNG")
parser.add_argument(
"--baseline-json",
default="",
help="Optional benchmark_table.json path used to prepend a real baseline point.",
)
parser.add_argument(
"--baseline-key",
default="tool_baseline",
choices=["tool_baseline", "no_tool_baseline"],
help="Which benchmark JSON field to average for the prepended baseline point.",
)
parser.add_argument(
"--baseline-label",
default="baseline",
help="X-axis label for the prepended baseline point.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
metrics = load_json(Path(args.metrics))
rows = metrics.get("reward_curve_rows", []) or []
if not rows:
raise SystemExit("No reward_curve_rows found in the provided SFT metrics file.")
labels = [str(row["checkpoint"]) for row in rows]
train_scores = [float(row["in_distribution_score"]) for row in rows]
heldout_scores = [float(row["heldout_score"]) for row in rows]
if args.baseline_json:
benchmark = load_json(Path(args.baseline_json))
baseline_value = mean_baseline(benchmark, args.baseline_key)
labels = [args.baseline_label] + labels
train_scores = [baseline_value] + train_scores
heldout_scores = [baseline_value] + heldout_scores
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError as exc:
raise SystemExit(f"matplotlib is required to plot this curve: {exc}") from exc
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(11, 5))
plt.plot(labels, train_scores, marker="o", linewidth=2, color="#174c7a", label="train family")
plt.plot(labels, heldout_scores, marker="s", linewidth=2, color="#6d4acb", label="held-out family")
plt.title("Janus SFT Checkpoint Learning Curve")
plt.xlabel("Checkpoint")
plt.ylabel("normalized_score")
plt.ylim(0.0, 1.0)
plt.grid(alpha=0.25)
plt.legend()
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.savefig(output_path, dpi=160)
print(output_path)
return 0
if __name__ == "__main__":
raise SystemExit(main())