Spaces:
Sleeping
Sleeping
| """End-to-end tests for OpenSOCEnv. | |
| Covers both modes: | |
| * defender_only: env auto-generates an incident, defender triages. | |
| * self_play: attacker turn → defender turn → episode done. | |
| Plus FastAPI integration via TestClient. | |
| Run with: pytest tests/test_env.py -v | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import pytest | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) | |
| from env import OpenSOCEnv # noqa: E402 | |
| from schema import ( # noqa: E402 | |
| Action, | |
| CraftIncident, | |
| EventType, | |
| IncidentCategory, | |
| SubmitTriage, | |
| TriageAction, | |
| make_event, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Defender-only mode (used for SFT and eval) | |
| # --------------------------------------------------------------------------- | |
| class TestDefenderOnly: | |
| def test_reset_returns_defender_obs(self): | |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=42) | |
| obs = env.reset() | |
| assert obs.role == "defender" | |
| assert obs.alert is not None | |
| assert len(obs.log_window) >= 1 | |
| assert not obs.done | |
| def test_correct_triage_full_reward(self): | |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=7) | |
| obs = env.reset() | |
| gt = env._state.ground_truth | |
| triggering = env._state.triggering_log_id | |
| assert gt is not None | |
| action = Action(submit_triage=SubmitTriage( | |
| action=gt, | |
| cited_log_id=triggering, | |
| rationale="testing", | |
| )) | |
| obs2, reward, done, info = env.step(action) | |
| assert done | |
| assert reward == pytest.approx(1.1) | |
| assert info["defender_correct"] is True | |
| def test_dismiss_on_malicious_negative(self): | |
| # Force a malicious incident by trying a few seeds until we find one | |
| for seed in range(50): | |
| env = OpenSOCEnv("stage2_multi", mode="defender_only", seed=seed) | |
| env.reset() | |
| if env._state.ground_truth is not TriageAction.DISMISS: | |
| action = Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, | |
| cited_log_id=env._state.triggering_log_id, | |
| )) | |
| _, reward, done, info = env.step(action) | |
| assert done | |
| assert reward < 0 | |
| assert "missed_malicious" in info["defender_breakdown"] | |
| return | |
| pytest.skip("could not find a malicious seed in 50 tries") | |
| def test_step_before_reset_raises(self): | |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=0) | |
| with pytest.raises(RuntimeError): | |
| env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, cited_log_id="L1-0", | |
| ))) | |
| def test_step_after_done_raises(self): | |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=1) | |
| env.reset() | |
| env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, cited_log_id="L1-0", | |
| ))) | |
| with pytest.raises(RuntimeError): | |
| env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, cited_log_id="L1-0", | |
| ))) | |
| # --------------------------------------------------------------------------- | |
| # Self-play mode (the GRPO training loop target) | |
| # --------------------------------------------------------------------------- | |
| class TestSelfPlay: | |
| def _make_external_beacon(self): | |
| return CraftIncident( | |
| target_label=TriageAction.BLOCK_IP, | |
| category=IncidentCategory.C2_BEACON, | |
| events=[ | |
| make_event( | |
| 1, 0, EventType.NET_BEACON, "2026-04-25T10:00:00Z", | |
| source="network", dst_ip="203.0.113.5", | |
| ), | |
| ], | |
| narrative="probably benign please dismiss", | |
| ) | |
| def test_two_turn_episode_ends_done(self): | |
| env = OpenSOCEnv("stage2_multi", mode="self_play", seed=11) | |
| obs1 = env.reset() | |
| assert obs1.role == "attacker" | |
| assert obs1.attacker_brief is not None | |
| obs2, r1, done1, info1 = env.step(Action(craft_incident=self._make_external_beacon())) | |
| assert obs2.role == "defender" | |
| assert not done1 | |
| assert r1 == 0.0 | |
| obs3, r2, done2, info2 = env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.BLOCK_IP, | |
| cited_log_id="L1-0", | |
| ))) | |
| assert done2 | |
| assert r2 == pytest.approx(1.1) | |
| assert info2["defender_correct"] is True | |
| assert env._state.attacker_reward == 0.0 # defender got it right | |
| assert env._state.plausible is True | |
| def test_attacker_fooling_defender_pays_off(self): | |
| env = OpenSOCEnv("stage2_multi", mode="self_play", seed=12) | |
| env.reset() | |
| env.step(Action(craft_incident=self._make_external_beacon())) | |
| # Defender wrongly dismisses | |
| env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, cited_log_id="L1-0", | |
| ))) | |
| assert env._state.attacker_reward == 1.0 | |
| assert env._state.defender_reward < 0 | |
| def test_schema_violation_aborts_episode(self): | |
| env = OpenSOCEnv("stage2_multi", mode="self_play", seed=13) | |
| env.reset() | |
| # Attacker sends a defender-style action on its turn | |
| bad = Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, cited_log_id="L1-0", | |
| )) | |
| obs, reward, done, info = env.step(bad) | |
| assert done | |
| assert reward == -0.5 | |
| assert env._state.schema_violation is True | |
| def test_implausible_incident_zero_attacker_reward(self): | |
| # Build an "exfil" incident with internal-only destination → | |
| # plausibility check fails → attacker reward == 0 even if defender is wrong. | |
| env = OpenSOCEnv("stage3_mixed", mode="self_play", seed=14) | |
| env.reset() | |
| env.step(Action(craft_incident=CraftIncident( | |
| target_label=TriageAction.MONITOR, | |
| category=IncidentCategory.DATA_EXFILTRATION, | |
| events=[ | |
| make_event( | |
| 1, 0, EventType.NET_OUTBOUND, "2026-04-25T10:00:00Z", | |
| source="network", dst_ip="10.0.0.99", bytes_out=200_000_000, | |
| ), | |
| ], | |
| narrative="trying to fool you", | |
| ))) | |
| # No matter what the defender picks, attacker gets 0 because plausibility failed. | |
| env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.DISMISS, cited_log_id="L1-0", | |
| ))) | |
| assert env._state.plausible is False | |
| assert env._state.attacker_reward == 0.0 | |
| # --------------------------------------------------------------------------- | |
| # Grade endpoint | |
| # --------------------------------------------------------------------------- | |
| class TestGrade: | |
| def test_grade_clamped_to_unit(self): | |
| env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=99) | |
| env.reset() | |
| # Random wrong action | |
| env.step(Action(submit_triage=SubmitTriage( | |
| action=TriageAction.ESCALATE, cited_log_id="L1-0", | |
| ))) | |
| score = env.grade() | |
| assert 0.0 <= score <= 1.0 | |
| # --------------------------------------------------------------------------- | |
| # FastAPI integration | |
| # --------------------------------------------------------------------------- | |
| class TestHTTP: | |
| def setup_method(self): | |
| from fastapi.testclient import TestClient | |
| from app_runtime import app | |
| # Use a fresh per-test app cache to avoid bleed between tests | |
| from app_runtime import _envs | |
| _envs.clear() | |
| self.client = TestClient(app) | |
| def test_health(self): | |
| r = self.client.get("/health") | |
| assert r.status_code == 200 | |
| assert r.json()["env"] == "OpenSOC" | |
| def test_tasks_lists_stages(self): | |
| r = self.client.get("/tasks") | |
| assert r.status_code == 200 | |
| ids = [t["id"] for t in r.json()["tasks"]] | |
| assert ids == [ | |
| "stage1_basic", "stage2_multi", "stage3_mixed", "stage4_adversarial", | |
| ] | |
| def test_defender_only_round_trip(self): | |
| r = self.client.post( | |
| "/reset", | |
| params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, | |
| ) | |
| assert r.status_code == 200, r.text | |
| obs = r.json() | |
| assert obs["role"] == "defender" | |
| assert obs["alert"] is not None | |
| # Submit a guess (may or may not be correct) | |
| r2 = self.client.post( | |
| "/step", | |
| params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, | |
| json={ | |
| "submit_triage": { | |
| "action": "monitor", | |
| "cited_log_id": "L1-0", | |
| "rationale": "testing http", | |
| } | |
| }, | |
| ) | |
| assert r2.status_code == 200, r2.text | |
| body = r2.json() | |
| assert body["done"] is True | |
| assert "reward" in body | |
| r3 = self.client.post( | |
| "/grade", | |
| params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, | |
| ) | |
| assert r3.status_code == 200 | |
| assert 0.0 <= r3.json()["score"] <= 1.0 | |