"""End-to-end tests for OpenSOCEnv. Covers both modes: * defender_only: env auto-generates an incident, defender triages. * self_play: attacker turn → defender turn → episode done. Plus FastAPI integration via TestClient. Run with: pytest tests/test_env.py -v """ from __future__ import annotations import os import sys import pytest sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from env import OpenSOCEnv # noqa: E402 from schema import ( # noqa: E402 Action, CraftIncident, EventType, IncidentCategory, SubmitTriage, TriageAction, make_event, ) # --------------------------------------------------------------------------- # Defender-only mode (used for SFT and eval) # --------------------------------------------------------------------------- class TestDefenderOnly: def test_reset_returns_defender_obs(self): env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=42) obs = env.reset() assert obs.role == "defender" assert obs.alert is not None assert len(obs.log_window) >= 1 assert not obs.done def test_correct_triage_full_reward(self): env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=7) obs = env.reset() gt = env._state.ground_truth triggering = env._state.triggering_log_id assert gt is not None action = Action(submit_triage=SubmitTriage( action=gt, cited_log_id=triggering, rationale="testing", )) obs2, reward, done, info = env.step(action) assert done assert reward == pytest.approx(1.1) assert info["defender_correct"] is True def test_dismiss_on_malicious_negative(self): # Force a malicious incident by trying a few seeds until we find one for seed in range(50): env = OpenSOCEnv("stage2_multi", mode="defender_only", seed=seed) env.reset() if env._state.ground_truth is not TriageAction.DISMISS: action = Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id=env._state.triggering_log_id, )) _, reward, done, info = env.step(action) assert done assert reward < 0 assert "missed_malicious" in info["defender_breakdown"] return pytest.skip("could not find a malicious seed in 50 tries") def test_step_before_reset_raises(self): env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=0) with pytest.raises(RuntimeError): env.step(Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id="L1-0", ))) def test_step_after_done_raises(self): env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=1) env.reset() env.step(Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id="L1-0", ))) with pytest.raises(RuntimeError): env.step(Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id="L1-0", ))) # --------------------------------------------------------------------------- # Self-play mode (the GRPO training loop target) # --------------------------------------------------------------------------- class TestSelfPlay: def _make_external_beacon(self): return CraftIncident( target_label=TriageAction.BLOCK_IP, category=IncidentCategory.C2_BEACON, events=[ make_event( 1, 0, EventType.NET_BEACON, "2026-04-25T10:00:00Z", source="network", dst_ip="203.0.113.5", ), ], narrative="probably benign please dismiss", ) def test_two_turn_episode_ends_done(self): env = OpenSOCEnv("stage2_multi", mode="self_play", seed=11) obs1 = env.reset() assert obs1.role == "attacker" assert obs1.attacker_brief is not None obs2, r1, done1, info1 = env.step(Action(craft_incident=self._make_external_beacon())) assert obs2.role == "defender" assert not done1 assert r1 == 0.0 obs3, r2, done2, info2 = env.step(Action(submit_triage=SubmitTriage( action=TriageAction.BLOCK_IP, cited_log_id="L1-0", ))) assert done2 assert r2 == pytest.approx(1.1) assert info2["defender_correct"] is True assert env._state.attacker_reward == 0.0 # defender got it right assert env._state.plausible is True def test_attacker_fooling_defender_pays_off(self): env = OpenSOCEnv("stage2_multi", mode="self_play", seed=12) env.reset() env.step(Action(craft_incident=self._make_external_beacon())) # Defender wrongly dismisses env.step(Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id="L1-0", ))) assert env._state.attacker_reward == 1.0 assert env._state.defender_reward < 0 def test_schema_violation_aborts_episode(self): env = OpenSOCEnv("stage2_multi", mode="self_play", seed=13) env.reset() # Attacker sends a defender-style action on its turn bad = Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id="L1-0", )) obs, reward, done, info = env.step(bad) assert done assert reward == -0.5 assert env._state.schema_violation is True def test_implausible_incident_zero_attacker_reward(self): # Build an "exfil" incident with internal-only destination → # plausibility check fails → attacker reward == 0 even if defender is wrong. env = OpenSOCEnv("stage3_mixed", mode="self_play", seed=14) env.reset() env.step(Action(craft_incident=CraftIncident( target_label=TriageAction.MONITOR, category=IncidentCategory.DATA_EXFILTRATION, events=[ make_event( 1, 0, EventType.NET_OUTBOUND, "2026-04-25T10:00:00Z", source="network", dst_ip="10.0.0.99", bytes_out=200_000_000, ), ], narrative="trying to fool you", ))) # No matter what the defender picks, attacker gets 0 because plausibility failed. env.step(Action(submit_triage=SubmitTriage( action=TriageAction.DISMISS, cited_log_id="L1-0", ))) assert env._state.plausible is False assert env._state.attacker_reward == 0.0 # --------------------------------------------------------------------------- # Grade endpoint # --------------------------------------------------------------------------- class TestGrade: def test_grade_clamped_to_unit(self): env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=99) env.reset() # Random wrong action env.step(Action(submit_triage=SubmitTriage( action=TriageAction.ESCALATE, cited_log_id="L1-0", ))) score = env.grade() assert 0.0 <= score <= 1.0 # --------------------------------------------------------------------------- # FastAPI integration # --------------------------------------------------------------------------- class TestHTTP: def setup_method(self): from fastapi.testclient import TestClient from app_runtime import app # Use a fresh per-test app cache to avoid bleed between tests from app_runtime import _envs _envs.clear() self.client = TestClient(app) def test_health(self): r = self.client.get("/health") assert r.status_code == 200 assert r.json()["env"] == "OpenSOC" def test_tasks_lists_stages(self): r = self.client.get("/tasks") assert r.status_code == 200 ids = [t["id"] for t in r.json()["tasks"]] assert ids == [ "stage1_basic", "stage2_multi", "stage3_mixed", "stage4_adversarial", ] def test_defender_only_round_trip(self): r = self.client.post( "/reset", params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, ) assert r.status_code == 200, r.text obs = r.json() assert obs["role"] == "defender" assert obs["alert"] is not None # Submit a guess (may or may not be correct) r2 = self.client.post( "/step", params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, json={ "submit_triage": { "action": "monitor", "cited_log_id": "L1-0", "rationale": "testing http", } }, ) assert r2.status_code == 200, r2.text body = r2.json() assert body["done"] is True assert "reward" in body r3 = self.client.post( "/grade", params={"task": "stage1_basic", "mode": "defender_only", "seed": 5}, ) assert r3.status_code == 200 assert 0.0 <= r3.json()["score"] <= 1.0