File size: 7,120 Bytes
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5719ec3
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5719ec3
 
 
57e71f8
 
5719ec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57e71f8
 
 
5719ec3
 
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Integration tests — Task 10."""

import os
import sys
import subprocess

import pytest

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from server.play_environment import CyberSOCEnvironment
from server.episode_sandbox import EpisodeTimeout
from server.graders import grade_episode
from models import SOCActionWrapper, RedActionWrapper


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _first_host(env):
    return next(iter(env._threat_graph.hosts), "WS-042")


def _first_alert(env):
    return next(iter(env._threat_graph.alerts), None)


def _valid_action(action_type, **kwargs):
    return SOCActionWrapper(type=action_type, **kwargs)


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------

def test_easy_episode_completes():
    env = CyberSOCEnvironment()
    obs = env.reset(task_id="easy")
    hostname = _first_host(env)
    for _ in range(5):
        obs = env.step(_valid_action("query_host", hostname=hostname))
    assert obs is not None


def test_medium_episode_with_all_10_actions():
    env = CyberSOCEnvironment()
    obs = env.reset(task_id="medium")
    hostname = _first_host(env)
    ioc_value = next(iter(env._threat_graph.iocs), None)

    # Triage phase
    alerts = list(env._threat_graph.alerts.keys())
    if len(alerts) >= 2:
        obs = env.step(_valid_action("correlate_alerts", alert_ids=alerts[:2]))

    # Investigation phase
    obs = env.step(_valid_action("query_host", hostname=hostname))
    obs = env.step(_valid_action("run_forensics", hostname=hostname))

    if ioc_value:
        obs = env.step(_valid_action("enrich_ioc", ioc_value=ioc_value, ioc_type="ip"))
    obs = env.step(_valid_action("scan_host_vulnerabilities", hostname=hostname))

    # Remediation phase
    if ioc_value:
        obs = env.step(_valid_action("block_ioc", ioc_value=ioc_value, ioc_type="ip"))
    obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="nonexistent.exe"))
    obs = env.step(_valid_action("isolate_segment", subnet="corporate", reason="test"))

    # Trigger playbook (prereqs may or may not be met — just must not crash)
    try:
        obs = env.step(_valid_action("trigger_playbook", playbook_name="c2_disruption", target=hostname))
    except Exception:
        pass

    # Submit plan
    obs = env.step(_valid_action(
        "submit_containment_plan",
        plan=[{"threat_id": "T1", "actions_taken": ["kill"], "root_cause": "malware", "confidence": 0.8}],
        executive_summary="Contained ransomware.",
    ))
    assert obs is not None


def test_phase_violation_returns_error():
    """kill_process during triage should record negative reward (not crash)."""
    env = CyberSOCEnvironment()
    env.reset(task_id="easy")
    hostname = _first_host(env)
    # This action is dispatched — may return low reward but must not raise
    obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="fake.exe"))
    assert obs is not None


def test_lateral_pivot_red_action():
    """LateralPivot RedActionWrapper creates a pivoted_from edge and a SIEM alert."""
    env = CyberSOCEnvironment(fsp_mode=True)
    env.reset(task_id="hard")

    # Find a compromised host to pivot from and a healthy one to pivot to
    src = next(
        (h for h, hd in env._host_index.items() if hd.get("status") == "compromised"),
        None,
    )
    dst = next(
        (h for h, hd in env._host_index.items()
         if hd.get("status") not in ("compromised", "isolated") and h != src),
        None,
    )
    if src is None or dst is None:
        pytest.skip("No suitable host pair for lateral pivot test")

    # Blue takes a PassTurn-equivalent (query) so active_turn flips to red
    env.step(_valid_action("query_host", hostname=src))
    assert env._state.active_turn == "red"

    alerts_before = len(env._alert_queue)
    env.step(RedActionWrapper(type="lateral_pivot", source_host=src, target_host=dst))

    pivot_edges = [e for e in env._threat_graph.edges if e.edge_type == "pivoted_from"]
    assert len(pivot_edges) >= 1
    assert env._host_index[dst]["status"] == "compromised"
    assert len(env._alert_queue) > alerts_before  # SIEM alert generated


def test_step_reward_accumulates():
    env = CyberSOCEnvironment()
    env.reset(task_id="easy")
    hostname = _first_host(env)

    before = env._step_reward_total
    # run_forensics in investigation phase earns +0.10
    for h in list(env._threat_graph.hosts.keys())[:3]:
        if h in env._host_index:
            env.step(_valid_action("run_forensics", hostname=h))

    assert env._step_reward_total > before


def test_step_reward_idempotent():
    env = CyberSOCEnvironment()
    env.reset(task_id="easy")
    hostname = _first_host(env)

    env.step(_valid_action("run_forensics", hostname=hostname))
    after_first = env._step_reward_total
    env.step(_valid_action("run_forensics", hostname=hostname))
    after_second = env._step_reward_total

    # Second call on same host earns 0 extra step reward (though may earn -0.02 from handler)
    step_reward_delta = after_second - after_first
    # The idempotent part: _get_step_reward returns 0 on second call for same triple
    key = ("investigation", "run_forensics", hostname)
    assert key in env._fired_step_rewards


def test_grader_returns_10_dim():
    env = CyberSOCEnvironment()
    env.reset(task_id="easy")
    hostname = _first_host(env)
    env.step(_valid_action("query_host", hostname=hostname))

    result = grade_episode(
        episode_actions=list(env._state.timeline),
        final_plan=None,
        graph=env._threat_graph,
        task_def=env._task_def,
        state=env._state,
    )
    assert len(result["breakdown"]) == 10


def test_sandbox_step_limit():
    from server.episode_sandbox import EpisodeSandbox, MAX_STEPS_PER_EPISODE

    env = CyberSOCEnvironment()
    env.reset(task_id="easy")
    sb = EpisodeSandbox(env)

    with pytest.raises(EpisodeTimeout):
        sb.check_step_limit(MAX_STEPS_PER_EPISODE)


def test_openenv_validate_compatible():
    """Run openenv validate and assert it exits 0 (env is valid with adaptive=False)."""
    python = os.path.join(_PROJECT_ROOT, "..", "venv", "Scripts", "python.exe")
    python = os.path.abspath(python)
    if not os.path.exists(python):
        pytest.skip("venv python not found")

    result = subprocess.run(
        [python, "-c",
         "from server.play_environment import CyberSOCEnvironment; "
         "env = CyberSOCEnvironment(adaptive=False); "
         "obs = env.reset(task_id='easy'); "
         "print('OK', len(obs.alert_queue))"],
        cwd=_PROJECT_ROOT,
        capture_output=True,
        text=True,
        timeout=60,
    )
    assert result.returncode == 0, f"stderr: {result.stderr}"
    assert "OK" in result.stdout