opensoc-env / tests /test_eval.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Smoke + correctness tests for `eval.metrics` and `eval.eval`.
Run with: pytest tests/test_eval.py -v
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from eval.metrics import ( # noqa: E402
accuracy,
confusion_matrix,
dismiss_on_malicious_rate,
over_react_rate,
per_class_f1,
)
class TestMetrics:
def test_accuracy_perfect(self):
assert accuracy(["a", "b", "c"], ["a", "b", "c"]) == 1.0
def test_accuracy_half(self):
assert accuracy(["a", "b"], ["a", "z"]) == pytest.approx(0.5)
def test_dismiss_on_malicious_rate(self):
# 4 malicious truths, 2 of them got dismissed -> 0.5
preds = ["dismiss", "dismiss", "monitor", "block_ip"]
truths = ["block_ip", "monitor", "monitor", "block_ip"]
assert dismiss_on_malicious_rate(preds, truths) == pytest.approx(0.5)
def test_dismiss_on_malicious_no_malicious(self):
# All-benign truths -> rate is 0 (avoid div-by-zero).
assert dismiss_on_malicious_rate(["dismiss", "dismiss"], ["dismiss", "dismiss"]) == 0.0
def test_over_react_rate(self):
# 4 benign truths, 2 got over-reacted on -> 0.5
preds = ["block_ip", "quarantine_host", "monitor", "dismiss"]
truths = ["dismiss", "monitor", "monitor", "dismiss"]
assert over_react_rate(preds, truths) == pytest.approx(0.5)
def test_per_class_f1_perfect(self):
truths = ["dismiss", "monitor", "block_ip", "escalate", "quarantine_host"]
preds = list(truths)
cm = confusion_matrix(preds, truths)
macro, per_class = per_class_f1(cm)
assert macro == pytest.approx(1.0)
for c, m in per_class.items():
if m["support"] > 0:
assert m["f1"] == pytest.approx(1.0)
class TestHoldout:
def setup_method(self):
# Generate a small hold-out file in memory and run eval --smoke-only
self.repo = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def test_make_holdout_writes_jsonl(self, tmp_path):
out = tmp_path / "ho.jsonl"
subprocess.run([
sys.executable, "-m", "eval.make_holdout",
"--n-per-stage", "5",
"--out", str(out.relative_to(self.repo)) if out.is_relative_to(self.repo) else str(out),
], check=True, cwd=self.repo)
# Use the path that was used by the script (relative-to-repo pathing is handled there).
# Easier: rerun directly importing the module to a tmp file.
def test_eval_smoke_only_runs(self, tmp_path):
out_dir = tmp_path / "results"
# Make a 5-incident holdout into the default location used by eval.eval
subprocess.run([
sys.executable, "-m", "eval.make_holdout",
"--n-per-stage", "5",
"--out", "data/holdout_smoke.jsonl",
], check=True, cwd=self.repo)
result = subprocess.run([
sys.executable, "-m", "eval.eval",
"--smoke-only",
"--holdout", "data/holdout_smoke.jsonl",
"--out-dir", str(out_dir),
], check=True, cwd=self.repo, capture_output=True, text=True)
# Parse the saved summary
summary = json.loads((out_dir / "summary.json").read_text())
labels = [s["label"] for s in summary]
assert "verifier_oracle" in labels
assert "always_dismiss" in labels
oracle = next(s for s in summary if s["label"] == "verifier_oracle")
assert oracle["accuracy"] == pytest.approx(1.0)
assert oracle["dismiss_on_malicious"] == pytest.approx(0.0)