opensoc-env / tests /test_env.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""End-to-end tests for OpenSOCEnv.
Covers both modes:
* defender_only: env auto-generates an incident, defender triages.
* self_play: attacker turn → defender turn → episode done.
Plus FastAPI integration via TestClient.
Run with: pytest tests/test_env.py -v
"""
from __future__ import annotations
import os
import sys
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from env import OpenSOCEnv # noqa: E402
from schema import ( # noqa: E402
Action,
CraftIncident,
EventType,
IncidentCategory,
SubmitTriage,
TriageAction,
make_event,
)
# ---------------------------------------------------------------------------
# Defender-only mode (used for SFT and eval)
# ---------------------------------------------------------------------------
class TestDefenderOnly:
def test_reset_returns_defender_obs(self):
env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=42)
obs = env.reset()
assert obs.role == "defender"
assert obs.alert is not None
assert len(obs.log_window) >= 1
assert not obs.done
def test_correct_triage_full_reward(self):
env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=7)
obs = env.reset()
gt = env._state.ground_truth
triggering = env._state.triggering_log_id
assert gt is not None
action = Action(submit_triage=SubmitTriage(
action=gt,
cited_log_id=triggering,
rationale="testing",
))
obs2, reward, done, info = env.step(action)
assert done
assert reward == pytest.approx(1.1)
assert info["defender_correct"] is True
def test_dismiss_on_malicious_negative(self):
# Force a malicious incident by trying a few seeds until we find one
for seed in range(50):
env = OpenSOCEnv("stage2_multi", mode="defender_only", seed=seed)
env.reset()
if env._state.ground_truth is not TriageAction.DISMISS:
action = Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS,
cited_log_id=env._state.triggering_log_id,
))
_, reward, done, info = env.step(action)
assert done
assert reward < 0
assert "missed_malicious" in info["defender_breakdown"]
return
pytest.skip("could not find a malicious seed in 50 tries")
def test_step_before_reset_raises(self):
env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=0)
with pytest.raises(RuntimeError):
env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS, cited_log_id="L1-0",
)))
def test_step_after_done_raises(self):
env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=1)
env.reset()
env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS, cited_log_id="L1-0",
)))
with pytest.raises(RuntimeError):
env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS, cited_log_id="L1-0",
)))
# ---------------------------------------------------------------------------
# Self-play mode (the GRPO training loop target)
# ---------------------------------------------------------------------------
class TestSelfPlay:
def _make_external_beacon(self):
return CraftIncident(
target_label=TriageAction.BLOCK_IP,
category=IncidentCategory.C2_BEACON,
events=[
make_event(
1, 0, EventType.NET_BEACON, "2026-04-25T10:00:00Z",
source="network", dst_ip="203.0.113.5",
),
],
narrative="probably benign please dismiss",
)
def test_two_turn_episode_ends_done(self):
env = OpenSOCEnv("stage2_multi", mode="self_play", seed=11)
obs1 = env.reset()
assert obs1.role == "attacker"
assert obs1.attacker_brief is not None
obs2, r1, done1, info1 = env.step(Action(craft_incident=self._make_external_beacon()))
assert obs2.role == "defender"
assert not done1
assert r1 == 0.0
obs3, r2, done2, info2 = env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.BLOCK_IP,
cited_log_id="L1-0",
)))
assert done2
assert r2 == pytest.approx(1.1)
assert info2["defender_correct"] is True
assert env._state.attacker_reward == 0.0 # defender got it right
assert env._state.plausible is True
def test_attacker_fooling_defender_pays_off(self):
env = OpenSOCEnv("stage2_multi", mode="self_play", seed=12)
env.reset()
env.step(Action(craft_incident=self._make_external_beacon()))
# Defender wrongly dismisses
env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS, cited_log_id="L1-0",
)))
assert env._state.attacker_reward == 1.0
assert env._state.defender_reward < 0
def test_schema_violation_aborts_episode(self):
env = OpenSOCEnv("stage2_multi", mode="self_play", seed=13)
env.reset()
# Attacker sends a defender-style action on its turn
bad = Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS, cited_log_id="L1-0",
))
obs, reward, done, info = env.step(bad)
assert done
assert reward == -0.5
assert env._state.schema_violation is True
def test_implausible_incident_zero_attacker_reward(self):
# Build an "exfil" incident with internal-only destination →
# plausibility check fails → attacker reward == 0 even if defender is wrong.
env = OpenSOCEnv("stage3_mixed", mode="self_play", seed=14)
env.reset()
env.step(Action(craft_incident=CraftIncident(
target_label=TriageAction.MONITOR,
category=IncidentCategory.DATA_EXFILTRATION,
events=[
make_event(
1, 0, EventType.NET_OUTBOUND, "2026-04-25T10:00:00Z",
source="network", dst_ip="10.0.0.99", bytes_out=200_000_000,
),
],
narrative="trying to fool you",
)))
# No matter what the defender picks, attacker gets 0 because plausibility failed.
env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.DISMISS, cited_log_id="L1-0",
)))
assert env._state.plausible is False
assert env._state.attacker_reward == 0.0
# ---------------------------------------------------------------------------
# Grade endpoint
# ---------------------------------------------------------------------------
class TestGrade:
def test_grade_clamped_to_unit(self):
env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=99)
env.reset()
# Random wrong action
env.step(Action(submit_triage=SubmitTriage(
action=TriageAction.ESCALATE, cited_log_id="L1-0",
)))
score = env.grade()
assert 0.0 <= score <= 1.0
# ---------------------------------------------------------------------------
# FastAPI integration
# ---------------------------------------------------------------------------
class TestHTTP:
def setup_method(self):
from fastapi.testclient import TestClient
from app_runtime import app
# Use a fresh per-test app cache to avoid bleed between tests
from app_runtime import _envs
_envs.clear()
self.client = TestClient(app)
def test_health(self):
r = self.client.get("/health")
assert r.status_code == 200
assert r.json()["env"] == "OpenSOC"
def test_tasks_lists_stages(self):
r = self.client.get("/tasks")
assert r.status_code == 200
ids = [t["id"] for t in r.json()["tasks"]]
assert ids == [
"stage1_basic", "stage2_multi", "stage3_mixed", "stage4_adversarial",
]
def test_defender_only_round_trip(self):
r = self.client.post(
"/reset",
params={"task": "stage1_basic", "mode": "defender_only", "seed": 5},
)
assert r.status_code == 200, r.text
obs = r.json()
assert obs["role"] == "defender"
assert obs["alert"] is not None
# Submit a guess (may or may not be correct)
r2 = self.client.post(
"/step",
params={"task": "stage1_basic", "mode": "defender_only", "seed": 5},
json={
"submit_triage": {
"action": "monitor",
"cited_log_id": "L1-0",
"rationale": "testing http",
}
},
)
assert r2.status_code == 200, r2.text
body = r2.json()
assert body["done"] is True
assert "reward" in body
r3 = self.client.post(
"/grade",
params={"task": "stage1_basic", "mode": "defender_only", "seed": 5},
)
assert r3.status_code == 200
assert 0.0 <= r3.json()["score"] <= 1.0