Spaces:
Running
Running
File size: 3,343 Bytes
bb6a031 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | """Plot eval/results/summary.json into PNG images for the README.
Run after `eval.eval`::
python -m eval.plot_results --in eval/results/summary.json --out-dir eval/results
Generates:
* `bar_dismiss_on_malicious.png` — the headline plot.
* `bar_macro_f1.png` — macro F1 by model.
* `confusion_<model>.png` — one heatmap per evaluated model.
We use matplotlib only; no seaborn dependency. This keeps the Hugging
Face Space slim and lets the plotter run on CPU only.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))
from eval.metrics import ALL_ACTIONS # noqa: E402
def _try_matplotlib():
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
return plt
except ImportError:
return None
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--in", dest="inp", default="eval/results/summary.json")
parser.add_argument("--out-dir", default="eval/results")
args = parser.parse_args()
plt = _try_matplotlib()
if plt is None:
sys.exit("matplotlib is required to render plots: `pip install matplotlib`")
inp = os.path.join(os.path.dirname(_HERE), args.inp)
out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir)
os.makedirs(out_dir, exist_ok=True)
with open(inp, "r") as f:
summaries = json.load(f)
labels = [s["label"] for s in summaries]
miss = [s["dismiss_on_malicious"] for s in summaries]
f1s = [s["macro_f1"] for s in summaries]
fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(labels, miss)
ax.set_ylabel("dismiss-on-malicious rate (lower is better)")
ax.set_title("Missed-malicious rate by model")
plt.xticks(rotation=20, ha="right")
fig.tight_layout()
fig.savefig(os.path.join(out_dir, "bar_dismiss_on_malicious.png"), dpi=150)
plt.close(fig)
fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(labels, f1s)
ax.set_ylabel("macro F1 (higher is better)")
ax.set_title("Macro F1 by model")
plt.xticks(rotation=20, ha="right")
fig.tight_layout()
fig.savefig(os.path.join(out_dir, "bar_macro_f1.png"), dpi=150)
plt.close(fig)
for s in summaries:
cm = s["confusion_matrix"]
rows = [[cm.get(gt, {}).get(p, 0) for p in ALL_ACTIONS] for gt in ALL_ACTIONS]
fig, ax = plt.subplots(figsize=(5.5, 4.5))
im = ax.imshow(rows, cmap="Blues")
ax.set_xticks(range(len(ALL_ACTIONS)), ALL_ACTIONS, rotation=25, ha="right")
ax.set_yticks(range(len(ALL_ACTIONS)), ALL_ACTIONS)
ax.set_xlabel("predicted")
ax.set_ylabel("ground truth")
ax.set_title(f"Confusion matrix: {s['label']}")
for r, row in enumerate(rows):
for c, v in enumerate(row):
ax.text(c, r, str(v), ha="center", va="center", fontsize=8,
color="white" if v > max(max(rr) for rr in rows) / 2 else "black")
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
fig.tight_layout()
fig.savefig(os.path.join(out_dir, f"confusion_{s['label']}.png"), dpi=150)
plt.close(fig)
print(f"Wrote plots to {out_dir}")
if __name__ == "__main__":
main()
|