Demo / server /play_environment.py
Ajayyy00
Add hotseat multiplayer Red Team controls and 4 architectural fixes
a144947
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
CyberSOCEnv — Enterprise Cybersecurity Operations Center Environment.
Implements the OpenEnv Environment interface for a deterministic SOC
incident response simulation on a 500-node enterprise network.
The agent receives SIEM/EDR alerts, queries hosts, runs forensics,
isolates segments, blocks IOCs, kills processes, and submits a
containment plan — all while minimizing business downtime.
"""
from __future__ import annotations
import copy
import random
import uuid
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import (
SOCObservation,
SOCActionWrapper,
SOCState,
Alert,
NetworkTopology,
ForensicsResult,
TimelineEntry,
QueryHost,
IsolateSegment,
BlockIOC,
RunForensics,
KillProcess,
SubmitContainmentPlan,
CorrelateAlerts,
EnrichIOC,
ScanHostVulnerabilities,
TerminatePID,
CreateFirewallRule,
QuarantineFile,
RedActionWrapper,
LateralPivot,
DeployPayload,
EvadeDetection,
PassTurn,
RED_ACTION_TYPES,
)
except ImportError:
from models import (
SOCObservation,
SOCActionWrapper,
SOCState,
Alert,
NetworkTopology,
ForensicsResult,
TimelineEntry,
QueryHost,
IsolateSegment,
BlockIOC,
RunForensics,
KillProcess,
SubmitContainmentPlan,
CorrelateAlerts,
EnrichIOC,
ScanHostVulnerabilities,
TerminatePID,
CreateFirewallRule,
QuarantineFile,
RedActionWrapper,
LateralPivot,
DeployPayload,
EvadeDetection,
PassTurn,
RED_ACTION_TYPES,
)
from .tasks import get_task, build_network
from .graders import grade_episode
from .threat_graph import (
ThreatGraph,
HostNode,
ProcessNode,
IOCNode,
VulnerabilityNode,
AlertNode,
Edge,
)
class ActionMiddleware:
"""Pre-flight validation for SOC actions.
Detects phase violations (action out of order) and graph-ungrounded actions
(action references an entity not yet discovered in the ThreatGraph).
Returns None if the action is valid, or an error dict otherwise.
"""
def validate(
self,
current_phase: str,
action_type: str,
args: Dict[str, Any],
graph,
) -> Optional[Dict[str, str]]:
# Phase violation: plan submission before any investigation
if action_type == "submit_containment_plan" and current_phase == "triage":
return {
"error_type": "PHASE_VIOLATION",
"message": "submit_containment_plan requires investigation phase first",
}
# Graph-groundedness: IOC must be discovered before enrichment
if action_type == "enrich_ioc":
ioc_val = args.get("ioc_value", "")
if ioc_val and graph is not None and ioc_val not in graph.iocs:
return {
"error_type": "GRAPH_FAILURE",
"message": f"IOC '{ioc_val}' not in threat graph; receive an alert or run forensics first",
}
# Graph-groundedness: host must be known before vulnerability scan
if action_type == "scan_host_vulnerabilities":
hostname = args.get("hostname", "")
if hostname and graph is not None and hostname not in graph.hosts:
return {
"error_type": "GRAPH_FAILURE",
"message": f"Host '{hostname}' not in threat graph; run query_host first",
}
# Emergency isolation gate: allow early isolate_segment only when a critical
# alert proves an active threat on the targeted subnet/host; otherwise penalise
# the panic as UNJUSTIFIED_EMERGENCY.
if action_type == "isolate_segment" and current_phase == "triage":
subnet = args.get("subnet", "")
target_host = args.get("target_host", "")
has_critical = False
if graph is not None:
for alert in graph.alerts.values():
if alert.severity != "critical":
continue
src = alert.source_host
if target_host and src == target_host:
has_critical = True
break
if subnet and src in graph.hosts:
host_node = graph.hosts.get(src)
if host_node and getattr(host_node, "subnet", "") == subnet:
has_critical = True
break
if not has_critical:
return {
"error_type": "UNJUSTIFIED_EMERGENCY",
"message": (
"isolate_segment during triage requires a critical-severity alert "
"on the targeted subnet/host to justify emergency response"
),
}
return None
class CyberSOCEnvironment(Environment):
"""
Deterministic SOC incident response environment.
Simulates a 500-node enterprise network under attack. The agent must
investigate alerts, contain threats, and submit a containment plan
while minimizing business downtime.
Supports concurrent WebSocket sessions (each gets own instance).
Example:
>>> env = CyberSOCEnvironment()
>>> obs = env.reset(task_id="easy")
>>> print(len(obs.alert_queue)) # Initial alerts
>>> obs = env.step(SOCActionWrapper(type="query_host", hostname="WS-042"))
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
adaptive: bool = False,
neural_red_policy: Optional[Any] = None,
red_team_logger: Optional[Callable[[Dict[str, Any]], None]] = None,
fsp_mode: bool = False,
):
"""Initialize the environment (actual state set in reset).
Args:
adaptive: Legacy adaptive-adversary flag (kept for backward compat).
neural_red_policy: Optional callable for neural Red policy (legacy hook).
red_team_logger: Optional callback for recording Red decisions.
fsp_mode: When True, step() uses strict alternating turns and
step_count only increments after BOTH Blue and Red have acted.
When False (default), step(SOCActionWrapper) behaves exactly as
before — Red's PassTurn is applied automatically so existing code
and tests remain unaffected.
"""
super().__init__()
self._adaptive = adaptive
self._neural_red_policy = neural_red_policy
self._red_team_logger = red_team_logger
self._fsp_mode = fsp_mode
self._red_team_decisions: List[Dict[str, Any]] = []
self._live_requirements: Dict[str, Any] = {}
self._threat_graph = None # will be initialized on reset()
self._state = SOCState(episode_id=str(uuid4()), step_count=0)
self._network: Dict[str, List[Dict[str, Any]]] = {}
self._task_def: Dict[str, Any] = {}
self._alert_queue: List[Dict[str, Any]] = []
self._host_index: Dict[str, Dict[str, Any]] = {} # hostname -> host dict
self._plan_entries: List[Dict[str, Any]] = []
self._last_forensics: Optional[ForensicsResult] = None
self._middleware = ActionMiddleware()
self._rng = random.Random(0) # overwritten in reset()
self._pending_followup: Dict[str, bool] = {} # hostname -> responded_to
self._disruption_cost: float = 0.0 # accumulates per clean host/subnet isolated
self._discovered_iocs: set = set() # IOCs revealed via run_forensics or enrich_ioc
self._quarantined_files: set[tuple[str, str]] = set()
self._step_reward_total: float = 0.0
def _reset_rubric(self):
"""Initialize live containment requirements for dynamic grading in adaptive mode."""
import copy
self._live_requirements = copy.deepcopy(
self._task_def.get("containment_requirements", {})
)
# ===========================================================================
# reset()
# ===========================================================================
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> SOCObservation:
"""Reset the environment for a specific task.
Args:
seed: Ignored (environment is fully deterministic).
episode_id: Optional custom episode ID.
**kwargs: Must include task_id ('easy', 'medium', or 'hard').
Returns:
Initial SOCObservation with alert queue and network state.
"""
task_id = kwargs.get("task_id", "easy")
self._rng = random.Random(hash(task_id))
self._task_def = get_task(task_id)
self._recent_actions = [] # reset stall detector
# Build deterministic network (cached per task for GRPO throughput)
if not hasattr(CyberSOCEnvironment, "_network_cache"):
CyberSOCEnvironment._network_cache = {}
cache_key = task_id
if cache_key in CyberSOCEnvironment._network_cache:
self._network = copy.deepcopy(CyberSOCEnvironment._network_cache[cache_key])
else:
self._network = build_network()
CyberSOCEnvironment._network_cache[cache_key] = copy.deepcopy(self._network)
# Build hostname index for O(1) lookups
self._host_index = {}
for subnet_name, hosts in self._network.items():
for host in hosts:
self._host_index[host["hostname"]] = host
# Inject attack chain: mark compromised hosts, add malicious processes
for threat in self._task_def["attack_chain"]:
for hostname in threat["compromised_hosts"]:
if hostname in self._host_index:
host = self._host_index[hostname]
host["status"] = "compromised"
for proc in threat["malicious_processes"]:
if proc not in host["running_processes"]:
host["running_processes"].append(proc)
# Initialize alert queue (deep copy so mutations don't affect task def)
self._alert_queue = copy.deepcopy(self._task_def["initial_alerts"])
# Reset state
eid = episode_id or str(uuid4())
self._state = SOCState(
episode_id=eid,
step_count=0,
task_id=task_id,
max_steps=self._task_def["max_steps"],
total_reward=0.0,
business_impact=self._task_def["initial_business_impact"],
contained_threats=[],
active_threats=[t["threat_id"] for t in self._task_def["attack_chain"]],
blocked_iocs=[],
isolated_subnets=[],
forensics_run=[],
killed_processes=[],
queried_hosts=[],
timeline=[],
is_done=False,
submitted_plan=False,
active_turn="blue",
)
self._plan_entries = []
self._last_forensics = None
self._reset_rubric()
self._fired_step_rewards: set = set()
self._step_reward_total: float = 0.0
self._pending_followup: Dict[str, bool] = {}
self._disruption_cost = 0.0
self._discovered_iocs: set = set()
self._quarantined_files: set[tuple[str, str]] = set()
self._red_team_decisions = []
# Initialize threat graph from task definition
self._threat_graph = ThreatGraph()
self._populate_threat_graph()
# Inject external threat-intel feed IOCs so Blue can immediately enrich/block them
# without hitting GRAPH_FAILURE (simulates acting on CISA or partner feed data).
for ioc_entry in self._task_def.get("external_intel_feed", []) or []:
if isinstance(ioc_entry, str):
ioc_value = ioc_entry
parts = ioc_entry.split(".")
if len(parts) == 4 and all(p.isdigit() for p in parts):
ioc_type = "ip"
elif len(ioc_entry) >= 32 and "." not in ioc_entry:
ioc_type = "hash"
else:
ioc_type = "domain"
elif isinstance(ioc_entry, dict):
ioc_value = ioc_entry.get("value", "")
ioc_type = ioc_entry.get("type", "ip")
else:
continue
if not ioc_value:
continue
if ioc_value not in self._threat_graph.iocs:
self._threat_graph.add_ioc(
IOCNode(ioc_value=ioc_value, ioc_type=ioc_type, confidence=0.70)
)
self._discovered_iocs.add(ioc_value)
self._last_obs_extras: Dict[str, Any] = {}
return self._build_observation(reward=0.0, done=False)
def _populate_threat_graph(self) -> None:
"""Seed the threat graph with hosts, processes, IOCs, and alerts from task_def."""
graph = self._threat_graph
# Hosts: include compromised hosts from attack chain + every host they live on
compromised_set: set[str] = set()
for threat in self._task_def.get("attack_chain", []):
for hn in threat.get("compromised_hosts", []):
compromised_set.add(hn)
for hostname in compromised_set:
host_dict = self._host_index.get(hostname)
if host_dict is None:
continue
graph.add_host(HostNode(
hostname=hostname,
subnet=host_dict.get("subnet", "corporate"),
business_criticality="high" if host_dict.get("criticality", 0.5) >= 0.7 else "medium",
status="compromised",
))
# Processes: malicious processes per compromised host
for threat in self._task_def.get("attack_chain", []):
tid = threat.get("threat_id", "T?")
for hostname in threat.get("compromised_hosts", []):
if hostname not in graph.hosts:
continue
for proc in threat.get("malicious_processes", []):
pid = f"{hostname}:{proc}"
if pid not in graph.processes:
graph.add_process(ProcessNode(
process_id=pid,
hostname=hostname,
process_name=proc,
))
# Add part_of_chain edge
graph.add_edge(Edge(
edge_type="part_of_chain",
source_id=tid,
target_id=hostname,
))
# IOCs from attack chain
for threat in self._task_def.get("attack_chain", []):
iocs = threat.get("iocs", {}) or {}
for ioc_value in iocs.get("hashes", []):
if ioc_value not in graph.iocs:
graph.add_ioc(IOCNode(ioc_value=ioc_value, ioc_type="hash", confidence=0.85))
for ioc_value in iocs.get("ips", []):
if ioc_value not in graph.iocs:
graph.add_ioc(IOCNode(ioc_value=ioc_value, ioc_type="ip", confidence=0.85))
for ioc_value in iocs.get("domains", []):
if ioc_value not in graph.iocs:
graph.add_ioc(IOCNode(ioc_value=ioc_value, ioc_type="domain", confidence=0.85))
for c2 in threat.get("c2_servers", []):
if c2 not in graph.iocs:
graph.add_ioc(IOCNode(ioc_value=c2, ioc_type="ip", confidence=0.95))
# Alerts
for a in self._task_def.get("initial_alerts", []):
aid = a.get("alert_id")
if aid and aid not in graph.alerts:
graph.add_alert(AlertNode(
alert_id=aid,
severity=a.get("severity", "medium"),
priority_score=1.0,
source_host=a.get("source_host", ""),
))
# ===========================================================================
# step()
# ===========================================================================
def step(
self,
action, # SOCActionWrapper | RedActionWrapper
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> SOCObservation:
"""Process one agent action — Blue (SOCActionWrapper) or Red (RedActionWrapper).
Turn semantics (fsp_mode=True):
• Blue step: execute, flip active_turn → 'red', do NOT increment step_count.
• Red step: execute, flip active_turn → 'blue', increment step_count.
When fsp_mode=False (default / backward-compat):
• Blue step auto-applies a Red PassTurn so step_count always increments,
preserving all existing test and dashboard behaviour.
Returns:
SOCObservation; includes active_turn and red_observation fields.
"""
if self._state.is_done:
return self._build_observation(reward=0.0, done=True)
if isinstance(action, RedActionWrapper):
return self._step_red(action)
return self._step_blue(action)
# ------------------------------------------------------------------
# _step_blue — execute a Blue (SOC analyst) action
# ------------------------------------------------------------------
def _step_blue(
self,
action: SOCActionWrapper,
) -> SOCObservation:
"""Execute one Blue turn."""
# Convert wrapper to typed action — gracefully handle hallucinated
# action types or wrong parameters from the LLM instead of crashing.
try:
typed_action = action.to_typed_action()
except Exception as exc:
# Return a negative reward signal so GRPO can learn from the mistake
penalty = -0.2
self._state.total_reward += penalty
self._state.timeline.append({
"step": self._state.step_count + 1,
"action_type": getattr(action, "type", "unknown"),
"target": "N/A",
"result": f"INVALID_ACTION: {exc}",
"reward": penalty,
})
self._state.step_count += 1
return self._build_observation(reward=penalty, done=False)
args = typed_action.model_dump(exclude={"metadata", "type"})
# Pre-flight validation — penalise without consuming a step
current_phase = self._get_current_phase()
validation_error = self._middleware.validate(
current_phase, typed_action.type, args, self._threat_graph
)
if validation_error:
error_type = validation_error.get("error_type", "")
if error_type == "PHASE_VIOLATION":
penalty = -0.10
elif error_type == "UNJUSTIFIED_EMERGENCY":
penalty = -0.15
else:
penalty = -0.05
self._state.total_reward += penalty
return self._build_observation(reward=penalty, done=False)
# Reset per-step extras
self._last_obs_extras = {}
# Dispatch to Blue handler
reward = 0.0
result_description = "unknown action"
if isinstance(typed_action, QueryHost):
reward, result_description = self._handle_query_host(typed_action)
elif isinstance(typed_action, IsolateSegment):
reward, result_description = self._handle_isolate_segment(typed_action)
elif isinstance(typed_action, BlockIOC):
reward, result_description = self._handle_block_ioc(typed_action)
elif isinstance(typed_action, RunForensics):
reward, result_description = self._handle_run_forensics(typed_action)
elif isinstance(typed_action, KillProcess):
reward, result_description = self._handle_kill_process(typed_action)
elif isinstance(typed_action, SubmitContainmentPlan):
reward, result_description = self._handle_submit_plan(typed_action)
elif isinstance(typed_action, CorrelateAlerts):
result = self._handle_correlate_alerts(typed_action)
self._last_obs_extras.update(result)
reward = 0.05 if "error" not in result else -0.05
result_description = result.get("description", "correlate_alerts")
elif isinstance(typed_action, EnrichIOC):
result = self._handle_enrich_ioc(typed_action)
self._last_obs_extras.update(result)
reward = 0.05 if "error" not in result else -0.05
result_description = result.get("description", "enrich_ioc")
elif isinstance(typed_action, ScanHostVulnerabilities):
result = self._handle_scan_vulnerabilities(typed_action)
self._last_obs_extras.update(result)
reward = 0.05 if "error" not in result else -0.05
result_description = result.get("description", "scan_host_vulnerabilities")
elif isinstance(typed_action, TerminatePID):
reward, result_description = self._handle_terminate_pid(typed_action)
elif isinstance(typed_action, CreateFirewallRule):
reward, result_description = self._handle_create_firewall_rule(typed_action)
elif isinstance(typed_action, QuarantineFile):
reward, result_description = self._handle_quarantine_file(typed_action)
# Idempotent step reward
target = self._get_action_target(typed_action)
step_r = self._get_step_reward(
phase="investigation", action_type=typed_action.type, target=target
)
reward += step_r
self._step_reward_total += step_r
# Stall detection: penalise 3+ consecutive identical actions
stall_key = (typed_action.type, target)
if not hasattr(self, "_recent_actions"):
self._recent_actions = []
self._recent_actions.append(stall_key)
if len(self._recent_actions) >= 3:
last_three = self._recent_actions[-3:]
if last_three[0] == last_three[1] == last_three[2]:
reward -= 0.05
# Business impact grows each step (attacker progresses)
if not self._state.is_done:
impact_rate = self._task_def.get("impact_per_step", 0.02)
active_ratio = len(self._state.active_threats) / max(
1, len(self._task_def.get("attack_chain", []))
)
self._state.business_impact = min(
1.0, self._state.business_impact + impact_rate * active_ratio
)
# Round label: step_count+1 = current round being played (not yet closed)
round_label = self._state.step_count + 1
# Record timeline
self._state.timeline.append({
"step": round_label,
"action_type": typed_action.type,
"target": target,
"result": result_description,
"reward": reward,
})
# Accumulate reward
self._state.total_reward += reward
# Check if episode ends due to Blue action (plan submission)
done = False
if self._state.submitted_plan:
done = True
self._state.is_done = True
self._state.active_turn = "blue" # episode over — keep at blue
# In non-FSP mode, still increment step_count for consistency
if not self._fsp_mode:
self._state.step_count += 1
return self._build_observation(reward=reward, done=done)
# Flip turn to Red
self._state.active_turn = "red"
# fsp_mode=False (backward compat): auto-apply Red PassTurn so
# callers that only drive Blue see step_count increment as before.
if not self._fsp_mode:
# Embedded Red dynamics: execute neural or deterministic policy.
# Only fires when a policy is wired (training) or adaptive=True (SFT).
if self._neural_red_policy is not None or self._adaptive:
self._apply_red_team_dynamics(typed_action.type, target)
self._state.step_count += 1
self._state.active_turn = "blue"
# Timeout check (done after Red's "auto turn")
if self._state.step_count >= self._state.max_steps:
reward -= 0.20
self._state.total_reward -= 0.20
self._state.is_done = True
done = True
return self._build_observation(reward=reward, done=done)
# ------------------------------------------------------------------
# _step_red — execute a Red Team action
# ------------------------------------------------------------------
def _step_red(self, action: RedActionWrapper) -> SOCObservation:
"""Execute one Red turn. Only valid when active_turn == 'red'."""
if self._state.active_turn != "red":
# Wrong turn — return current obs with 0 reward (no state change)
return self._build_observation(reward=0.0, done=False)
typed_action = action.to_typed_action()
self._last_obs_extras = {}
reward = 0.0
result_description = "red: noop"
if isinstance(typed_action, LateralPivot):
reward, result_description = self._handle_lateral_pivot(typed_action)
elif isinstance(typed_action, DeployPayload):
reward, result_description = self._handle_deploy_payload(typed_action)
elif isinstance(typed_action, EvadeDetection):
reward, result_description = self._handle_evade_detection(typed_action)
elif isinstance(typed_action, PassTurn):
reward, result_description = self._handle_pass_turn(typed_action)
# Close the round: increment step_count, flip turn back to Blue
self._state.step_count += 1
self._state.active_turn = "blue"
# Record Red's action in timeline (prefixed with "red:" to distinguish)
self._state.timeline.append({
"step": self._state.step_count,
"action_type": f"red:{typed_action.type}",
"target": self._get_red_action_target(typed_action),
"result": result_description,
"reward": 0.0, # Red actions don't add to Blue's reward total
})
# Timeout check after the full round
done = False
if self._state.step_count >= self._state.max_steps:
done = True
self._state.is_done = True
return self._build_observation(reward=reward, done=done)
# ===========================================================================
# Action Handlers (return (reward, description))
# ===========================================================================
def _handle_query_host(self, action: QueryHost) -> tuple[float, str]:
"""Query a host for status info."""
hostname = action.hostname
self._last_forensics = None # Clear forensics from previous step
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found in network"
host = self._host_index[hostname]
# Reward for querying compromised hosts (useful investigation)
reward = 0.0
if host["status"] == "compromised" and hostname not in self._state.queried_hosts:
reward = 0.05 # Good: investigating a compromised host
elif hostname in self._state.queried_hosts:
reward = -0.02 # Penalty: re-querying same host wastes time
self._state.queried_hosts.append(hostname)
# Enhanced observation extras: process_tree + network_connections from graph
process_tree = []
if self._threat_graph is not None:
for p in self._threat_graph.processes.values():
if p.hostname == hostname:
process_tree.append({
"process_id": p.process_id,
"process_name": p.process_name,
"killed": p.killed,
})
network_connections = []
if self._threat_graph is not None:
for e in self._threat_graph.edges:
if e.edge_type == "communicates_with" and (
e.source_id == hostname or e.target_id == hostname
):
other = e.target_id if e.source_id == hostname else e.source_id
if other in self._threat_graph.iocs:
network_connections.append(other)
self._last_obs_extras["process_tree"] = process_tree
self._last_obs_extras["network_connections"] = network_connections
return reward, f"Queried {hostname}: status={host['status']}, procs={len(host['running_processes'])}"
def _handle_isolate_segment(self, action: IsolateSegment) -> tuple[float, str]:
"""Isolate a network segment, or a single host if target_host is set."""
self._last_forensics = None
# Single-host isolation path
target_host = getattr(action, "target_host", None)
if target_host:
if target_host not in self._host_index:
return -0.05, f"Host '{target_host}' not found"
self._host_index[target_host]["status"] = "isolated"
if self._threat_graph is not None and target_host in self._threat_graph.hosts:
self._threat_graph.hosts[target_host].status = "isolated"
if target_host in self._pending_followup:
self._pending_followup[target_host] = True
# Penalise isolating a clean (non-compromised) host — business disruption
compromised_host_set = {
h for threat in self._task_def["attack_chain"]
for h in threat.get("compromised_hosts", [])
}
if target_host not in compromised_host_set:
self._disruption_cost += 0.35
self._state.business_impact = min(1.0, self._state.business_impact + 0.10)
return -0.35, (
f"Isolated clean host '{target_host}' — unjustified business disruption "
f"(cumulative cost={self._disruption_cost:.2f})"
)
return 0.10, f"Isolated single host '{target_host}'"
subnet = action.subnet
if subnet not in self._network:
return -0.05, f"Subnet '{subnet}' does not exist"
if subnet in self._state.isolated_subnets:
return -0.02, f"Subnet '{subnet}' is already isolated"
# Build compromised host set for disruption tracking
compromised_host_set = {
h for threat in self._task_def["attack_chain"]
for h in threat.get("compromised_hosts", [])
}
# Isolate all hosts in the subnet; count clean hosts for disruption cost
clean_isolated_count = 0
for host in self._network[subnet]:
host["status"] = "isolated"
if self._threat_graph is not None and host["hostname"] in self._threat_graph.hosts:
self._threat_graph.hosts[host["hostname"]].status = "isolated"
if host["hostname"] in self._pending_followup:
self._pending_followup[host["hostname"]] = True
if host["hostname"] not in compromised_host_set:
clean_isolated_count += 1
self._state.isolated_subnets.append(subnet)
# Accumulate disruption cost for each clean host swept up in the isolation
if clean_isolated_count > 0:
self._disruption_cost += 0.25 * clean_isolated_count
self._state.business_impact = min(
1.0, self._state.business_impact + 0.05 * clean_isolated_count
)
# Check if this contains any active threats
reward = 0.0
threats_contained = []
for threat in self._task_def["attack_chain"]:
if threat["threat_id"] in self._state.active_threats:
# Check if any compromised hosts are in this subnet
for ch in threat["compromised_hosts"]:
if ch in self._host_index and self._host_index[ch]["subnet"] == subnet:
threats_contained.append(threat["threat_id"])
break
if threats_contained:
# Reduced reward — isolation is a blunt instrument; prefer kill_process / block_ioc
reward = 0.07 * len(threats_contained)
for tid in threats_contained:
if tid not in self._state.contained_threats:
self._state.contained_threats.append(tid)
if tid in self._state.active_threats:
self._state.active_threats.remove(tid)
# Heavy per-clean-host penalty to deter blunt-force isolation spam
if clean_isolated_count > 0:
reward -= 0.25 * clean_isolated_count
# Additional penalty for explicitly prohibited isolation
must_not_isolate = self._task_def["containment_requirements"].get("must_not_isolate", [])
if subnet in must_not_isolate:
reward -= 0.10
self._state.business_impact = min(1.0, self._state.business_impact + 0.08)
return reward, (
f"Isolated subnet '{subnet}'. Threats contained: {threats_contained}. "
f"Clean hosts disrupted: {clean_isolated_count} "
f"(cumulative cost={self._disruption_cost:.2f})"
)
def _handle_block_ioc(self, action: BlockIOC) -> tuple[float, str]:
"""Block an IOC at the perimeter.
Requires prior discovery via run_forensics or enrich_ioc; blind blocks
are recorded but yield 0 reward to prevent reward hacking.
"""
ioc = action.ioc_value
self._last_forensics = None
if ioc in self._state.blocked_iocs:
return -0.02, f"IOC '{ioc}' is already blocked"
# Prerequisite gate: IOC must have been discovered via run_forensics or enrich_ioc
if ioc not in self._discovered_iocs:
self._state.blocked_iocs.append(ioc) # record the block, but no reward
return 0.0, (
f"IOC '{ioc}' blocked without prior investigation — 0 reward "
"(run_forensics or enrich_ioc required to unlock reward)"
)
self._state.blocked_iocs.append(ioc)
# Mark forensics-confirmed hosts as responded-to — only valid for discovered IOCs,
# ensuring _pending_followup accurately reflects investigated-then-actioned flow
for hostname, responded in list(self._pending_followup.items()):
if responded:
continue
for threat in self._task_def["attack_chain"]:
if hostname in threat["compromised_hosts"]:
all_threat_iocs = (
threat["iocs"].get("hashes", [])
+ threat["iocs"].get("ips", [])
+ threat["iocs"].get("domains", [])
+ threat.get("c2_servers", [])
)
if ioc in all_threat_iocs:
self._pending_followup[hostname] = True
break
# Boosted rewards: surgical strikes are heavily preferred over blunt isolation
reward = 0.0
relevant = False
for threat in self._task_def["attack_chain"]:
all_iocs = (
threat["iocs"].get("hashes", [])
+ threat["iocs"].get("ips", [])
+ threat["iocs"].get("domains", [])
)
if ioc in all_iocs:
relevant = True
if ioc in threat.get("c2_servers", []):
reward += 0.30 # High value: severing C2 command channel
else:
reward += 0.20 # Good: blocking an investigated IOC
break
if not relevant:
reward = -0.03 # Noise: blocking irrelevant IOC
return reward, f"Blocked IOC '{ioc}' (type={action.ioc_type}). Relevant: {relevant}"
def _handle_run_forensics(self, action: RunForensics) -> tuple[float, str]:
"""Run forensic analysis on a host."""
hostname = action.hostname
if hostname not in self._host_index:
self._last_forensics = None
return -0.05, f"Host '{hostname}' not found"
host = self._host_index[hostname]
# Build forensics result based on actual host state
is_compromised = host["status"] == "compromised"
malicious_procs = []
suspicious_files = []
network_conns = []
registry_mods = []
memory_artifacts = []
if is_compromised:
# Find which threat(s) affect this host
for threat in self._task_def["attack_chain"]:
if hostname in threat["compromised_hosts"]:
malicious_procs.extend(threat["malicious_processes"])
# Generate deterministic forensic artifacts
for proc in threat["malicious_processes"]:
suspicious_files.append(f"C:\\Windows\\Temp\\{proc}.dat")
registry_mods.append(f"HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run\\{proc}")
for c2 in threat.get("c2_servers", []):
network_conns.append(f"{c2}:443")
for ioc_hash in threat["iocs"].get("hashes", []):
memory_artifacts.append(f"memory_inject_{ioc_hash[:8]}")
self._last_forensics = ForensicsResult(
hostname=hostname,
malicious_processes=malicious_procs,
suspicious_files=suspicious_files,
network_connections=network_conns,
registry_modifications=registry_mods,
memory_artifacts=memory_artifacts,
is_compromised=is_compromised,
)
# Reward
reward = 0.0
if hostname not in self._state.forensics_run:
if is_compromised:
reward = 0.10 # Good: found evidence
self._pending_followup.setdefault(hostname, False) # needs response action
# Reveal all IOCs tied to this host's threat chain so block_ioc can earn reward
for threat in self._task_def["attack_chain"]:
if hostname in threat.get("compromised_hosts", []):
for ioc in (
threat["iocs"].get("hashes", [])
+ threat["iocs"].get("ips", [])
+ threat["iocs"].get("domains", [])
+ threat.get("c2_servers", [])
):
self._discovered_iocs.add(ioc)
else:
reward = 0.02 # Cleared a host (some value)
self._state.forensics_run.append(hostname)
else:
reward = -0.02 # Re-running forensics wastes time
# Enhanced: behavioral_chain and network_flows from graph
behavioral_chain = []
network_flows = []
if self._threat_graph is not None:
for e in self._threat_graph.edges:
if e.source_id == hostname or e.target_id == hostname:
behavioral_chain.append({
"edge_type": e.edge_type,
"source_id": e.source_id,
"target_id": e.target_id,
})
for e in self._threat_graph.edges:
if e.edge_type == "communicates_with":
if e.source_id == hostname or e.target_id == hostname:
other = e.target_id if e.source_id == hostname else e.source_id
if other in self._threat_graph.iocs:
network_flows.append(other)
self._last_obs_extras["behavioral_chain"] = behavioral_chain
self._last_obs_extras["network_flows"] = network_flows
return reward, f"Forensics on {hostname}: compromised={is_compromised}, procs={malicious_procs}"
def _handle_kill_process(self, action: KillProcess) -> tuple[float, str]:
"""Kill a process on a host."""
hostname = action.hostname
process = action.process_name
self._last_forensics = None
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found"
host = self._host_index[hostname]
if host["status"] == "isolated":
return -0.02, f"Host '{hostname}' is isolated — cannot interact"
if process not in host["running_processes"]:
return -0.03, f"Process '{process}' not running on {hostname}"
# Kill the process
host["running_processes"].remove(process)
self._state.killed_processes.append({"hostname": hostname, "process": process})
if hostname in self._pending_followup:
self._pending_followup[hostname] = True
# Check if this was a malicious process
reward = 0.0
was_malicious = False
for threat in self._task_def["attack_chain"]:
if hostname in threat["compromised_hosts"] and process in threat["malicious_processes"]:
was_malicious = True
reward = 0.25 # Surgical strike: high reward for targeted process kill
# Check if all processes for this threat are killed
all_killed = True
for th_host in threat["compromised_hosts"]:
for th_proc in threat["malicious_processes"]:
still_running = (
th_host in self._host_index
and th_proc in self._host_index[th_host]["running_processes"]
)
if still_running:
all_killed = False
break
if all_killed and threat["threat_id"] in self._state.active_threats:
self._state.active_threats.remove(threat["threat_id"])
if threat["threat_id"] not in self._state.contained_threats:
self._state.contained_threats.append(threat["threat_id"])
reward += 0.15 # Bonus: fully contained a threat via surgical action
break
if not was_malicious:
reward = -0.08 # Penalty: killing legitimate process = downtime
self._state.business_impact = min(1.0, self._state.business_impact + 0.03)
return reward, f"Killed '{process}' on {hostname}. Malicious: {was_malicious}"
def _handle_terminate_pid(self, action: TerminatePID) -> tuple[float, str]:
"""Terminate a process by PID. PID is mapped to process name in this simulation."""
hostname = action.hostname
pid = action.pid
self._last_forensics = None
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found"
host = self._host_index[hostname]
if host["status"] == "isolated":
return -0.02, f"Host '{hostname}' is isolated - cannot interact"
process_name = pid
if ":" in pid:
pid_host, _, pid_proc = pid.partition(":")
if pid_host == hostname and pid_proc:
process_name = pid_proc
if process_name not in host["running_processes"]:
return -0.03, f"PID '{pid}' is not running on {hostname}"
host["running_processes"].remove(process_name)
self._state.killed_processes.append({"hostname": hostname, "process": process_name, "pid": pid})
if hostname in self._pending_followup:
self._pending_followup[hostname] = True
was_malicious = False
reward = 0.0
for threat in self._task_def["attack_chain"]:
if hostname in threat["compromised_hosts"] and process_name in threat["malicious_processes"]:
was_malicious = True
reward = 0.24
all_killed = True
for th_host in threat["compromised_hosts"]:
for th_proc in threat["malicious_processes"]:
if th_host in self._host_index and th_proc in self._host_index[th_host]["running_processes"]:
all_killed = False
break
if all_killed and threat["threat_id"] in self._state.active_threats:
self._state.active_threats.remove(threat["threat_id"])
if threat["threat_id"] not in self._state.contained_threats:
self._state.contained_threats.append(threat["threat_id"])
reward += 0.12
break
if not was_malicious:
reward = -0.10
self._state.business_impact = min(1.0, self._state.business_impact + 0.04)
return reward, f"Terminated benign PID '{pid}' on {hostname} - business disruption"
return reward, f"Terminated PID '{pid}' on {hostname}. Malicious: True"
def _handle_create_firewall_rule(self, action: CreateFirewallRule) -> tuple[float, str]:
"""Create firewall rule; drop blocks target IP as IOC, allow is neutral."""
hostname = action.hostname
target_ip = action.target_ip
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found"
if action.action == "drop":
if target_ip in self._state.blocked_iocs:
return -0.01, f"Firewall drop rule already exists for {target_ip}"
self._state.blocked_iocs.append(target_ip)
return 0.08, f"Created firewall DROP rule on {hostname} for {target_ip}"
return 0.0, f"Created firewall ALLOW rule on {hostname} for {target_ip}"
def _handle_quarantine_file(self, action: QuarantineFile) -> tuple[float, str]:
"""Quarantine suspicious files; requires terminating associated malicious PID first."""
hostname = action.hostname
file_path = action.file_path
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found"
file_key = (hostname, file_path)
if file_key in self._quarantined_files:
return -0.01, f"File '{file_path}' already quarantined on {hostname}"
associated_processes: List[str] = []
lowered = file_path.lower()
for threat in self._task_def.get("attack_chain", []):
if hostname not in threat.get("compromised_hosts", []):
continue
for proc in threat.get("malicious_processes", []):
expected_suffix = f"\\{proc}.dat".lower()
if lowered.endswith(expected_suffix):
associated_processes.append(proc)
if not associated_processes:
self._quarantined_files.add(file_key)
return -0.02, f"Quarantined untracked file '{file_path}' on {hostname}"
host = self._host_index[hostname]
locked = any(proc in host["running_processes"] for proc in associated_processes)
if locked:
self._state.business_impact = min(1.0, self._state.business_impact + 0.01)
return -0.04, (
f"Quarantine failed: file '{file_path}' is locked. "
"Terminate associated PID first."
)
self._quarantined_files.add(file_key)
return 0.10, f"Quarantined file '{file_path}' on {hostname}"
def _handle_submit_plan(self, action: SubmitContainmentPlan) -> tuple[float, str]:
"""Submit the final containment plan."""
self._last_forensics = None
self._state.submitted_plan = True
self._plan_entries = [entry.model_dump() for entry in action.plan]
# Grade the episode using new 10-dim grader
final_plan_dict = {
"entries": self._plan_entries,
"primary_threat_id": (self._plan_entries[0]["threat_id"]
if self._plan_entries else ""),
}
grade_result = grade_episode(
episode_actions=list(self._state.timeline),
final_plan=final_plan_dict,
graph=self._threat_graph,
task_def=self._task_def,
state=self._state,
disruption_cost=self._disruption_cost,
)
final_score = grade_result["final_score"]
# Reward proportional to final grade
reward = final_score * 1.0 # Scale: perfect score = 1.0 reward
description = (
f"Containment plan submitted. "
f"Grade: {final_score:.3f}. "
f"Threats contained: {len(self._state.contained_threats)}/{len(self._task_def['attack_chain'])}. "
f"Business impact: {self._state.business_impact:.2f}"
)
return reward, description
# ===========================================================================
# New Action Handlers (return observation-update dict)
# ===========================================================================
def _handle_correlate_alerts(self, action: CorrelateAlerts) -> dict:
"""Correlate alerts to find shared hosts/IOCs."""
if len(action.alert_ids) < 2:
return {"error": "correlate_alerts requires at least 2 alert IDs",
"description": "correlate_alerts error"}
graph = self._threat_graph
known_alerts = {aid: graph.alerts[aid] for aid in action.alert_ids if aid in graph.alerts}
if len(known_alerts) < 2:
return {"error": "fewer than 2 alert IDs found in graph",
"description": "correlate_alerts error"}
# Find shared source hosts
source_hosts: dict[str, list[str]] = {}
for aid, alert in known_alerts.items():
source_hosts.setdefault(alert.source_host, []).append(aid)
shared_hosts = [h for h, aids in source_hosts.items() if len(aids) >= 2]
# Find shared IOCs via "involves" edges
shared_iocs: set[str] = set()
for e in graph.edges:
if e.edge_type == "involves" and e.source_id in known_alerts:
if any(
e2.edge_type == "involves" and e2.target_id == e.target_id
and e2.source_id in known_alerts and e2.source_id != e.source_id
for e2 in graph.edges
):
shared_iocs.add(e.target_id)
# Update correlated_with on each alert
all_ids = list(known_alerts.keys())
for aid, alert in known_alerts.items():
for other_id in all_ids:
if other_id != aid and other_id not in alert.correlated_with:
alert.correlated_with.append(other_id)
self._state.correlated_alert_pairs.append(tuple(all_ids))
shared_count = len(shared_hosts) + len(shared_iocs)
correlation_score = min(1.0, shared_count / len(all_ids))
result = {
"correlation_results": {
"shared_hosts": shared_hosts,
"shared_iocs": list(shared_iocs),
"correlation_score": correlation_score,
},
"description": f"Correlated {len(all_ids)} alerts: {len(shared_hosts)} shared hosts",
}
return result
def _handle_enrich_ioc(self, action: EnrichIOC) -> dict:
"""Enrich an IOC with threat-intel data."""
graph = self._threat_graph
if action.ioc_value not in graph.iocs:
return {"error": "IOC not yet discovered",
"description": "enrich_ioc error"}
intel = self._task_def.get("threat_intel_data", {}) or {}
data = intel.get(action.ioc_value, {
"reputation": 0.5,
"threat_actor": "unknown",
"mitre_ttps": [],
})
# Update IOC node in graph
ioc_node = graph.iocs[action.ioc_value]
ioc_node.enriched = True
ioc_node.threat_actor = data.get("threat_actor")
ioc_node.mitre_ttps = data.get("mitre_ttps", [])
if action.ioc_value not in self._state.enriched_iocs:
self._state.enriched_iocs.append(action.ioc_value)
# Mark IOC as discovered — future block_ioc on it will receive full reward
self._discovered_iocs.add(action.ioc_value)
return {
"ioc_enrichment": data,
"description": f"Enriched IOC {action.ioc_value}: actor={data.get('threat_actor')}",
}
def _handle_scan_vulnerabilities(self, action: ScanHostVulnerabilities) -> dict:
"""Scan a host for CVE vulnerabilities."""
graph = self._threat_graph
hostname = action.hostname
if hostname not in graph.hosts:
return {"error": f"Host '{hostname}' not in Threat Graph",
"description": "scan_host_vulnerabilities error"}
vuln_chain = self._task_def.get("vulnerability_chain", []) or []
vuln_results: list[dict] = []
for entry in vuln_chain:
if not isinstance(entry, dict):
continue
if entry.get("hostname") == hostname or entry.get("affected_hosts") and hostname in entry["affected_hosts"]:
cve_id = entry.get("cve_id", "CVE-UNKNOWN")
vuln_node = VulnerabilityNode(
cve_id=cve_id,
hostname=hostname,
cvss_score=entry.get("cvss_score", 5.0),
exploitability=entry.get("exploitability", "theoretical"),
patch_available=entry.get("patch_available", False),
exploited_by_threat=entry.get("threat_id"),
)
graph.add_vulnerability(vuln_node)
graph.add_edge(Edge(
edge_type="exploits",
source_id=cve_id,
target_id=hostname,
))
vuln_results.append(entry)
# Mark host as scanned
graph.hosts[hostname].scanned = True
if hostname not in self._state.scanned_hosts:
self._state.scanned_hosts.append(hostname)
return {
"vulnerability_results": vuln_results,
"description": f"Scanned {hostname}: found {len(vuln_results)} CVEs",
}
# ===========================================================================
# Red Team Action Handlers
# ===========================================================================
def _handle_lateral_pivot(self, action: LateralPivot) -> tuple[float, str]:
"""Red: spread from a compromised host to a new target."""
src = action.source_host
dst = action.target_host
if src not in self._host_index:
return 0.0, f"red: lateral_pivot — source '{src}' not in network"
if self._host_index[src].get("status") != "compromised":
return 0.0, f"red: lateral_pivot — '{src}' not under Red control"
if dst not in self._host_index:
return 0.0, f"red: lateral_pivot — target '{dst}' not in network"
dst_status = self._host_index[dst].get("status", "online")
if dst_status == "isolated":
return 0.0, f"red: lateral_pivot — '{dst}' is isolated, pivot blocked by Blue"
if dst_status == "compromised":
return 0.0, f"red: lateral_pivot — '{dst}' already compromised"
# Compromise target and copy a process from source
self._host_index[dst]["status"] = "compromised"
src_procs = (
[p for p in self._threat_graph.processes.values() if p.hostname == src]
if self._threat_graph else []
)
proc_name = src_procs[0].process_name if src_procs else "cmd.exe"
self._host_index[dst].setdefault("running_processes", [])
if proc_name not in self._host_index[dst]["running_processes"]:
self._host_index[dst]["running_processes"].append(proc_name)
# Update threat graph
if self._threat_graph is not None:
if dst not in self._threat_graph.hosts:
hd = self._host_index[dst]
self._threat_graph.add_host(HostNode(
hostname=dst,
subnet=hd.get("subnet", "corporate"),
business_criticality="medium",
status="compromised",
))
else:
self._threat_graph.hosts[dst].status = "compromised"
pid = f"{dst}:{proc_name}"
if pid not in self._threat_graph.processes:
self._threat_graph.add_process(ProcessNode(
process_id=pid, hostname=dst, process_name=proc_name
))
self._threat_graph.add_edge(Edge(
edge_type="pivoted_from", source_id=dst, target_id=src
))
# Generate SIEM alert for Blue
alert_id = f"PIVOT-{uuid.uuid4().hex[:6].upper()}"
subnet = self._host_index.get(dst, {}).get("subnet", "unknown")
self._alert_queue.append({
"alert_id": alert_id,
"timestamp": "2024-01-01T00:00:00Z",
"source_host": dst,
"severity": "critical",
"threat_type": "lateral_movement",
"description": (
f"Lateral movement detected: {proc_name} spawned on {dst} "
f"(pivot from {src})"
),
"ioc_indicators": [],
"subnet": subnet,
"is_acknowledged": False,
})
if self._threat_graph is not None:
self._threat_graph.add_alert(AlertNode(
alert_id=alert_id, severity="critical",
priority_score=15.0, source_host=dst,
))
# Update live rubric
if self._live_requirements is not None:
self._live_requirements.setdefault("must_kill", []).append({
"hostname": dst, "process": proc_name, "threat_id": "FSP_PIVOT",
})
return 0.0, f"red: lateral_pivot {src}{dst} (proc={proc_name})"
def _handle_deploy_payload(self, action: DeployPayload) -> tuple[float, str]:
"""Red: deploy a malicious payload on a host Red controls."""
hostname = action.hostname
payload_type = action.payload_type
if hostname not in self._host_index:
return 0.0, f"red: deploy_payload — '{hostname}' not in network"
if self._host_index[hostname].get("status") != "compromised":
return 0.0, f"red: deploy_payload — no shell on '{hostname}'"
proc_name = {
"ransomware": "ransomware.exe",
"exfiltration": "exfil_agent.exe",
"c2": "c2_beacon.exe",
}[payload_type]
host = self._host_index[hostname]
if proc_name not in host.get("running_processes", []):
host.setdefault("running_processes", []).append(proc_name)
if self._threat_graph is not None:
pid = f"{hostname}:{proc_name}"
if pid not in self._threat_graph.processes:
self._threat_graph.add_process(ProcessNode(
process_id=pid, hostname=hostname, process_name=proc_name
))
impact_delta = {"ransomware": 0.15, "exfiltration": 0.08, "c2": 0.05}[payload_type]
self._state.business_impact = min(1.0, self._state.business_impact + impact_delta)
severity = {"ransomware": "critical", "exfiltration": "high", "c2": "high"}[payload_type]
alert_id = f"PAYLOAD-{uuid.uuid4().hex[:6].upper()}"
self._alert_queue.append({
"alert_id": alert_id,
"timestamp": "2024-01-01T00:00:00Z",
"source_host": hostname,
"severity": severity,
"threat_type": payload_type,
"description": (
f"{payload_type.capitalize()} payload deployed on {hostname}: {proc_name}"
),
"ioc_indicators": [],
"subnet": host.get("subnet", "unknown"),
"is_acknowledged": False,
})
if self._threat_graph is not None:
self._threat_graph.add_alert(AlertNode(
alert_id=alert_id, severity=severity,
priority_score=18.0, source_host=hostname,
))
return 0.0, f"red: deployed {payload_type} payload on {hostname}"
def _handle_evade_detection(self, action: EvadeDetection) -> tuple[float, str]:
"""Red: apply a detection-evasion technique on a controlled host."""
hostname = action.hostname
technique = action.technique
if hostname not in self._host_index:
return 0.0, f"red: evade_detection — '{hostname}' not in network"
if self._host_index[hostname].get("status") != "compromised":
return 0.0, f"red: evade_detection — no shell on '{hostname}'"
if technique == "migrate_pid":
host = self._host_index[hostname]
malicious_procs = {
proc
for threat in self._task_def.get("attack_chain", [])
if hostname in threat.get("compromised_hosts", [])
for proc in threat.get("malicious_processes", [])
}
for i, proc in enumerate(list(host.get("running_processes", []))):
if proc in malicious_procs:
new_name = f"svchost_{i}.exe"
host["running_processes"][i] = new_name
if self._threat_graph:
old_pid = f"{hostname}:{proc}"
if old_pid in self._threat_graph.processes:
self._threat_graph.processes.pop(old_pid)
new_pid = f"{hostname}:{new_name}"
self._threat_graph.add_process(ProcessNode(
process_id=new_pid, hostname=hostname,
process_name=new_name,
))
return 0.0, f"red: migrated PIDs on {hostname} to blend with system processes"
if technique == "clear_logs":
before = len(self._alert_queue)
self._alert_queue = [
a for a in self._alert_queue
if a.get("source_host") != hostname
]
removed = before - len(self._alert_queue)
return 0.0, f"red: cleared {removed} SIEM alert(s) from {hostname}"
return 0.0, f"red: evasion '{technique}' applied on {hostname}"
def _handle_pass_turn(self, action: PassTurn) -> tuple[float, str]: # noqa: ARG002
"""Red: remain stealthy, take no action."""
return 0.0, "red: pass_turn (stealth)"
def _get_red_action_target(self, action: Any) -> str:
"""Extract a compact target string from a Red action for timeline logging."""
if isinstance(action, LateralPivot):
return f"{action.source_host}{action.target_host}"
if isinstance(action, DeployPayload):
return f"{action.hostname}/{action.payload_type}"
if isinstance(action, EvadeDetection):
return f"{action.hostname}/{action.technique}"
return "—"
# ===========================================================================
# Helpers
# ===========================================================================
def _compute_reward_dimensions(self) -> Dict[str, float]:
"""Per-step heuristic partial scores for all 10 grading dimensions.
Evidence-gated: actions only score if prior evidence justified them.
Result-usage: forensics-confirmed hosts with no followup are penalized.
Scores in [0, 1]; terminal grade_breakdown supersedes these on plan submission.
"""
state = self._state
task_chain = self._task_def.get("attack_chain", [])
total_threats = max(1, len(task_chain))
total_compromised = max(1, sum(len(t.get("compromised_hosts", [])) for t in task_chain))
total_iocs = max(1, sum(
len(t.get("iocs", {}).get("hashes", []))
+ len(t.get("iocs", {}).get("ips", []))
+ len(t.get("iocs", {}).get("domains", []))
for t in task_chain
))
# --- Build evidence pools: what the agent could have observed ---
# Hosts mentioned as alert source (visible from turn 0)
alert_source_hosts: set = set()
for a in self._task_def.get("initial_alerts", []):
alert_source_hosts.add(a.get("source_host", ""))
for a in self._alert_queue:
alert_source_hosts.add(a.get("source_host", ""))
alert_source_hosts.discard("")
# IOCs visible from alert ioc_indicators
alert_iocs: set = set()
for a_list in (self._task_def.get("initial_alerts", []), self._alert_queue):
for a in a_list:
for ioc in a.get("ioc_indicators", []):
alert_iocs.add(ioc)
# IOCs revealed by running forensics on a host
forensics_revealed_iocs: set = set()
for hostname in state.forensics_run:
for threat in task_chain:
if hostname in threat.get("compromised_hosts", []):
forensics_revealed_iocs.update(threat.get("c2_servers", []))
forensics_revealed_iocs.update(threat["iocs"].get("hashes", []))
forensics_revealed_iocs.update(threat["iocs"].get("ips", []))
forensics_revealed_iocs.update(threat["iocs"].get("domains", []))
discovered_iocs = alert_iocs | forensics_revealed_iocs
# 1. threat_containment — fraction of threats neutralised (no evidence gate; outcome IS evidence)
threat_containment = min(1.0, len(state.contained_threats) / total_threats)
# 2. ioc_blocking — only blocks of IOCs the agent actually discovered count
justified_blocks = [ioc for ioc in state.blocked_iocs if ioc in discovered_iocs]
ioc_blocking = min(1.0, len(justified_blocks) / total_iocs)
# 3. forensic_investigation — only counts forensics on alert-mentioned or previously queried
# hosts; penalizes confirmed compromises left with no response action
justified_forensics = [
h for h in state.forensics_run
if h in alert_source_hosts or h in state.queried_hosts
]
pending = self._pending_followup
unresponded = sum(1 for v in pending.values() if not v)
followup_penalty = min(0.30, unresponded * 0.10)
forensic_investigation = max(0.0,
min(1.0, len(justified_forensics) / total_compromised) - followup_penalty
)
# 4. siem_correlation — scored by semantic quality (shared source hosts or IOCs)
if not state.correlated_alert_pairs:
siem_correlation = 0.0
else:
alert_map: Dict[str, Any] = {}
for a in self._task_def.get("initial_alerts", []):
alert_map[a.get("alert_id", "")] = a
for a in self._alert_queue:
alert_map[a.get("alert_id", "")] = a
quality_scores = []
for pair in state.correlated_alert_pairs:
pair_alerts = [alert_map[aid] for aid in pair if aid in alert_map]
if len(pair_alerts) < 2:
quality_scores.append(0.3)
continue
sources = [a.get("source_host") for a in pair_alerts]
ioc_sets = [set(a.get("ioc_indicators", [])) for a in pair_alerts]
shared_hosts = len(sources) != len({s for s in sources if s})
shared_iocs = bool(ioc_sets[0] & ioc_sets[1]) if len(ioc_sets) >= 2 else False
quality_scores.append(1.0 if (shared_hosts or shared_iocs) else 0.2)
siem_correlation = sum(quality_scores) / max(1, len(quality_scores))
# 5. threat_intel_usage — only enrichments of discovered IOCs count
justified_enrichments = [ioc for ioc in state.enriched_iocs if ioc in discovered_iocs]
threat_intel_usage = min(1.0, len(justified_enrichments) / total_iocs)
# 6. vuln_root_cause — fraction of threats with a scanned host
vuln_root_cause = min(1.0, len(state.scanned_hosts) / total_threats)
# 7. business_impact — proportionate isolation + low overall impact
# Reward: isolating confirmed-compromised hosts Penalize: isolating clean hosts
isolated_host_set = {
h for h, hd in self._host_index.items() if hd.get("status") == "isolated"
} if self._host_index else set()
compromised_host_set = {
h for threat in task_chain for h in threat.get("compromised_hosts", [])
}
if isolated_host_set:
over_isolated = isolated_host_set - compromised_host_set
isolation_proportion = (
len(isolated_host_set - over_isolated) / len(isolated_host_set)
)
over_iso_penalty = min(0.40, len(over_isolated) * 0.15)
else:
isolation_proportion = 1.0
over_iso_penalty = 0.0
raw_impact_score = max(0.0, 1.0 - state.business_impact)
business_impact = max(0.0, min(1.0,
0.6 * raw_impact_score + 0.4 * isolation_proportion - over_iso_penalty
))
# 8. step_efficiency — reward early resolution
ratio = state.step_count / max(1, state.max_steps)
step_efficiency = max(0.0, 1.0 - max(0.0, ratio - 0.5) * 1.5)
# 9. plan_coverage — partial credit scales with threats addressed
if state.submitted_plan:
plan_coverage = min(1.0, len(self._plan_entries) / total_threats)
else:
plan_coverage = min(0.5, len(state.contained_threats) / total_threats * 0.5)
# 10. plan_evidence_quality — confidence of submitted plan; else evidence depth proxy
if state.submitted_plan and self._plan_entries:
avg_conf = sum(e.get("confidence", 0.0) for e in self._plan_entries) / len(self._plan_entries)
plan_evidence_quality = float(avg_conf)
else:
evidence_count = len(justified_forensics) + len(justified_enrichments) + len(state.scanned_hosts)
plan_evidence_quality = min(0.5, evidence_count / (total_compromised * 3) * 0.5)
return {
"threat_containment": round(threat_containment, 4),
"ioc_blocking": round(ioc_blocking, 4),
"forensic_investigation": round(forensic_investigation, 4),
"siem_correlation": round(siem_correlation, 4),
"threat_intel_usage": round(threat_intel_usage, 4),
"vuln_root_cause": round(vuln_root_cause, 4),
"business_impact": round(business_impact, 4),
"step_efficiency": round(step_efficiency, 4),
"plan_coverage": round(plan_coverage, 4),
"plan_evidence_quality": round(plan_evidence_quality, 4),
}
def _get_current_phase(self) -> str:
"""Derive episode phase from the action history in the timeline."""
action_types = {t["action_type"] for t in self._state.timeline}
if any(t in action_types for t in ["kill_process", "block_ioc", "isolate_segment", "terminate_pid", "create_firewall_rule", "quarantine_file"]):
return "remediation"
if any(t in action_types for t in ["run_forensics", "enrich_ioc", "scan_host_vulnerabilities", "query_host"]):
return "investigation"
return "triage"
def _build_observation(self, reward: float, done: bool) -> SOCObservation:
"""Build the observation from current state."""
# Compute network topology summary
subnet_counts = {name: len(hosts) for name, hosts in self._network.items()}
compromised = sum(
1 for hosts in self._network.values()
for h in hosts if h["status"] == "compromised"
)
isolated = sum(
1 for hosts in self._network.values()
for h in hosts if h["status"] == "isolated"
)
total = sum(len(hosts) for hosts in self._network.values())
topology = NetworkTopology(
total_hosts=total,
subnets=subnet_counts,
compromised_count=compromised,
isolated_count=isolated,
online_count=total - compromised - isolated,
)
# Build alert list
alerts = [Alert(**a) for a in self._alert_queue]
# Build timeline
timeline = [
TimelineEntry(
step=t["step"],
action_type=t["action_type"],
target=t["target"],
result=t["result"],
reward=t["reward"],
)
for t in self._state.timeline
]
# Compute final grade if done
final_score_val = None
grade_breakdown_val = None
if done and self._state.submitted_plan:
final_plan_dict = {
"entries": self._plan_entries,
"primary_threat_id": (self._plan_entries[0]["threat_id"]
if self._plan_entries else ""),
}
computed = grade_episode(
episode_actions=list(self._state.timeline),
final_plan=final_plan_dict,
graph=self._threat_graph,
task_def=self._task_def,
state=self._state,
disruption_cost=self._disruption_cost,
)
final_score_val = round(computed["final_score"], 4)
grade_breakdown_val = computed["breakdown"]
# Merge per-step observation extras (process_tree, correlation_results, etc.)
extras = getattr(self, "_last_obs_extras", {}) or {}
threat_graph_summary = None
if self._threat_graph is not None:
threat_graph_summary = self._threat_graph.get_context_summary()
# Per-step partial reward dimensions for GRPO credit assignment
reward_dimensions = self._compute_reward_dimensions()
# Red observation — only populated when it is Red's turn next
red_obs = (
self._generate_red_observation()
if self._state.active_turn == "red"
else None
)
return SOCObservation(
episode_id=self._state.episode_id or "",
alert_queue=alerts,
network_topology=topology,
host_forensics=self._last_forensics,
timeline=timeline,
business_impact_score=round(self._state.business_impact, 4),
step_count=self._state.step_count,
active_threats=list(self._state.active_threats),
max_steps=self._state.max_steps,
task_id=self._state.task_id,
total_reward=round(self._state.total_reward, 4),
final_score=final_score_val,
grade_breakdown=grade_breakdown_val,
done=done,
reward=round(reward, 4),
correlation_results=extras.get("correlation_results"),
ioc_enrichment=extras.get("ioc_enrichment"),
vulnerability_results=extras.get("vulnerability_results"),
playbook_result=None,
threat_graph_summary=threat_graph_summary,
available_playbooks=[],
reward_dimensions=reward_dimensions,
active_turn=self._state.active_turn,
red_observation=red_obs,
)
def _get_action_target(self, action: Any) -> str:
"""Extract the target string from a typed action for timeline logging."""
if isinstance(action, QueryHost):
return action.hostname
elif isinstance(action, IsolateSegment):
return getattr(action, "target_host", None) or action.subnet
elif isinstance(action, BlockIOC):
return f"{action.ioc_type}:{action.ioc_value}"
elif isinstance(action, RunForensics):
return action.hostname
elif isinstance(action, KillProcess):
return f"{action.hostname}/{action.process_name}"
elif isinstance(action, SubmitContainmentPlan):
return f"{len(action.plan)} entries"
elif isinstance(action, CorrelateAlerts):
return ",".join(action.alert_ids)
elif isinstance(action, EnrichIOC):
return action.ioc_value
elif isinstance(action, ScanHostVulnerabilities):
return action.hostname
elif isinstance(action, TerminatePID):
return f"{action.hostname}/{action.pid}"
elif isinstance(action, CreateFirewallRule):
return f"{action.hostname}:{action.action}:{action.target_ip}"
elif isinstance(action, QuarantineFile):
return f"{action.hostname}:{action.file_path}"
return "unknown"
# ===========================================================================
# Adaptive Red Team + Step Rewards (Task 10)
# ===========================================================================
def _generate_red_observation(self) -> Dict[str, Any]:
"""What the Red Team LLM sees: footholds it controls + Blue's last action.
Returned as the ``red_observation`` field in SOCObservation whenever
``active_turn == 'red'``, so inference.py can feed it straight to the
Red LLM without a separate API call.
"""
compromised_hosts = [
h for h, hd in self._host_index.items()
if hd.get("status") == "compromised"
]
# Most recent Blue action from the timeline (exclude Red's own entries)
blue_actions_detected: List[Dict[str, Any]] = []
for entry in reversed(self._state.timeline):
action_type = entry.get("action_type", "")
if not action_type.startswith("red:"):
blue_actions_detected.append({
"step": entry["step"],
"action": action_type,
"target": entry["target"],
"result": entry["result"],
})
break # Only the single most recent Blue action
return {
"episode_id": self._state.episode_id,
"round": self._state.step_count + 1,
"compromised_hosts": compromised_hosts,
"blue_actions_detected": blue_actions_detected,
"active_threats": list(self._state.active_threats),
"business_impact": round(self._state.business_impact, 4),
}
def _log_red_decision(self, observation: Dict[str, Any], action: Dict[str, Any]) -> None:
"""Record (observation -> action) tuples for red-team imitation warm-start."""
record = {"observation": observation, "action": action}
self._red_team_decisions.append(record)
if self._red_team_logger is not None:
try:
self._red_team_logger(record)
except Exception:
# Logging is best effort and should never affect environment execution.
pass
def _apply_red_team_dynamics(self, action_type: str, target: str) -> None:
"""Execute embedded Red dynamics in non-FSP mode.
When neural_red_policy is callable: invoke it with the current red
observation, route the returned action through the Red handlers, and
log the (obs → action) pair for offline SFT.
When neural_red_policy is None (adaptive=True path): apply the
deterministic fallback policy and log the pair.
"""
red_obs = self._generate_red_observation()
if callable(self._neural_red_policy):
try:
action_dict = self._neural_red_policy(red_obs)
if not isinstance(action_dict, dict):
action_dict = {"type": "pass_turn"}
except Exception:
action_dict = {"type": "pass_turn"}
atype = action_dict.get("type", "pass_turn")
if atype == "lateral_pivot":
src = action_dict.get("source_host", "")
dst = action_dict.get("target_host", "")
if src and dst:
self._handle_lateral_pivot(
LateralPivot(type="lateral_pivot", source_host=src, target_host=dst)
)
elif atype == "deploy_payload":
h = action_dict.get("hostname", "")
pl = action_dict.get("payload_type", "ransomware")
if h:
self._handle_deploy_payload(
DeployPayload(type="deploy_payload", hostname=h, payload_type=pl)
)
elif atype == "evade_detection":
h = action_dict.get("hostname", "")
tech = action_dict.get("technique", "migrate_pid")
if h:
self._handle_evade_detection(
EvadeDetection(type="evade_detection", hostname=h, technique=tech)
)
# pass_turn → no graph mutation needed
self._log_red_decision(red_obs, action_dict)
else:
# Deterministic fallback for imitation warm-start (adaptive=True path)
det_action = self._deterministic_red_policy(action_type, target, red_obs)
atype = det_action.get("type", "pass_turn")
if atype == "lateral_pivot":
self._handle_lateral_pivot(
LateralPivot(
type="lateral_pivot",
source_host=det_action["source_host"],
target_host=det_action["target_host"],
)
)
elif atype == "deploy_payload":
dp_host = det_action.get("hostname", "")
dp_payload = det_action.get("payload_type", "ransomware")
if dp_host:
self._handle_deploy_payload(
DeployPayload(
type="deploy_payload",
hostname=dp_host,
payload_type=dp_payload,
)
)
self._log_red_decision(red_obs, det_action)
def _deterministic_red_policy(
self, blue_action: str, blue_target: str, red_obs: Dict[str, Any]
) -> Dict[str, Any]:
"""Rule-based Red policy for SFT imitation warm-start data collection.
Priority order:
1. Stall punishment — >= 3 consecutive passive Blue actions deploy ransomware.
2. Reactive pivot — Blue containment action triggers lateral movement.
3. Autonomous pivot — 15% chance to spread even on passive Blue actions.
"""
_passive = frozenset({"query_host", "pass_turn"})
_containment = frozenset({"kill_process", "isolate_segment", "block_ioc"})
compromised = red_obs.get("compromised_hosts", [])
# 1. Stall punishment: >= 3 consecutive passive steps without containment
if blue_action in _passive and compromised:
streak = 0
for entry in reversed(getattr(self, "_recent_actions", [])):
if isinstance(entry, tuple) and entry[0] in _passive:
streak += 1
else:
break
if streak >= 3:
return {
"type": "deploy_payload",
"hostname": compromised[0],
"payload_type": "ransomware",
}
# 2. Reactive pivot on Blue containment actions
if blue_action in _containment:
src = compromised[0] if compromised else (blue_target or None)
if src is not None and src in self._host_index:
dst = next(
(h for h, hd in self._host_index.items()
if hd.get("status") not in ("compromised", "isolated") and h != src),
None,
)
if dst:
return {"type": "lateral_pivot", "source_host": src, "target_host": dst}
# 3. Autonomous pivot: 15% chance even when Blue is passive
if blue_action in _passive and compromised and self._rng.random() < 0.15:
src = compromised[0]
dst = next(
(h for h, hd in self._host_index.items()
if hd.get("status") not in ("compromised", "isolated") and h != src),
None,
)
if dst:
return {"type": "lateral_pivot", "source_host": src, "target_host": dst}
return {"type": "pass_turn"}
def export_red_team_decisions(self) -> List[Dict[str, Any]]:
"""Return a copy of recorded red-team decisions for offline SFT."""
return list(self._red_team_decisions)
STEP_REWARDS: Dict[Any, float] = {
("investigation", "run_forensics"): +0.10,
("investigation", "enrich_ioc"): +0.05,
("investigation", "scan_host_vulnerabilities"): +0.05,
("triage", "correlate_alerts"): +0.05,
"phase_violation_attempt": -0.20,
"ungrounded_action_attempt": -0.10,
}
def _get_step_reward(self, phase: str, action_type: str, target: str) -> float:
"""Idempotent step reward — fires only once per (phase, action_type, target) triple.
Hard cap: total step rewards per episode never exceed 0.40.
"""
if not hasattr(self, "_fired_step_rewards"):
self._fired_step_rewards = set()
# Hard cap: once we've reached 0.40 in step rewards, return 0 for all subsequent
if getattr(self, "_step_reward_total", 0.0) >= 0.40:
return 0.0
key = (phase, action_type, target)
if key in self._fired_step_rewards:
return 0.0
reward = self.STEP_REWARDS.get((phase, action_type), 0.0)
if reward != 0.0:
self._fired_step_rewards.add(key)
return reward
def _maybe_reinfect(self, hostname: str, process_name: str) -> None:
"""30 % chance to reinfect with a _v2 variant when unblocked IOCs exist in the threat chain."""
if not self._adaptive:
return
graph = self._threat_graph
if graph is None:
return
# Check whether any IOC in the host's threat chain is still unblocked
unblocked_chain_iocs = False
for ioc_node in graph.iocs.values():
if not ioc_node.blocked:
# Is this IOC linked (via any edge) to the same host's chain?
for e in graph.edges:
if e.target_id == hostname or e.source_id == hostname:
unblocked_chain_iocs = True
break
if unblocked_chain_iocs:
break
if not unblocked_chain_iocs:
return
if self._rng.random() >= 0.3:
return
# Reinfect: spawn a _v2 variant process on the host
variant_name = f"{process_name}_v2"
if hostname in self._host_index:
host = self._host_index[hostname]
if variant_name not in host["running_processes"]:
host["running_processes"].append(variant_name)
host["status"] = "compromised"
# Add the variant to the threat graph
pid = f"{hostname}:{variant_name}"
if pid not in graph.processes:
graph.add_process(ProcessNode(
process_id=pid,
hostname=hostname,
process_name=variant_name,
killed=False,
))
# Emit a CRITICAL alert to signal the reinfection
alert_id = f"REINFECT-{uuid.uuid4().hex[:6].upper()}"
graph.add_alert(AlertNode(
alert_id=alert_id,
severity="critical",
priority_score=18.0,
source_host=hostname,
))
self._alert_queue.append({
"alert_id": alert_id,
"timestamp": "2024-01-01T00:00:00Z",
"source_host": hostname,
"severity": "critical",
"threat_type": "malware",
"description": f"Reinfection detected: {variant_name} spawned on {hostname} (IOC-assisted persistence)",
"ioc_indicators": [],
"subnet": self._host_index.get(hostname, {}).get("subnet", "unknown"),
"is_acknowledged": False,
})
def _adversary_react(self, action_type: str, target: str) -> Optional[Dict[str, Any]]:
"""Legacy hook — disabled; Red Team now acts via explicit RedActionWrapper steps."""
return None
@property
def state(self) -> SOCState:
"""Get the current internal environment state."""
return self._state