Demo / tests /test_task7.py
Ajayyy00
Initial commit of CyberSOC upgraded RLVR environment
57e71f8
"""Tests for Task 7 — Tool Router + Triage Solver."""
import os
import sys
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from server.tool_router import (
ToolRouter,
compute_triage_priority,
solve_triage_order,
)
from server.threat_graph import (
ThreatGraph,
AlertNode,
HostNode,
IOCNode,
ProcessNode,
)
def _alert(aid="A1", severity="high", source="WS-001"):
return AlertNode(alert_id=aid, severity=severity, priority_score=1.0, source_host=source)
def _host(name="WS-001", crit="medium", status="compromised"):
return HostNode(hostname=name, subnet="corporate", business_criticality=crit, status=status)
def _full_evidence_graph():
g = ThreatGraph()
g.add_host(_host("WS-001"))
g.add_ioc(IOCNode(ioc_value="1.1.1.1", ioc_type="ip", confidence=0.9))
g.add_process(ProcessNode(process_id="WS-001:1", hostname="WS-001", process_name="x"))
return g
def test_triage_to_investigation_with_alerts():
g = ThreatGraph()
g.add_alert(_alert())
r = ToolRouter()
assert r.next_phase("triage", g, 10) == "investigation"
def test_triage_to_report_no_alerts():
g = ThreatGraph()
r = ToolRouter()
assert r.next_phase("triage", g, 10) == "report"
def test_investigation_loops_then_exits():
g = ThreatGraph() # evidence-free
r = ToolRouter()
out = "investigation"
for _ in range(r.MAX_INVESTIGATION_LOOPS + 1):
out = r.next_phase("investigation", g, 10)
assert out == "remediation"
def test_investigation_exits_on_sufficient_evidence():
r = ToolRouter()
assert r.next_phase("investigation", _full_evidence_graph(), 10) == "remediation"
def test_remediation_exits_when_contained():
r = ToolRouter()
g = ThreatGraph()
g.add_host(_host("WS-001", status="isolated"))
assert r.next_phase("remediation", g, 10) == "report"
def test_report_returns_done():
r = ToolRouter()
assert r.next_phase("report", ThreatGraph(), 10) == "done"
def test_honor_pushback_rejects_no_graph_refs():
r = ToolRouter()
ok, _ = r.honor_pushback("investigation", [], ThreatGraph())
assert ok is False
def test_honor_pushback_accepts_valid_critical_alert():
g = ThreatGraph()
g.add_alert(_alert("A1", severity="critical"))
r = ToolRouter()
ok, _ = r.honor_pushback("investigation", ["A1"], g)
assert ok is True
def test_triage_priority_higher_for_critical():
g = ThreatGraph()
a_crit = _alert("A1", severity="critical")
a_low = _alert("A2", severity="low", source="WS-002")
h_crit = _host("WS-001", crit="critical")
h_low = _host("WS-002", crit="low")
s_crit = compute_triage_priority(a_crit, h_crit, g)
s_low = compute_triage_priority(a_low, h_low, g)
assert s_crit > s_low
def test_solve_triage_order_descending():
g = ThreatGraph()
g.add_host(_host("WS-001", crit="critical"))
g.add_host(_host("WS-002", crit="medium"))
g.add_host(_host("WS-003", crit="low"))
g.add_alert(_alert("A1", severity="critical", source="WS-001"))
g.add_alert(_alert("A2", severity="medium", source="WS-002"))
g.add_alert(_alert("A3", severity="low", source="WS-003"))
order = solve_triage_order(g)
assert order == ["A1", "A2", "A3"]