Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| def extract_overall(data): | |
| if "overall" in data: | |
| return data["overall"] | |
| return { | |
| "naive": data.get("naive", {}), | |
| "rule": data.get("rule", {}), | |
| } | |
| def plot_metrics(): | |
| try: | |
| input_path = Path("outputs/eval_results.json") | |
| with open(input_path, "r", encoding="utf-8") as handle: | |
| data = json.load(handle) | |
| overall = extract_overall(data) | |
| naive_avg = overall["naive"].get("avg_reward", data.get("naive_avg", 0.0)) | |
| rule_avg = overall["rule"].get("avg_reward", data.get("rule_avg", 0.0)) | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 4)) | |
| axes[0].bar(["Naive", "Rule"], [naive_avg, rule_avg], color=["#d64545", "#2f6db0"]) | |
| axes[0].set_ylabel("Average Reward") | |
| axes[0].set_title("Overall Reward") | |
| seen_safe = data.get("seen", {}).get("rule", {}).get("safe_ship_rate") | |
| unseen_safe = data.get("unseen", {}).get("rule", {}).get("safe_ship_rate") | |
| if seen_safe is None or unseen_safe is None: | |
| seen_safe = overall["rule"].get("safe_ship_rate", 0.0) | |
| unseen_safe = overall["naive"].get("safe_ship_rate", 0.0) | |
| axes[1].bar(["Seen", "Unseen"], [seen_safe, unseen_safe], color=["#4f8f4f", "#c58f3a"]) | |
| axes[1].set_ylim(0.0, 1.0) | |
| axes[1].set_ylabel("Safe Ship Rate") | |
| axes[1].set_title("Rule Baseline Slice Safety") | |
| fig.tight_layout() | |
| output_path = Path("outputs/eval_chart.png") | |
| fig.savefig(output_path) | |
| print(f"Chart saved to {output_path}") | |
| except Exception as exc: | |
| print(f"Error plotting metrics: {exc}") | |
| if __name__ == "__main__": | |
| plot_metrics() | |