suraj-01's picture
Initial
b14c6e3
"""
Plotting and Visualization for Evaluation Results
Generates comparison charts for RL vs baseline agents.
"""
import json
import argparse
from typing import Dict, Any, List
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
def load_results(filename: str = "evaluation_results.json") -> Dict[str, Any]:
"""
Load evaluation results from JSON file.
Args:
filename: Path to results JSON file
Returns:
Results dictionary
"""
with open(filename, "r") as f:
return json.load(f)
def plot_score_comparison(results: Dict[str, Any], output_file: str = "score_comparison.png") -> None:
"""
Create bar plot comparing mean scores across agents and tasks.
Args:
results: Evaluation results dictionary
output_file: Output filename for plot
"""
# Prepare data
agents = []
tasks = []
scores = []
errors = []
for agent_name, agent_results in results.items():
for task_id, task_results in agent_results.items():
agents.append(agent_name)
tasks.append(task_id.capitalize())
scores.append(task_results["mean_score"])
errors.append(task_results["std_score"])
df = pd.DataFrame({
"Agent": agents,
"Task": tasks,
"Score": scores,
"Error": errors,
})
# Create plot
plt.figure(figsize=(12, 6))
# Group by task
tasks_unique = df["Task"].unique()
x = np.arange(len(tasks_unique))
width = 0.25
agents_unique = df["Agent"].unique()
for i, agent in enumerate(agents_unique):
agent_data = df[df["Agent"] == agent]
scores_by_task = [
agent_data[agent_data["Task"] == task]["Score"].values[0]
if len(agent_data[agent_data["Task"] == task]) > 0 else 0
for task in tasks_unique
]
errors_by_task = [
agent_data[agent_data["Task"] == task]["Error"].values[0]
if len(agent_data[agent_data["Task"] == task]) > 0 else 0
for task in tasks_unique
]
plt.bar(
x + i * width,
scores_by_task,
width,
label=agent,
yerr=errors_by_task,
capsize=5,
alpha=0.8,
)
# Add success thresholds
thresholds = {"Easy": 0.7, "Medium": 0.65, "Hard": 0.6}
for i, task in enumerate(tasks_unique):
plt.axhline(
y=thresholds.get(task, 0.6),
xmin=(i - 0.4) / len(tasks_unique),
xmax=(i + 0.4) / len(tasks_unique),
color="red",
linestyle="--",
linewidth=1,
alpha=0.5,
)
plt.xlabel("Task", fontsize=12, fontweight="bold")
plt.ylabel("Score", fontsize=12, fontweight="bold")
plt.title("Agent Performance Comparison Across Tasks", fontsize=14, fontweight="bold")
plt.xticks(x + width, tasks_unique)
plt.legend(loc="upper right", fontsize=10)
plt.ylim(0, 1.0)
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches="tight")
print(f"Score comparison plot saved to: {output_file}")
plt.close()
def plot_success_rates(results: Dict[str, Any], output_file: str = "success_rates.png") -> None:
"""
Create plot showing success rates (% episodes above threshold).
Args:
results: Evaluation results dictionary
output_file: Output filename for plot
"""
# Prepare data
agents = []
tasks = []
success_rates = []
for agent_name, agent_results in results.items():
for task_id, task_results in agent_results.items():
agents.append(agent_name)
tasks.append(task_id.capitalize())
success_rates.append(task_results["success_rate"] * 100)
df = pd.DataFrame({
"Agent": agents,
"Task": tasks,
"Success Rate": success_rates,
})
# Create plot
plt.figure(figsize=(10, 6))
# Grouped bar chart
pivot_df = df.pivot(index="Task", columns="Agent", values="Success Rate")
pivot_df.plot(kind="bar", ax=plt.gca(), width=0.8, alpha=0.8)
plt.xlabel("Task", fontsize=12, fontweight="bold")
plt.ylabel("Success Rate (%)", fontsize=12, fontweight="bold")
plt.title("Success Rate: % Episodes Above Threshold", fontsize=14, fontweight="bold")
plt.xticks(rotation=0)
plt.legend(title="Agent", fontsize=10)
plt.ylim(0, 100)
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches="tight")
print(f"Success rates plot saved to: {output_file}")
plt.close()
def plot_failure_analysis(results: Dict[str, Any], output_file: str = "failure_analysis.png") -> None:
"""
Create plot showing mean system failures per episode.
Args:
results: Evaluation results dictionary
output_file: Output filename for plot
"""
# Prepare data
agents = []
tasks = []
failures = []
for agent_name, agent_results in results.items():
for task_id, task_results in agent_results.items():
agents.append(agent_name)
tasks.append(task_id.capitalize())
failures.append(task_results["mean_failures"])
df = pd.DataFrame({
"Agent": agents,
"Task": tasks,
"Mean Failures": failures,
})
# Create plot
plt.figure(figsize=(10, 6))
# Grouped bar chart
pivot_df = df.pivot(index="Task", columns="Agent", values="Mean Failures")
pivot_df.plot(kind="bar", ax=plt.gca(), width=0.8, alpha=0.8, color=["#d32f2f", "#f57c00", "#fbc02d"])
plt.xlabel("Task", fontsize=12, fontweight="bold")
plt.ylabel("Mean System Failures", fontsize=12, fontweight="bold")
plt.title("System Failure Analysis (Lower is Better)", fontsize=14, fontweight="bold")
plt.xticks(rotation=0)
plt.legend(title="Agent", fontsize=10)
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches="tight")
print(f"Failure analysis plot saved to: {output_file}")
plt.close()
def plot_episode_curves(results: Dict[str, Any], task_id: str = "medium",
output_file: str = "episode_curves.png") -> None:
"""
Create line plots showing score progression across episodes.
Args:
results: Evaluation results dictionary
task_id: Task to plot
output_file: Output filename for plot
"""
plt.figure(figsize=(12, 5))
# Plot scores across episodes
plt.subplot(1, 2, 1)
for agent_name, agent_results in results.items():
if task_id in agent_results:
scores = agent_results[task_id]["episode_scores"]
plt.plot(range(1, len(scores) + 1), scores, marker="o", label=agent_name, alpha=0.7)
# Add threshold line
thresholds = {"easy": 0.7, "medium": 0.65, "hard": 0.6}
plt.axhline(y=thresholds[task_id], color="red", linestyle="--", label="Success Threshold", alpha=0.5)
plt.xlabel("Episode", fontsize=11, fontweight="bold")
plt.ylabel("Task Score", fontsize=11, fontweight="bold")
plt.title(f"{task_id.capitalize()} Task: Score per Episode", fontsize=12, fontweight="bold")
plt.legend(fontsize=9)
plt.grid(alpha=0.3)
# Plot rewards across episodes
plt.subplot(1, 2, 2)
for agent_name, agent_results in results.items():
if task_id in agent_results:
rewards = agent_results[task_id]["episode_rewards"]
plt.plot(range(1, len(rewards) + 1), rewards, marker="s", label=agent_name, alpha=0.7)
plt.xlabel("Episode", fontsize=11, fontweight="bold")
plt.ylabel("Total Reward", fontsize=11, fontweight="bold")
plt.title(f"{task_id.capitalize()} Task: Reward per Episode", fontsize=12, fontweight="bold")
plt.legend(fontsize=9)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches="tight")
print(f"Episode curves plot saved to: {output_file}")
plt.close()
def generate_all_plots(results_file: str = "evaluation_results.json", output_dir: str = ".") -> None:
"""
Generate all visualization plots.
Args:
results_file: Path to evaluation results JSON
output_dir: Directory to save plots
"""
print(f"\nGenerating plots from: {results_file}\n")
# Load results
results = load_results(results_file)
# Set style
sns.set_style("whitegrid")
sns.set_palette("husl")
# Generate plots
plot_score_comparison(results, f"{output_dir}/score_comparison.png")
plot_success_rates(results, f"{output_dir}/success_rates.png")
plot_failure_analysis(results, f"{output_dir}/failure_analysis.png")
# Generate episode curves for each task
for task_id in ["easy", "medium", "hard"]:
plot_episode_curves(results, task_id, f"{output_dir}/episode_curves_{task_id}.png")
print("\n✅ All plots generated successfully!")
def main():
"""Main plotting entry point."""
parser = argparse.ArgumentParser(
description="Generate plots from evaluation results"
)
parser.add_argument(
"--input",
type=str,
default="evaluation_results.json",
help="Input JSON file with evaluation results",
)
parser.add_argument(
"--output-dir",
type=str,
default=".",
help="Output directory for plots",
)
parser.add_argument(
"--plot-type",
type=str,
choices=["all", "scores", "success", "failures", "episodes"],
default="all",
help="Type of plot to generate",
)
args = parser.parse_args()
print("Adaptive Alert Triage - Result Visualization")
# Load results
try:
results = load_results(args.input)
except FileNotFoundError:
print(f"Error: Results file not found: {args.input}")
print("Run evaluate.py first to generate results.")
return
# Generate plots based on type
if args.plot_type == "all":
generate_all_plots(args.input, args.output_dir)
elif args.plot_type == "scores":
plot_score_comparison(results, f"{args.output_dir}/score_comparison.png")
elif args.plot_type == "success":
plot_success_rates(results, f"{args.output_dir}/success_rates.png")
elif args.plot_type == "failures":
plot_failure_analysis(results, f"{args.output_dir}/failure_analysis.png")
elif args.plot_type == "episodes":
for task_id in ["easy", "medium", "hard"]:
plot_episode_curves(results, task_id, f"{args.output_dir}/episode_curves_{task_id}.png")
if __name__ == "__main__":
main()