satomitheito's picture
Add new agents and observability, fix sys.path for HF Space
3487f22
"""
Visualize eval_metrics.json results. Charts are saved to the charts/ subdirectory.
Produces:
charts/overall_scores.png β€” bar chart of avg score per scorer
charts/category_heatmap.png β€” heatmap of avg score by category Γ— scorer
charts/latency_distribution.png β€” per-case latency score distribution
charts/failures_by_category.png β€” count of sub-threshold cases per category
Usage:
python agent-evaluation/visualize_metrics.py
python agent-evaluation/visualize_metrics.py --metrics eval_metrics.json --threshold 1.0
"""
import argparse
import json
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
DEFAULT_METRICS_PATH = Path(__file__).resolve().parent / "eval_metrics.json"
CHARTS_DIR = Path(__file__).resolve().parent / "charts"
DEFAULT_THRESHOLD = 1.0
PALETTE = {
"pass": "#4CAF50",
"warn": "#FF9800",
"fail": "#F44336",
"bar": "#5C6BC0",
"bg": "#FAFAFA",
}
_SCORER_LABELS = {
"BenchmarkSimilarityValidity": "Benchmark Similarity Validity",
"ClassificationValidity": "Classification Validity",
"ClauseStructureValidity": "Clause Structure Validity",
"ExpectedClauseType": "Expected Clause Type",
"Factuality": "Factuality",
"Latency": "Latency",
"NoError": "No Error",
"OutputStructureValidity": "Output Structure Validity",
"RiskFactorsPresence": "Risk Factors Presence",
"RiskScoreValidity": "Risk Score Validity",
}
_CATEGORY_LABELS = {
"agency_agreement": "Agency Agreement",
"commercial_lease": "Commercial Lease",
"consulting_agreement": "Consulting Agreement",
"data_processing_agreement": "Data Processing Agreement",
"distribution_agreement": "Distribution Agreement",
"distributor_agreement": "Distributor Agreement",
"employment_contract": "Employment Contract",
"enterprise_software": "Enterprise Software",
"franchise_agreement": "Franchise Agreement",
"joint_venture": "Joint Venture",
"master_service_agreement": "Master Service Agreement",
"nda": "NDA",
"nda_services_agreement": "NDA Services Agreement",
"partnership_agreement": "Partnership Agreement",
"research_collaboration": "Research Collaboration",
"reseller_agreement": "Reseller Agreement",
"saas_agreement": "SaaS Agreement",
"services_agreement": "Services Agreement",
"software_license": "Software License",
"strategic_alliance": "Strategic Alliance",
"supply_agreement": "Supply Agreement",
"technology_licensing": "Technology Licensing",
"technology_transfer": "Technology Transfer",
}
def _fmt(name: str, mapping: dict) -> str:
return mapping.get(name, name.replace("_", " ").title())
def load_metrics(path: str) -> dict:
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Metrics file not found: {path}")
with open(p) as f:
return json.load(f)
def _save(fig: plt.Figure, name: str) -> None:
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
out = CHARTS_DIR / name
fig.savefig(out, bbox_inches="tight", dpi=150)
plt.close(fig)
print(f" saved β†’ {out}")
# ---------------------------------------------------------------------------
# Chart 1: Overall scores bar chart
# ---------------------------------------------------------------------------
def plot_overall_scores(metrics: dict) -> None:
overall = metrics.get("overall_scores", {})
scorers = sorted(overall.keys())
averages = [overall[s]["average"] for s in scorers]
counts = [overall[s]["count"] for s in scorers]
colors = [
PALETTE["pass"] if v == 1.0
else PALETTE["warn"] if v >= 0.75
else PALETTE["fail"]
for v in averages
]
fig, ax = plt.subplots(figsize=(10, 5))
fig.patch.set_facecolor(PALETTE["bg"])
ax.set_facecolor(PALETTE["bg"])
scorer_labels = [_fmt(s, _SCORER_LABELS) for s in scorers]
bars = ax.barh(scorer_labels, averages, color=colors, edgecolor="white", height=0.6)
for bar, avg, n in zip(bars, averages, counts):
ax.text(
min(avg + 0.005, 1.01), bar.get_y() + bar.get_height() / 2,
f"{avg:.3f} (n={n})",
va="center", ha="left", fontsize=9, color="#333333",
)
ax.set_xlim(0, 1.12)
ax.set_xlabel("Average Score", fontsize=11)
ax.set_title("Overall Scorer Averages", fontsize=13, fontweight="bold", pad=12)
ax.axvline(1.0, color="#BDBDBD", linewidth=0.8, linestyle="--")
ax.tick_params(axis="y", labelsize=10)
ax.spines[["top", "right", "bottom"]].set_visible(False)
total = metrics.get("total_cases", 0)
ax.text(
0.99, -0.08, f"n = {total} total cases",
transform=ax.transAxes, ha="right", fontsize=9, color="#757575",
)
_save(fig, "overall_scores.png")
# ---------------------------------------------------------------------------
# Chart 2: Category Γ— Scorer heatmap
# ---------------------------------------------------------------------------
def plot_category_heatmap(metrics: dict) -> None:
per_category = metrics.get("per_category", {})
categories = sorted(per_category.keys())
all_scorers: set[str] = set()
for cat_data in per_category.values():
all_scorers.update(cat_data.keys())
scorers = sorted(all_scorers)
matrix = np.full((len(categories), len(scorers)), np.nan)
for r, cat in enumerate(categories):
for c, scorer in enumerate(scorers):
entry = per_category[cat].get(scorer)
if entry is not None:
matrix[r, c] = entry["average"]
fig, ax = plt.subplots(figsize=(max(10, len(scorers) * 1.3), max(6, len(categories) * 0.55)))
fig.patch.set_facecolor(PALETTE["bg"])
ax.set_facecolor(PALETTE["bg"])
cmap = mcolors.LinearSegmentedColormap.from_list(
"eval", [PALETTE["fail"], PALETTE["warn"], PALETTE["pass"]]
)
cmap.set_bad(color="#E0E0E0")
im = ax.imshow(matrix, cmap=cmap, vmin=0.0, vmax=1.0, aspect="auto")
ax.set_xticks(range(len(scorers)))
ax.set_xticklabels([_fmt(s, _SCORER_LABELS) for s in scorers], rotation=35, ha="right", fontsize=9)
ax.set_yticks(range(len(categories)))
ax.set_yticklabels([_fmt(c, _CATEGORY_LABELS) for c in categories], fontsize=9)
for r in range(len(categories)):
for c in range(len(scorers)):
val = matrix[r, c]
if not np.isnan(val):
text_color = "white" if val < 0.6 else "#222222"
ax.text(c, r, f"{val:.2f}", ha="center", va="center",
fontsize=7.5, color=text_color)
plt.colorbar(im, ax=ax, fraction=0.02, pad=0.02, label="Average Score")
ax.set_title("Score Heatmap: Category Γ— Scorer", fontsize=13, fontweight="bold", pad=12)
ax.tick_params(length=0)
_save(fig, "category_heatmap.png")
# ---------------------------------------------------------------------------
# Chart 3: Per-case Latency score distribution
# ---------------------------------------------------------------------------
def plot_latency_distribution(metrics: dict) -> None:
latency_scores = [
case["scores"].get("Latency")
for case in metrics.get("per_case", [])
if case.get("scores", {}).get("Latency") is not None
]
if not latency_scores:
print(" no latency scores found β€” skipping latency chart")
return
buckets = {1.0: 0, 0.75: 0, 0.5: 0, 0.25: 0, 0.0: 0}
for s in latency_scores:
buckets[s] = buckets.get(s, 0) + 1
labels = {
1.0: "< 30s\n(1.0)",
0.75: "30–60s\n(0.75)",
0.5: "60–120s\n(0.5)",
0.25: "120–180s\n(0.25)",
0.0: "> 180s\n(0.0)",
}
ordered = [1.0, 0.75, 0.5, 0.25, 0.0]
counts = [buckets.get(v, 0) for v in ordered]
tick_labels = [labels[v] for v in ordered]
colors = [PALETTE["pass"], PALETTE["warn"], PALETTE["warn"], PALETTE["fail"], PALETTE["fail"]]
fig, ax = plt.subplots(figsize=(8, 4))
fig.patch.set_facecolor(PALETTE["bg"])
ax.set_facecolor(PALETTE["bg"])
bars = ax.bar(tick_labels, counts, color=colors, edgecolor="white", width=0.55)
for bar, count in zip(bars, counts):
if count:
ax.text(
bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.15,
str(count), ha="center", va="bottom", fontsize=10,
)
ax.set_ylabel("Number of Cases", fontsize=11)
ax.set_title("Latency Score Distribution (per case)", fontsize=13, fontweight="bold", pad=12)
ax.spines[["top", "right", "left"]].set_visible(False)
ax.tick_params(axis="x", length=0)
ax.set_yticks([])
_save(fig, "latency_distribution.png")
# ---------------------------------------------------------------------------
# Chart 4: Failures per category (cases below threshold)
# ---------------------------------------------------------------------------
def plot_failures_by_category(metrics: dict, threshold: float) -> None:
failure_counts: dict[str, int] = {}
for case in metrics.get("per_case", []):
cat = case.get("category", "unknown")
has_failure = any(
s < threshold for s in case.get("scores", {}).values()
) or bool(case.get("error"))
if has_failure:
failure_counts[cat] = failure_counts.get(cat, 0) + 1
if not failure_counts:
print(f" no failures at threshold {threshold} β€” skipping failures chart")
return
categories = sorted(failure_counts, key=lambda c: -failure_counts[c])
counts = [failure_counts[c] for c in categories]
category_labels = [_fmt(c, _CATEGORY_LABELS) for c in categories]
fig, ax = plt.subplots(figsize=(9, max(4, len(categories) * 0.5)))
fig.patch.set_facecolor(PALETTE["bg"])
ax.set_facecolor(PALETTE["bg"])
ax.barh(category_labels, counts, color=PALETTE["fail"], edgecolor="white", height=0.6)
for i, (cat, count) in enumerate(zip(category_labels, counts)):
ax.text(count + 0.05, i, str(count), va="center", fontsize=10)
ax.set_xlabel("Number of Failed Cases", fontsize=11)
ax.set_title(
f"Failures by Category (threshold < {threshold})",
fontsize=13, fontweight="bold", pad=12,
)
ax.spines[["top", "right", "bottom"]].set_visible(False)
ax.set_xlim(0, max(counts) + 1.5)
ax.tick_params(axis="y", labelsize=9)
_save(fig, "failures_by_category.png")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(
description="Visualize eval_metrics.json. Charts saved to charts/",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--metrics", type=str, default=str(DEFAULT_METRICS_PATH),
help=f"Path to eval_metrics.json (default: {DEFAULT_METRICS_PATH})",
)
parser.add_argument(
"--threshold", type=float, default=DEFAULT_THRESHOLD,
help=f"Failure threshold for the failures-by-category chart (default: {DEFAULT_THRESHOLD})",
)
args = parser.parse_args()
metrics = load_metrics(args.metrics)
print(f"\nGenerating charts from {args.metrics} β†’ {CHARTS_DIR}/\n")
plot_overall_scores(metrics)
plot_category_heatmap(metrics)
plot_latency_distribution(metrics)
plot_failures_by_category(metrics, args.threshold)
print("\nDone.\n")
if __name__ == "__main__":
main()