File size: 9,327 Bytes
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""End-to-end tests for OpenSOCEnv.

Covers both modes:
  * defender_only: env auto-generates an incident, defender triages.
  * self_play:     attacker turn → defender turn → episode done.

Plus FastAPI integration via TestClient.

Run with: pytest tests/test_env.py -v
"""

from __future__ import annotations

import os
import sys

import pytest

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from env import OpenSOCEnv  # noqa: E402
from schema import (  # noqa: E402
    Action,
    CraftIncident,
    EventType,
    IncidentCategory,
    SubmitTriage,
    TriageAction,
    make_event,
)


# ---------------------------------------------------------------------------
# Defender-only mode (used for SFT and eval)
# ---------------------------------------------------------------------------

class TestDefenderOnly:
    def test_reset_returns_defender_obs(self):
        env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=42)
        obs = env.reset()
        assert obs.role == "defender"
        assert obs.alert is not None
        assert len(obs.log_window) >= 1
        assert not obs.done

    def test_correct_triage_full_reward(self):
        env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=7)
        obs = env.reset()
        gt = env._state.ground_truth
        triggering = env._state.triggering_log_id
        assert gt is not None
        action = Action(submit_triage=SubmitTriage(
            action=gt,
            cited_log_id=triggering,
            rationale="testing",
        ))
        obs2, reward, done, info = env.step(action)
        assert done
        assert reward == pytest.approx(1.1)
        assert info["defender_correct"] is True

    def test_dismiss_on_malicious_negative(self):
        # Force a malicious incident by trying a few seeds until we find one
        for seed in range(50):
            env = OpenSOCEnv("stage2_multi", mode="defender_only", seed=seed)
            env.reset()
            if env._state.ground_truth is not TriageAction.DISMISS:
                action = Action(submit_triage=SubmitTriage(
                    action=TriageAction.DISMISS,
                    cited_log_id=env._state.triggering_log_id,
                ))
                _, reward, done, info = env.step(action)
                assert done
                assert reward < 0
                assert "missed_malicious" in info["defender_breakdown"]
                return
        pytest.skip("could not find a malicious seed in 50 tries")

    def test_step_before_reset_raises(self):
        env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=0)
        with pytest.raises(RuntimeError):
            env.step(Action(submit_triage=SubmitTriage(
                action=TriageAction.DISMISS, cited_log_id="L1-0",
            )))

    def test_step_after_done_raises(self):
        env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=1)
        env.reset()
        env.step(Action(submit_triage=SubmitTriage(
            action=TriageAction.DISMISS, cited_log_id="L1-0",
        )))
        with pytest.raises(RuntimeError):
            env.step(Action(submit_triage=SubmitTriage(
                action=TriageAction.DISMISS, cited_log_id="L1-0",
            )))


# ---------------------------------------------------------------------------
# Self-play mode (the GRPO training loop target)
# ---------------------------------------------------------------------------

class TestSelfPlay:
    def _make_external_beacon(self):
        return CraftIncident(
            target_label=TriageAction.BLOCK_IP,
            category=IncidentCategory.C2_BEACON,
            events=[
                make_event(
                    1, 0, EventType.NET_BEACON, "2026-04-25T10:00:00Z",
                    source="network", dst_ip="203.0.113.5",
                ),
            ],
            narrative="probably benign please dismiss",
        )

    def test_two_turn_episode_ends_done(self):
        env = OpenSOCEnv("stage2_multi", mode="self_play", seed=11)
        obs1 = env.reset()
        assert obs1.role == "attacker"
        assert obs1.attacker_brief is not None

        obs2, r1, done1, info1 = env.step(Action(craft_incident=self._make_external_beacon()))
        assert obs2.role == "defender"
        assert not done1
        assert r1 == 0.0

        obs3, r2, done2, info2 = env.step(Action(submit_triage=SubmitTriage(
            action=TriageAction.BLOCK_IP,
            cited_log_id="L1-0",
        )))
        assert done2
        assert r2 == pytest.approx(1.1)
        assert info2["defender_correct"] is True
        assert env._state.attacker_reward == 0.0  # defender got it right
        assert env._state.plausible is True

    def test_attacker_fooling_defender_pays_off(self):
        env = OpenSOCEnv("stage2_multi", mode="self_play", seed=12)
        env.reset()
        env.step(Action(craft_incident=self._make_external_beacon()))
        # Defender wrongly dismisses
        env.step(Action(submit_triage=SubmitTriage(
            action=TriageAction.DISMISS, cited_log_id="L1-0",
        )))
        assert env._state.attacker_reward == 1.0
        assert env._state.defender_reward < 0

    def test_schema_violation_aborts_episode(self):
        env = OpenSOCEnv("stage2_multi", mode="self_play", seed=13)
        env.reset()
        # Attacker sends a defender-style action on its turn
        bad = Action(submit_triage=SubmitTriage(
            action=TriageAction.DISMISS, cited_log_id="L1-0",
        ))
        obs, reward, done, info = env.step(bad)
        assert done
        assert reward == -0.5
        assert env._state.schema_violation is True

    def test_implausible_incident_zero_attacker_reward(self):
        # Build an "exfil" incident with internal-only destination →
        # plausibility check fails → attacker reward == 0 even if defender is wrong.
        env = OpenSOCEnv("stage3_mixed", mode="self_play", seed=14)
        env.reset()
        env.step(Action(craft_incident=CraftIncident(
            target_label=TriageAction.MONITOR,
            category=IncidentCategory.DATA_EXFILTRATION,
            events=[
                make_event(
                    1, 0, EventType.NET_OUTBOUND, "2026-04-25T10:00:00Z",
                    source="network", dst_ip="10.0.0.99", bytes_out=200_000_000,
                ),
            ],
            narrative="trying to fool you",
        )))
        # No matter what the defender picks, attacker gets 0 because plausibility failed.
        env.step(Action(submit_triage=SubmitTriage(
            action=TriageAction.DISMISS, cited_log_id="L1-0",
        )))
        assert env._state.plausible is False
        assert env._state.attacker_reward == 0.0


# ---------------------------------------------------------------------------
# Grade endpoint
# ---------------------------------------------------------------------------

class TestGrade:
    def test_grade_clamped_to_unit(self):
        env = OpenSOCEnv("stage1_basic", mode="defender_only", seed=99)
        env.reset()
        # Random wrong action
        env.step(Action(submit_triage=SubmitTriage(
            action=TriageAction.ESCALATE, cited_log_id="L1-0",
        )))
        score = env.grade()
        assert 0.0 <= score <= 1.0


# ---------------------------------------------------------------------------
# FastAPI integration
# ---------------------------------------------------------------------------

class TestHTTP:
    def setup_method(self):
        from fastapi.testclient import TestClient

        from app_runtime import app
        # Use a fresh per-test app cache to avoid bleed between tests
        from app_runtime import _envs
        _envs.clear()
        self.client = TestClient(app)

    def test_health(self):
        r = self.client.get("/health")
        assert r.status_code == 200
        assert r.json()["env"] == "OpenSOC"

    def test_tasks_lists_stages(self):
        r = self.client.get("/tasks")
        assert r.status_code == 200
        ids = [t["id"] for t in r.json()["tasks"]]
        assert ids == [
            "stage1_basic", "stage2_multi", "stage3_mixed", "stage4_adversarial",
        ]

    def test_defender_only_round_trip(self):
        r = self.client.post(
            "/reset",
            params={"task": "stage1_basic", "mode": "defender_only", "seed": 5},
        )
        assert r.status_code == 200, r.text
        obs = r.json()
        assert obs["role"] == "defender"
        assert obs["alert"] is not None

        # Submit a guess (may or may not be correct)
        r2 = self.client.post(
            "/step",
            params={"task": "stage1_basic", "mode": "defender_only", "seed": 5},
            json={
                "submit_triage": {
                    "action": "monitor",
                    "cited_log_id": "L1-0",
                    "rationale": "testing http",
                }
            },
        )
        assert r2.status_code == 200, r2.text
        body = r2.json()
        assert body["done"] is True
        assert "reward" in body

        r3 = self.client.post(
            "/grade",
            params={"task": "stage1_basic", "mode": "defender_only", "seed": 5},
        )
        assert r3.status_code == 200
        assert 0.0 <= r3.json()["score"] <= 1.0