| """ |
| Visualize eval_metrics.json results. Charts are saved to the charts/ subdirectory. |
| |
| Produces: |
| charts/overall_scores.png β bar chart of avg score per scorer |
| charts/category_heatmap.png β heatmap of avg score by category Γ scorer |
| charts/latency_distribution.png β per-case latency score distribution |
| charts/failures_by_category.png β count of sub-threshold cases per category |
| |
| Usage: |
| python agent-evaluation/visualize_metrics.py |
| python agent-evaluation/visualize_metrics.py --metrics eval_metrics.json --threshold 1.0 |
| """ |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import matplotlib.colors as mcolors |
| import numpy as np |
|
|
| DEFAULT_METRICS_PATH = Path(__file__).resolve().parent / "eval_metrics.json" |
| CHARTS_DIR = Path(__file__).resolve().parent / "charts" |
| DEFAULT_THRESHOLD = 1.0 |
|
|
| PALETTE = { |
| "pass": "#4CAF50", |
| "warn": "#FF9800", |
| "fail": "#F44336", |
| "bar": "#5C6BC0", |
| "bg": "#FAFAFA", |
| } |
|
|
| _SCORER_LABELS = { |
| "BenchmarkSimilarityValidity": "Benchmark Similarity Validity", |
| "ClassificationValidity": "Classification Validity", |
| "ClauseStructureValidity": "Clause Structure Validity", |
| "ExpectedClauseType": "Expected Clause Type", |
| "Factuality": "Factuality", |
| "Latency": "Latency", |
| "NoError": "No Error", |
| "OutputStructureValidity": "Output Structure Validity", |
| "RiskFactorsPresence": "Risk Factors Presence", |
| "RiskScoreValidity": "Risk Score Validity", |
| } |
|
|
| _CATEGORY_LABELS = { |
| "agency_agreement": "Agency Agreement", |
| "commercial_lease": "Commercial Lease", |
| "consulting_agreement": "Consulting Agreement", |
| "data_processing_agreement": "Data Processing Agreement", |
| "distribution_agreement": "Distribution Agreement", |
| "distributor_agreement": "Distributor Agreement", |
| "employment_contract": "Employment Contract", |
| "enterprise_software": "Enterprise Software", |
| "franchise_agreement": "Franchise Agreement", |
| "joint_venture": "Joint Venture", |
| "master_service_agreement": "Master Service Agreement", |
| "nda": "NDA", |
| "nda_services_agreement": "NDA Services Agreement", |
| "partnership_agreement": "Partnership Agreement", |
| "research_collaboration": "Research Collaboration", |
| "reseller_agreement": "Reseller Agreement", |
| "saas_agreement": "SaaS Agreement", |
| "services_agreement": "Services Agreement", |
| "software_license": "Software License", |
| "strategic_alliance": "Strategic Alliance", |
| "supply_agreement": "Supply Agreement", |
| "technology_licensing": "Technology Licensing", |
| "technology_transfer": "Technology Transfer", |
| } |
|
|
|
|
| def _fmt(name: str, mapping: dict) -> str: |
| return mapping.get(name, name.replace("_", " ").title()) |
|
|
|
|
| def load_metrics(path: str) -> dict: |
| p = Path(path) |
| if not p.exists(): |
| raise FileNotFoundError(f"Metrics file not found: {path}") |
| with open(p) as f: |
| return json.load(f) |
|
|
|
|
| def _save(fig: plt.Figure, name: str) -> None: |
| CHARTS_DIR.mkdir(parents=True, exist_ok=True) |
| out = CHARTS_DIR / name |
| fig.savefig(out, bbox_inches="tight", dpi=150) |
| plt.close(fig) |
| print(f" saved β {out}") |
|
|
|
|
| |
| |
| |
|
|
| def plot_overall_scores(metrics: dict) -> None: |
| overall = metrics.get("overall_scores", {}) |
| scorers = sorted(overall.keys()) |
| averages = [overall[s]["average"] for s in scorers] |
| counts = [overall[s]["count"] for s in scorers] |
|
|
| colors = [ |
| PALETTE["pass"] if v == 1.0 |
| else PALETTE["warn"] if v >= 0.75 |
| else PALETTE["fail"] |
| for v in averages |
| ] |
|
|
| fig, ax = plt.subplots(figsize=(10, 5)) |
| fig.patch.set_facecolor(PALETTE["bg"]) |
| ax.set_facecolor(PALETTE["bg"]) |
|
|
| scorer_labels = [_fmt(s, _SCORER_LABELS) for s in scorers] |
| bars = ax.barh(scorer_labels, averages, color=colors, edgecolor="white", height=0.6) |
|
|
| for bar, avg, n in zip(bars, averages, counts): |
| ax.text( |
| min(avg + 0.005, 1.01), bar.get_y() + bar.get_height() / 2, |
| f"{avg:.3f} (n={n})", |
| va="center", ha="left", fontsize=9, color="#333333", |
| ) |
|
|
| ax.set_xlim(0, 1.12) |
| ax.set_xlabel("Average Score", fontsize=11) |
| ax.set_title("Overall Scorer Averages", fontsize=13, fontweight="bold", pad=12) |
| ax.axvline(1.0, color="#BDBDBD", linewidth=0.8, linestyle="--") |
| ax.tick_params(axis="y", labelsize=10) |
| ax.spines[["top", "right", "bottom"]].set_visible(False) |
|
|
| total = metrics.get("total_cases", 0) |
| ax.text( |
| 0.99, -0.08, f"n = {total} total cases", |
| transform=ax.transAxes, ha="right", fontsize=9, color="#757575", |
| ) |
|
|
| _save(fig, "overall_scores.png") |
|
|
|
|
| |
| |
| |
|
|
| def plot_category_heatmap(metrics: dict) -> None: |
| per_category = metrics.get("per_category", {}) |
| categories = sorted(per_category.keys()) |
|
|
| all_scorers: set[str] = set() |
| for cat_data in per_category.values(): |
| all_scorers.update(cat_data.keys()) |
| scorers = sorted(all_scorers) |
|
|
| matrix = np.full((len(categories), len(scorers)), np.nan) |
| for r, cat in enumerate(categories): |
| for c, scorer in enumerate(scorers): |
| entry = per_category[cat].get(scorer) |
| if entry is not None: |
| matrix[r, c] = entry["average"] |
|
|
| fig, ax = plt.subplots(figsize=(max(10, len(scorers) * 1.3), max(6, len(categories) * 0.55))) |
| fig.patch.set_facecolor(PALETTE["bg"]) |
| ax.set_facecolor(PALETTE["bg"]) |
|
|
| cmap = mcolors.LinearSegmentedColormap.from_list( |
| "eval", [PALETTE["fail"], PALETTE["warn"], PALETTE["pass"]] |
| ) |
| cmap.set_bad(color="#E0E0E0") |
|
|
| im = ax.imshow(matrix, cmap=cmap, vmin=0.0, vmax=1.0, aspect="auto") |
|
|
| ax.set_xticks(range(len(scorers))) |
| ax.set_xticklabels([_fmt(s, _SCORER_LABELS) for s in scorers], rotation=35, ha="right", fontsize=9) |
| ax.set_yticks(range(len(categories))) |
| ax.set_yticklabels([_fmt(c, _CATEGORY_LABELS) for c in categories], fontsize=9) |
|
|
| for r in range(len(categories)): |
| for c in range(len(scorers)): |
| val = matrix[r, c] |
| if not np.isnan(val): |
| text_color = "white" if val < 0.6 else "#222222" |
| ax.text(c, r, f"{val:.2f}", ha="center", va="center", |
| fontsize=7.5, color=text_color) |
|
|
| plt.colorbar(im, ax=ax, fraction=0.02, pad=0.02, label="Average Score") |
| ax.set_title("Score Heatmap: Category Γ Scorer", fontsize=13, fontweight="bold", pad=12) |
| ax.tick_params(length=0) |
|
|
| _save(fig, "category_heatmap.png") |
|
|
|
|
| |
| |
| |
|
|
| def plot_latency_distribution(metrics: dict) -> None: |
| latency_scores = [ |
| case["scores"].get("Latency") |
| for case in metrics.get("per_case", []) |
| if case.get("scores", {}).get("Latency") is not None |
| ] |
|
|
| if not latency_scores: |
| print(" no latency scores found β skipping latency chart") |
| return |
|
|
| buckets = {1.0: 0, 0.75: 0, 0.5: 0, 0.25: 0, 0.0: 0} |
| for s in latency_scores: |
| buckets[s] = buckets.get(s, 0) + 1 |
|
|
| labels = { |
| 1.0: "< 30s\n(1.0)", |
| 0.75: "30β60s\n(0.75)", |
| 0.5: "60β120s\n(0.5)", |
| 0.25: "120β180s\n(0.25)", |
| 0.0: "> 180s\n(0.0)", |
| } |
| ordered = [1.0, 0.75, 0.5, 0.25, 0.0] |
| counts = [buckets.get(v, 0) for v in ordered] |
| tick_labels = [labels[v] for v in ordered] |
| colors = [PALETTE["pass"], PALETTE["warn"], PALETTE["warn"], PALETTE["fail"], PALETTE["fail"]] |
|
|
| fig, ax = plt.subplots(figsize=(8, 4)) |
| fig.patch.set_facecolor(PALETTE["bg"]) |
| ax.set_facecolor(PALETTE["bg"]) |
|
|
| bars = ax.bar(tick_labels, counts, color=colors, edgecolor="white", width=0.55) |
| for bar, count in zip(bars, counts): |
| if count: |
| ax.text( |
| bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.15, |
| str(count), ha="center", va="bottom", fontsize=10, |
| ) |
|
|
| ax.set_ylabel("Number of Cases", fontsize=11) |
| ax.set_title("Latency Score Distribution (per case)", fontsize=13, fontweight="bold", pad=12) |
| ax.spines[["top", "right", "left"]].set_visible(False) |
| ax.tick_params(axis="x", length=0) |
| ax.set_yticks([]) |
|
|
| _save(fig, "latency_distribution.png") |
|
|
|
|
| |
| |
| |
|
|
| def plot_failures_by_category(metrics: dict, threshold: float) -> None: |
| failure_counts: dict[str, int] = {} |
|
|
| for case in metrics.get("per_case", []): |
| cat = case.get("category", "unknown") |
| has_failure = any( |
| s < threshold for s in case.get("scores", {}).values() |
| ) or bool(case.get("error")) |
| if has_failure: |
| failure_counts[cat] = failure_counts.get(cat, 0) + 1 |
|
|
| if not failure_counts: |
| print(f" no failures at threshold {threshold} β skipping failures chart") |
| return |
|
|
| categories = sorted(failure_counts, key=lambda c: -failure_counts[c]) |
| counts = [failure_counts[c] for c in categories] |
| category_labels = [_fmt(c, _CATEGORY_LABELS) for c in categories] |
|
|
| fig, ax = plt.subplots(figsize=(9, max(4, len(categories) * 0.5))) |
| fig.patch.set_facecolor(PALETTE["bg"]) |
| ax.set_facecolor(PALETTE["bg"]) |
|
|
| ax.barh(category_labels, counts, color=PALETTE["fail"], edgecolor="white", height=0.6) |
| for i, (cat, count) in enumerate(zip(category_labels, counts)): |
| ax.text(count + 0.05, i, str(count), va="center", fontsize=10) |
|
|
| ax.set_xlabel("Number of Failed Cases", fontsize=11) |
| ax.set_title( |
| f"Failures by Category (threshold < {threshold})", |
| fontsize=13, fontweight="bold", pad=12, |
| ) |
| ax.spines[["top", "right", "bottom"]].set_visible(False) |
| ax.set_xlim(0, max(counts) + 1.5) |
| ax.tick_params(axis="y", labelsize=9) |
|
|
| _save(fig, "failures_by_category.png") |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Visualize eval_metrics.json. Charts saved to charts/", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=__doc__, |
| ) |
| parser.add_argument( |
| "--metrics", type=str, default=str(DEFAULT_METRICS_PATH), |
| help=f"Path to eval_metrics.json (default: {DEFAULT_METRICS_PATH})", |
| ) |
| parser.add_argument( |
| "--threshold", type=float, default=DEFAULT_THRESHOLD, |
| help=f"Failure threshold for the failures-by-category chart (default: {DEFAULT_THRESHOLD})", |
| ) |
| args = parser.parse_args() |
|
|
| metrics = load_metrics(args.metrics) |
| print(f"\nGenerating charts from {args.metrics} β {CHARTS_DIR}/\n") |
|
|
| plot_overall_scores(metrics) |
| plot_category_heatmap(metrics) |
| plot_latency_distribution(metrics) |
| plot_failures_by_category(metrics, args.threshold) |
|
|
| print("\nDone.\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|