Demo / tests /test_integration.py
Ajayyy00
Add FSP multi-agent architecture: Red Team LLM action space + alternating turns
5719ec3
"""Integration tests — Task 10."""
import os
import sys
import subprocess
import pytest
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from server.play_environment import CyberSOCEnvironment
from server.episode_sandbox import EpisodeTimeout
from server.graders import grade_episode
from models import SOCActionWrapper, RedActionWrapper
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _first_host(env):
return next(iter(env._threat_graph.hosts), "WS-042")
def _first_alert(env):
return next(iter(env._threat_graph.alerts), None)
def _valid_action(action_type, **kwargs):
return SOCActionWrapper(type=action_type, **kwargs)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_easy_episode_completes():
env = CyberSOCEnvironment()
obs = env.reset(task_id="easy")
hostname = _first_host(env)
for _ in range(5):
obs = env.step(_valid_action("query_host", hostname=hostname))
assert obs is not None
def test_medium_episode_with_all_10_actions():
env = CyberSOCEnvironment()
obs = env.reset(task_id="medium")
hostname = _first_host(env)
ioc_value = next(iter(env._threat_graph.iocs), None)
# Triage phase
alerts = list(env._threat_graph.alerts.keys())
if len(alerts) >= 2:
obs = env.step(_valid_action("correlate_alerts", alert_ids=alerts[:2]))
# Investigation phase
obs = env.step(_valid_action("query_host", hostname=hostname))
obs = env.step(_valid_action("run_forensics", hostname=hostname))
if ioc_value:
obs = env.step(_valid_action("enrich_ioc", ioc_value=ioc_value, ioc_type="ip"))
obs = env.step(_valid_action("scan_host_vulnerabilities", hostname=hostname))
# Remediation phase
if ioc_value:
obs = env.step(_valid_action("block_ioc", ioc_value=ioc_value, ioc_type="ip"))
obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="nonexistent.exe"))
obs = env.step(_valid_action("isolate_segment", subnet="corporate", reason="test"))
# Trigger playbook (prereqs may or may not be met — just must not crash)
try:
obs = env.step(_valid_action("trigger_playbook", playbook_name="c2_disruption", target=hostname))
except Exception:
pass
# Submit plan
obs = env.step(_valid_action(
"submit_containment_plan",
plan=[{"threat_id": "T1", "actions_taken": ["kill"], "root_cause": "malware", "confidence": 0.8}],
executive_summary="Contained ransomware.",
))
assert obs is not None
def test_phase_violation_returns_error():
"""kill_process during triage should record negative reward (not crash)."""
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
# This action is dispatched — may return low reward but must not raise
obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="fake.exe"))
assert obs is not None
def test_lateral_pivot_red_action():
"""LateralPivot RedActionWrapper creates a pivoted_from edge and a SIEM alert."""
env = CyberSOCEnvironment(fsp_mode=True)
env.reset(task_id="hard")
# Find a compromised host to pivot from and a healthy one to pivot to
src = next(
(h for h, hd in env._host_index.items() if hd.get("status") == "compromised"),
None,
)
dst = next(
(h for h, hd in env._host_index.items()
if hd.get("status") not in ("compromised", "isolated") and h != src),
None,
)
if src is None or dst is None:
pytest.skip("No suitable host pair for lateral pivot test")
# Blue takes a PassTurn-equivalent (query) so active_turn flips to red
env.step(_valid_action("query_host", hostname=src))
assert env._state.active_turn == "red"
alerts_before = len(env._alert_queue)
env.step(RedActionWrapper(type="lateral_pivot", source_host=src, target_host=dst))
pivot_edges = [e for e in env._threat_graph.edges if e.edge_type == "pivoted_from"]
assert len(pivot_edges) >= 1
assert env._host_index[dst]["status"] == "compromised"
assert len(env._alert_queue) > alerts_before # SIEM alert generated
def test_step_reward_accumulates():
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
before = env._step_reward_total
# run_forensics in investigation phase earns +0.10
for h in list(env._threat_graph.hosts.keys())[:3]:
if h in env._host_index:
env.step(_valid_action("run_forensics", hostname=h))
assert env._step_reward_total > before
def test_step_reward_idempotent():
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
env.step(_valid_action("run_forensics", hostname=hostname))
after_first = env._step_reward_total
env.step(_valid_action("run_forensics", hostname=hostname))
after_second = env._step_reward_total
# Second call on same host earns 0 extra step reward (though may earn -0.02 from handler)
step_reward_delta = after_second - after_first
# The idempotent part: _get_step_reward returns 0 on second call for same triple
key = ("investigation", "run_forensics", hostname)
assert key in env._fired_step_rewards
def test_grader_returns_10_dim():
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
env.step(_valid_action("query_host", hostname=hostname))
result = grade_episode(
episode_actions=list(env._state.timeline),
final_plan=None,
graph=env._threat_graph,
task_def=env._task_def,
state=env._state,
)
assert len(result["breakdown"]) == 10
def test_sandbox_step_limit():
from server.episode_sandbox import EpisodeSandbox, MAX_STEPS_PER_EPISODE
env = CyberSOCEnvironment()
env.reset(task_id="easy")
sb = EpisodeSandbox(env)
with pytest.raises(EpisodeTimeout):
sb.check_step_limit(MAX_STEPS_PER_EPISODE)
def test_openenv_validate_compatible():
"""Run openenv validate and assert it exits 0 (env is valid with adaptive=False)."""
python = os.path.join(_PROJECT_ROOT, "..", "venv", "Scripts", "python.exe")
python = os.path.abspath(python)
if not os.path.exists(python):
pytest.skip("venv python not found")
result = subprocess.run(
[python, "-c",
"from server.play_environment import CyberSOCEnvironment; "
"env = CyberSOCEnvironment(adaptive=False); "
"obs = env.reset(task_id='easy'); "
"print('OK', len(obs.alert_queue))"],
cwd=_PROJECT_ROOT,
capture_output=True,
text=True,
timeout=60,
)
assert result.returncode == 0, f"stderr: {result.stderr}"
assert "OK" in result.stdout