New_gpu_space / training /plot_metrics.py
hiitsesh's picture
Deploy ReleaseOps Arena GPU Space
9aa9d87
Raw
History Blame Contribute Delete
1.76 kB
import json
from pathlib import Path
import matplotlib.pyplot as plt
def extract_overall(data):
if "overall" in data:
return data["overall"]
return {
"naive": data.get("naive", {}),
"rule": data.get("rule", {}),
}
def plot_metrics():
try:
input_path = Path("outputs/eval_results.json")
with open(input_path, "r", encoding="utf-8") as handle:
data = json.load(handle)
overall = extract_overall(data)
naive_avg = overall["naive"].get("avg_reward", data.get("naive_avg", 0.0))
rule_avg = overall["rule"].get("avg_reward", data.get("rule_avg", 0.0))
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].bar(["Naive", "Rule"], [naive_avg, rule_avg], color=["#d64545", "#2f6db0"])
axes[0].set_ylabel("Average Reward")
axes[0].set_title("Overall Reward")
seen_safe = data.get("seen", {}).get("rule", {}).get("safe_ship_rate")
unseen_safe = data.get("unseen", {}).get("rule", {}).get("safe_ship_rate")
if seen_safe is None or unseen_safe is None:
seen_safe = overall["rule"].get("safe_ship_rate", 0.0)
unseen_safe = overall["naive"].get("safe_ship_rate", 0.0)
axes[1].bar(["Seen", "Unseen"], [seen_safe, unseen_safe], color=["#4f8f4f", "#c58f3a"])
axes[1].set_ylim(0.0, 1.0)
axes[1].set_ylabel("Safe Ship Rate")
axes[1].set_title("Rule Baseline Slice Safety")
fig.tight_layout()
output_path = Path("outputs/eval_chart.png")
fig.savefig(output_path)
print(f"Chart saved to {output_path}")
except Exception as exc:
print(f"Error plotting metrics: {exc}")
if __name__ == "__main__":
plot_metrics()