Spaces:
Running
Running
File size: 3,370 Bytes
991c049 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """Generate all six required plots from existing CSV outputs.
Inputs (must already exist; usually written by ``citeguard evaluate``):
- outputs/tables/claim_eval.csv
- outputs/tables/aggregate_metrics.json (for confusion matrix)
- outputs/tables/benchmark_summary.csv (optional — for baseline/ablation plots)
Outputs (all six):
- outputs/figures/fig_error_distribution.png
- outputs/figures/fig_confusion_matrix.png
- outputs/figures/fig_baseline_comparison.png
- outputs/figures/fig_ablation.png
- outputs/figures/fig_retrieval_vs_support.png
- outputs/figures/fig_runtime.png
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
from citeguard.reporting.plots import ( # noqa: E402
plot_ablation,
plot_baseline_comparison,
plot_confusion_matrix,
plot_error_distribution,
plot_retrieval_vs_support,
plot_runtime,
)
from citeguard.schemas import LABELS # noqa: E402
def _load_confusion(agg_path: Path) -> tuple[np.ndarray, list[str]]:
if not agg_path.exists():
return np.zeros((len(LABELS), len(LABELS)), dtype=int), LABELS
with open(agg_path, "r", encoding="utf-8") as f:
agg = json.load(f)
conf = agg.get("confusion", {})
labels = conf.get("labels") or LABELS
matrix = conf.get("matrix") or []
if not matrix:
return np.zeros((len(labels), len(labels)), dtype=int), labels
return np.asarray(matrix, dtype=int), list(labels)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--claim-eval", default="outputs/tables/claim_eval.csv")
ap.add_argument("--aggregate", default="outputs/tables/aggregate_metrics.json")
ap.add_argument("--benchmark", default="outputs/tables/benchmark_summary.csv")
ap.add_argument("--figures-dir", default="outputs/figures")
args = ap.parse_args()
figs = Path(args.figures_dir)
figs.mkdir(parents=True, exist_ok=True)
claim_path = Path(args.claim_eval)
if claim_path.exists():
df = pd.read_csv(claim_path)
else:
print(f"WARNING: {claim_path} not found; producing empty figures.", file=sys.stderr)
df = pd.DataFrame()
plot_error_distribution(df, figs / "fig_error_distribution.png")
cm, labels = _load_confusion(Path(args.aggregate))
plot_confusion_matrix(cm, labels, figs / "fig_confusion_matrix.png")
bench_path = Path(args.benchmark)
if bench_path.exists():
bench_df = pd.read_csv(bench_path)
rows = bench_df.to_dict(orient="records")
else:
rows = []
# Baseline comparison uses the same rows; the script user can supply more.
plot_baseline_comparison(rows, figs / "fig_baseline_comparison.png")
plot_ablation(rows, figs / "fig_ablation.png")
plot_retrieval_vs_support(df, figs / "fig_retrieval_vs_support.png")
plot_runtime(df, figs / "fig_runtime.png")
for name in [
"fig_error_distribution.png",
"fig_confusion_matrix.png",
"fig_baseline_comparison.png",
"fig_ablation.png",
"fig_retrieval_vs_support.png",
"fig_runtime.png",
]:
print(f" {figs / name}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|