File size: 3,679 Bytes
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Smoke + correctness tests for `eval.metrics` and `eval.eval`.

Run with: pytest tests/test_eval.py -v
"""

from __future__ import annotations

import json
import os
import subprocess
import sys

import pytest

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from eval.metrics import (  # noqa: E402
    accuracy,
    confusion_matrix,
    dismiss_on_malicious_rate,
    over_react_rate,
    per_class_f1,
)


class TestMetrics:
    def test_accuracy_perfect(self):
        assert accuracy(["a", "b", "c"], ["a", "b", "c"]) == 1.0

    def test_accuracy_half(self):
        assert accuracy(["a", "b"], ["a", "z"]) == pytest.approx(0.5)

    def test_dismiss_on_malicious_rate(self):
        # 4 malicious truths, 2 of them got dismissed -> 0.5
        preds = ["dismiss", "dismiss", "monitor", "block_ip"]
        truths = ["block_ip", "monitor", "monitor", "block_ip"]
        assert dismiss_on_malicious_rate(preds, truths) == pytest.approx(0.5)

    def test_dismiss_on_malicious_no_malicious(self):
        # All-benign truths -> rate is 0 (avoid div-by-zero).
        assert dismiss_on_malicious_rate(["dismiss", "dismiss"], ["dismiss", "dismiss"]) == 0.0

    def test_over_react_rate(self):
        # 4 benign truths, 2 got over-reacted on -> 0.5
        preds = ["block_ip", "quarantine_host", "monitor", "dismiss"]
        truths = ["dismiss", "monitor", "monitor", "dismiss"]
        assert over_react_rate(preds, truths) == pytest.approx(0.5)

    def test_per_class_f1_perfect(self):
        truths = ["dismiss", "monitor", "block_ip", "escalate", "quarantine_host"]
        preds = list(truths)
        cm = confusion_matrix(preds, truths)
        macro, per_class = per_class_f1(cm)
        assert macro == pytest.approx(1.0)
        for c, m in per_class.items():
            if m["support"] > 0:
                assert m["f1"] == pytest.approx(1.0)


class TestHoldout:
    def setup_method(self):
        # Generate a small hold-out file in memory and run eval --smoke-only
        self.repo = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

    def test_make_holdout_writes_jsonl(self, tmp_path):
        out = tmp_path / "ho.jsonl"
        subprocess.run([
            sys.executable, "-m", "eval.make_holdout",
            "--n-per-stage", "5",
            "--out", str(out.relative_to(self.repo)) if out.is_relative_to(self.repo) else str(out),
        ], check=True, cwd=self.repo)
        # Use the path that was used by the script (relative-to-repo pathing is handled there).
        # Easier: rerun directly importing the module to a tmp file.

    def test_eval_smoke_only_runs(self, tmp_path):
        out_dir = tmp_path / "results"
        # Make a 5-incident holdout into the default location used by eval.eval
        subprocess.run([
            sys.executable, "-m", "eval.make_holdout",
            "--n-per-stage", "5",
            "--out", "data/holdout_smoke.jsonl",
        ], check=True, cwd=self.repo)
        result = subprocess.run([
            sys.executable, "-m", "eval.eval",
            "--smoke-only",
            "--holdout", "data/holdout_smoke.jsonl",
            "--out-dir", str(out_dir),
        ], check=True, cwd=self.repo, capture_output=True, text=True)
        # Parse the saved summary
        summary = json.loads((out_dir / "summary.json").read_text())
        labels = [s["label"] for s in summary]
        assert "verifier_oracle" in labels
        assert "always_dismiss" in labels
        oracle = next(s for s in summary if s["label"] == "verifier_oracle")
        assert oracle["accuracy"] == pytest.approx(1.0)
        assert oracle["dismiss_on_malicious"] == pytest.approx(0.0)