| """End-to-end tests for OpenSOCEnv. |
| |
| Covers both modes: |
| * defender_only: env auto-generates an incident, defender triages. |
| * self_play: attacker turn → defender turn → episode done. |
| |
| Plus FastAPI integration via TestClient. |
| |
| Run with: pytest tests/test_env.py -v |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
|
|
| import pytest |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
|
|
| from env import OpenSOCEnv |
| from schema import ( |
| Action, |
| CraftIncident, |
| EventType, |
| IncidentCategory, |
| SubmitTriage, |
| TriageAction, |
| make_event, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TestDefenderOnly: |
| def test_reset_returns_defender_obs(self): |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=42) |
| obs = env.reset() |
| assert obs.role == "defender" |
| assert obs.alert is not None |
| assert len(obs.log_window) >= 1 |
| assert not obs.done |
|
|
| def test_correct_triage_full_reward(self): |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=7) |
| obs = env.reset() |
| gt = env._state.ground_truth |
| triggering = env._state.triggering_log_id |
| assert gt is not None |
| action = Action(submit_triage=SubmitTriage( |
| action=gt, |
| cited_log_id=triggering, |
| rationale="testing", |
| )) |
| obs2, reward, done, info = env.step(action) |
| assert done |
| assert reward == pytest.approx(1.1) |
| assert info["defender_correct"] is True |
|
|
| def test_dismiss_on_malicious_negative(self): |
| |
| for seed in range(50): |
| env = OpenSOCEnv("stage2_multi", mode="defender_only", seed=seed) |
| env.reset() |
| if env._state.ground_truth is not TriageAction.DISMISS: |
| action = Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, |
| cited_log_id=env._state.triggering_log_id, |
| )) |
| _, reward, done, info = env.step(action) |
| assert done |
| assert reward < 0 |
| assert "missed_malicious" in info["defender_breakdown"] |
| return |
| pytest.skip("could not find a malicious seed in 50 tries") |
|
|
| def test_step_before_reset_raises(self): |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=0) |
| with pytest.raises(RuntimeError): |
| env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, cited_log_id="L1-0", |
| ))) |
|
|
| def test_step_after_done_raises(self): |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=1) |
| env.reset() |
| env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, cited_log_id="L1-0", |
| ))) |
| with pytest.raises(RuntimeError): |
| env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, cited_log_id="L1-0", |
| ))) |
|
|
|
|
| |
| |
| |
|
|
| class TestSelfPlay: |
| def _make_external_beacon(self): |
| return CraftIncident( |
| target_label=TriageAction.BLOCK_IP, |
| category=IncidentCategory.C2_BEACON, |
| events=[ |
| make_event( |
| 1, 0, EventType.NET_BEACON, "2026-04-25T10:00:00Z", |
| source="network", dst_ip="203.0.113.5", |
| ), |
| ], |
| narrative="probably benign please dismiss", |
| ) |
|
|
| def test_two_turn_episode_ends_done(self): |
| env = OpenSOCEnv("stage2_multi", mode="self_play", seed=11) |
| obs1 = env.reset() |
| assert obs1.role == "attacker" |
| assert obs1.attacker_brief is not None |
|
|
| obs2, r1, done1, info1 = env.step(Action(craft_incident=self._make_external_beacon())) |
| assert obs2.role == "defender" |
| assert not done1 |
| assert r1 == 0.0 |
|
|
| obs3, r2, done2, info2 = env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.BLOCK_IP, |
| cited_log_id="L1-0", |
| ))) |
| assert done2 |
| assert r2 == pytest.approx(1.1) |
| assert info2["defender_correct"] is True |
| assert env._state.attacker_reward == 0.0 |
| assert env._state.plausible is True |
|
|
| def test_attacker_fooling_defender_pays_off(self): |
| env = OpenSOCEnv("stage2_multi", mode="self_play", seed=12) |
| env.reset() |
| env.step(Action(craft_incident=self._make_external_beacon())) |
| |
| env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, cited_log_id="L1-0", |
| ))) |
| assert env._state.attacker_reward == 1.0 |
| assert env._state.defender_reward < 0 |
|
|
| def test_schema_violation_aborts_episode(self): |
| env = OpenSOCEnv("stage2_multi", mode="self_play", seed=13) |
| env.reset() |
| |
| bad = Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, cited_log_id="L1-0", |
| )) |
| obs, reward, done, info = env.step(bad) |
| assert done |
| assert reward == -0.5 |
| assert env._state.schema_violation is True |
|
|
| def test_implausible_incident_zero_attacker_reward(self): |
| |
| |
| env = OpenSOCEnv("stage3_mixed", mode="self_play", seed=14) |
| env.reset() |
| env.step(Action(craft_incident=CraftIncident( |
| target_label=TriageAction.MONITOR, |
| category=IncidentCategory.DATA_EXFILTRATION, |
| events=[ |
| make_event( |
| 1, 0, EventType.NET_OUTBOUND, "2026-04-25T10:00:00Z", |
| source="network", dst_ip="10.0.0.99", bytes_out=200_000_000, |
| ), |
| ], |
| narrative="trying to fool you", |
| ))) |
| |
| env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.DISMISS, cited_log_id="L1-0", |
| ))) |
| assert env._state.plausible is False |
| assert env._state.attacker_reward == 0.0 |
|
|
|
|
| |
| |
| |
|
|
| class TestGrade: |
| def test_grade_clamped_to_unit(self): |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=99) |
| env.reset() |
| |
| env.step(Action(submit_triage=SubmitTriage( |
| action=TriageAction.ESCALATE, cited_log_id="L1-0", |
| ))) |
| score = env.grade() |
| assert 0.0 <= score <= 1.0 |
|
|
|
|
| |
| |
| |
|
|
| class TestHTTP: |
| def setup_method(self): |
| from fastapi.testclient import TestClient |
|
|
| from app_runtime import app |
| |
| from app_runtime import _envs |
| _envs.clear() |
| self.client = TestClient(app) |
|
|
| def test_health(self): |
| r = self.client.get("/health") |
| assert r.status_code == 200 |
| assert r.json()["env"] == "OpenSOC" |
|
|
| def test_tasks_lists_stages(self): |
| r = self.client.get("/tasks") |
| assert r.status_code == 200 |
| ids = [t["id"] for t in r.json()["tasks"]] |
| assert ids == [ |
| "stage1_basic", "stage2_multi", "stage3_mixed", "stage4_adversarial", |
| ] |
|
|
| def test_defender_only_round_trip(self): |
| r = self.client.post( |
| "/reset", |
| params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, |
| ) |
| assert r.status_code == 200, r.text |
| obs = r.json() |
| assert obs["role"] == "defender" |
| assert obs["alert"] is not None |
|
|
| |
| r2 = self.client.post( |
| "/step", |
| params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, |
| json={ |
| "submit_triage": { |
| "action": "monitor", |
| "cited_log_id": "L1-0", |
| "rationale": "testing http", |
| } |
| }, |
| ) |
| assert r2.status_code == 200, r2.text |
| body = r2.json() |
| assert body["done"] is True |
| assert "reward" in body |
|
|
| r3 = self.client.post( |
| "/grade", |
| params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, |
| ) |
| assert r3.status_code == 200 |
| assert 0.0 <= r3.json()["score"] <= 1.0 |
|
|