Spaces:
Running
Running
File size: 3,242 Bytes
c1060df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | #!/usr/bin/env python3
"""Plot an SFT checkpoint curve with an optional honest baseline start point."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
def load_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def mean_baseline(benchmark: dict[str, Any], key: str) -> float:
values = benchmark.get(key, {})
numeric = [float(value) for value in values.values() if value is not None]
if not numeric:
raise ValueError(f"No numeric values found under benchmark key '{key}'")
return sum(numeric) / len(numeric)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot SFT checkpoint learning curve with optional baseline point.")
parser.add_argument("--metrics", required=True, help="Path to sft_metrics.json")
parser.add_argument("--output", required=True, help="Where to write the PNG")
parser.add_argument(
"--baseline-json",
default="",
help="Optional benchmark_table.json path used to prepend a real baseline point.",
)
parser.add_argument(
"--baseline-key",
default="tool_baseline",
choices=["tool_baseline", "no_tool_baseline"],
help="Which benchmark JSON field to average for the prepended baseline point.",
)
parser.add_argument(
"--baseline-label",
default="baseline",
help="X-axis label for the prepended baseline point.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
metrics = load_json(Path(args.metrics))
rows = metrics.get("reward_curve_rows", []) or []
if not rows:
raise SystemExit("No reward_curve_rows found in the provided SFT metrics file.")
labels = [str(row["checkpoint"]) for row in rows]
train_scores = [float(row["in_distribution_score"]) for row in rows]
heldout_scores = [float(row["heldout_score"]) for row in rows]
if args.baseline_json:
benchmark = load_json(Path(args.baseline_json))
baseline_value = mean_baseline(benchmark, args.baseline_key)
labels = [args.baseline_label] + labels
train_scores = [baseline_value] + train_scores
heldout_scores = [baseline_value] + heldout_scores
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError as exc:
raise SystemExit(f"matplotlib is required to plot this curve: {exc}") from exc
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(11, 5))
plt.plot(labels, train_scores, marker="o", linewidth=2, color="#174c7a", label="train family")
plt.plot(labels, heldout_scores, marker="s", linewidth=2, color="#6d4acb", label="held-out family")
plt.title("Janus SFT Checkpoint Learning Curve")
plt.xlabel("Checkpoint")
plt.ylabel("normalized_score")
plt.ylim(0.0, 1.0)
plt.grid(alpha=0.25)
plt.legend()
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.savefig(output_path, dpi=160)
print(output_path)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|