| """Smoke + correctness tests for `eval.metrics` and `eval.eval`. |
| |
| Run with: pytest tests/test_eval.py -v |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import subprocess |
| import sys |
|
|
| import pytest |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
|
|
| from eval.metrics import ( |
| accuracy, |
| confusion_matrix, |
| dismiss_on_malicious_rate, |
| over_react_rate, |
| per_class_f1, |
| ) |
|
|
|
|
| class TestMetrics: |
| def test_accuracy_perfect(self): |
| assert accuracy(["a", "b", "c"], ["a", "b", "c"]) == 1.0 |
|
|
| def test_accuracy_half(self): |
| assert accuracy(["a", "b"], ["a", "z"]) == pytest.approx(0.5) |
|
|
| def test_dismiss_on_malicious_rate(self): |
| |
| preds = ["dismiss", "dismiss", "monitor", "block_ip"] |
| truths = ["block_ip", "monitor", "monitor", "block_ip"] |
| assert dismiss_on_malicious_rate(preds, truths) == pytest.approx(0.5) |
|
|
| def test_dismiss_on_malicious_no_malicious(self): |
| |
| assert dismiss_on_malicious_rate(["dismiss", "dismiss"], ["dismiss", "dismiss"]) == 0.0 |
|
|
| def test_over_react_rate(self): |
| |
| preds = ["block_ip", "quarantine_host", "monitor", "dismiss"] |
| truths = ["dismiss", "monitor", "monitor", "dismiss"] |
| assert over_react_rate(preds, truths) == pytest.approx(0.5) |
|
|
| def test_per_class_f1_perfect(self): |
| truths = ["dismiss", "monitor", "block_ip", "escalate", "quarantine_host"] |
| preds = list(truths) |
| cm = confusion_matrix(preds, truths) |
| macro, per_class = per_class_f1(cm) |
| assert macro == pytest.approx(1.0) |
| for c, m in per_class.items(): |
| if m["support"] > 0: |
| assert m["f1"] == pytest.approx(1.0) |
|
|
|
|
| class TestHoldout: |
| def setup_method(self): |
| |
| self.repo = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
| def test_make_holdout_writes_jsonl(self, tmp_path): |
| out = tmp_path / "ho.jsonl" |
| subprocess.run([ |
| sys.executable, "-m", "eval.make_holdout", |
| "--n-per-stage", "5", |
| "--out", str(out.relative_to(self.repo)) if out.is_relative_to(self.repo) else str(out), |
| ], check=True, cwd=self.repo) |
| |
| |
|
|
| def test_eval_smoke_only_runs(self, tmp_path): |
| out_dir = tmp_path / "results" |
| |
| subprocess.run([ |
| sys.executable, "-m", "eval.make_holdout", |
| "--n-per-stage", "5", |
| "--out", "data/holdout_smoke.jsonl", |
| ], check=True, cwd=self.repo) |
| result = subprocess.run([ |
| sys.executable, "-m", "eval.eval", |
| "--smoke-only", |
| "--holdout", "data/holdout_smoke.jsonl", |
| "--out-dir", str(out_dir), |
| ], check=True, cwd=self.repo, capture_output=True, text=True) |
| |
| summary = json.loads((out_dir / "summary.json").read_text()) |
| labels = [s["label"] for s in summary] |
| assert "verifier_oracle" in labels |
| assert "always_dismiss" in labels |
| oracle = next(s for s in summary if s["label"] == "verifier_oracle") |
| assert oracle["accuracy"] == pytest.approx(1.0) |
| assert oracle["dismiss_on_malicious"] == pytest.approx(0.0) |
|
|