File size: 3,370 Bytes
991c049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Generate all six required plots from existing CSV outputs.

Inputs (must already exist; usually written by ``citeguard evaluate``):
- outputs/tables/claim_eval.csv
- outputs/tables/aggregate_metrics.json (for confusion matrix)
- outputs/tables/benchmark_summary.csv (optional — for baseline/ablation plots)

Outputs (all six):
- outputs/figures/fig_error_distribution.png
- outputs/figures/fig_confusion_matrix.png
- outputs/figures/fig_baseline_comparison.png
- outputs/figures/fig_ablation.png
- outputs/figures/fig_retrieval_vs_support.png
- outputs/figures/fig_runtime.png
"""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))

import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402

from citeguard.reporting.plots import (  # noqa: E402
    plot_ablation,
    plot_baseline_comparison,
    plot_confusion_matrix,
    plot_error_distribution,
    plot_retrieval_vs_support,
    plot_runtime,
)
from citeguard.schemas import LABELS  # noqa: E402


def _load_confusion(agg_path: Path) -> tuple[np.ndarray, list[str]]:
    if not agg_path.exists():
        return np.zeros((len(LABELS), len(LABELS)), dtype=int), LABELS
    with open(agg_path, "r", encoding="utf-8") as f:
        agg = json.load(f)
    conf = agg.get("confusion", {})
    labels = conf.get("labels") or LABELS
    matrix = conf.get("matrix") or []
    if not matrix:
        return np.zeros((len(labels), len(labels)), dtype=int), labels
    return np.asarray(matrix, dtype=int), list(labels)


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--claim-eval", default="outputs/tables/claim_eval.csv")
    ap.add_argument("--aggregate", default="outputs/tables/aggregate_metrics.json")
    ap.add_argument("--benchmark", default="outputs/tables/benchmark_summary.csv")
    ap.add_argument("--figures-dir", default="outputs/figures")
    args = ap.parse_args()

    figs = Path(args.figures_dir)
    figs.mkdir(parents=True, exist_ok=True)

    claim_path = Path(args.claim_eval)
    if claim_path.exists():
        df = pd.read_csv(claim_path)
    else:
        print(f"WARNING: {claim_path} not found; producing empty figures.", file=sys.stderr)
        df = pd.DataFrame()

    plot_error_distribution(df, figs / "fig_error_distribution.png")

    cm, labels = _load_confusion(Path(args.aggregate))
    plot_confusion_matrix(cm, labels, figs / "fig_confusion_matrix.png")

    bench_path = Path(args.benchmark)
    if bench_path.exists():
        bench_df = pd.read_csv(bench_path)
        rows = bench_df.to_dict(orient="records")
    else:
        rows = []
    # Baseline comparison uses the same rows; the script user can supply more.
    plot_baseline_comparison(rows, figs / "fig_baseline_comparison.png")
    plot_ablation(rows, figs / "fig_ablation.png")

    plot_retrieval_vs_support(df, figs / "fig_retrieval_vs_support.png")
    plot_runtime(df, figs / "fig_runtime.png")

    for name in [
        "fig_error_distribution.png",
        "fig_confusion_matrix.png",
        "fig_baseline_comparison.png",
        "fig_ablation.png",
        "fig_retrieval_vs_support.png",
        "fig_runtime.png",
    ]:
        print(f"  {figs / name}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())