File size: 3,343 Bytes
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Plot eval/results/summary.json into PNG images for the README.

Run after `eval.eval`::

    python -m eval.plot_results --in eval/results/summary.json --out-dir eval/results

Generates:
  * `bar_dismiss_on_malicious.png` — the headline plot.
  * `bar_macro_f1.png` — macro F1 by model.
  * `confusion_<model>.png` — one heatmap per evaluated model.

We use matplotlib only; no seaborn dependency.  This keeps the Hugging
Face Space slim and lets the plotter run on CPU only.
"""

from __future__ import annotations

import argparse
import json
import os
import sys

_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))

from eval.metrics import ALL_ACTIONS  # noqa: E402


def _try_matplotlib():
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        return plt
    except ImportError:
        return None


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--in", dest="inp", default="eval/results/summary.json")
    parser.add_argument("--out-dir", default="eval/results")
    args = parser.parse_args()

    plt = _try_matplotlib()
    if plt is None:
        sys.exit("matplotlib is required to render plots: `pip install matplotlib`")

    inp = os.path.join(os.path.dirname(_HERE), args.inp)
    out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir)
    os.makedirs(out_dir, exist_ok=True)

    with open(inp, "r") as f:
        summaries = json.load(f)

    labels = [s["label"] for s in summaries]
    miss = [s["dismiss_on_malicious"] for s in summaries]
    f1s = [s["macro_f1"] for s in summaries]

    fig, ax = plt.subplots(figsize=(7, 4))
    ax.bar(labels, miss)
    ax.set_ylabel("dismiss-on-malicious rate (lower is better)")
    ax.set_title("Missed-malicious rate by model")
    plt.xticks(rotation=20, ha="right")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, "bar_dismiss_on_malicious.png"), dpi=150)
    plt.close(fig)

    fig, ax = plt.subplots(figsize=(7, 4))
    ax.bar(labels, f1s)
    ax.set_ylabel("macro F1 (higher is better)")
    ax.set_title("Macro F1 by model")
    plt.xticks(rotation=20, ha="right")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, "bar_macro_f1.png"), dpi=150)
    plt.close(fig)

    for s in summaries:
        cm = s["confusion_matrix"]
        rows = [[cm.get(gt, {}).get(p, 0) for p in ALL_ACTIONS] for gt in ALL_ACTIONS]
        fig, ax = plt.subplots(figsize=(5.5, 4.5))
        im = ax.imshow(rows, cmap="Blues")
        ax.set_xticks(range(len(ALL_ACTIONS)), ALL_ACTIONS, rotation=25, ha="right")
        ax.set_yticks(range(len(ALL_ACTIONS)), ALL_ACTIONS)
        ax.set_xlabel("predicted")
        ax.set_ylabel("ground truth")
        ax.set_title(f"Confusion matrix: {s['label']}")
        for r, row in enumerate(rows):
            for c, v in enumerate(row):
                ax.text(c, r, str(v), ha="center", va="center", fontsize=8,
                        color="white" if v > max(max(rr) for rr in rows) / 2 else "black")
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        fig.tight_layout()
        fig.savefig(os.path.join(out_dir, f"confusion_{s['label']}.png"), dpi=150)
        plt.close(fig)

    print(f"Wrote plots to {out_dir}")


if __name__ == "__main__":
    main()