CyberSOC / server /play_environment.py
Ajayyy00
Initial commit: CyberSOC Enterprise Environment Baseline
bb0d7fd
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
CyberSOCEnv — Enterprise Cybersecurity Operations Center Environment.
Implements the OpenEnv Environment interface for a deterministic SOC
incident response simulation on a 500-node enterprise network.
The agent receives SIEM/EDR alerts, queries hosts, runs forensics,
isolates segments, blocks IOCs, kills processes, and submits a
containment plan — all while minimizing business downtime.
"""
from __future__ import annotations
import copy
from typing import Any, Dict, List, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import (
SOCObservation,
SOCActionWrapper,
SOCState,
Alert,
NetworkTopology,
ForensicsResult,
TimelineEntry,
QueryHost,
IsolateSegment,
BlockIOC,
RunForensics,
KillProcess,
SubmitContainmentPlan,
)
except ImportError:
from models import (
SOCObservation,
SOCActionWrapper,
SOCState,
Alert,
NetworkTopology,
ForensicsResult,
TimelineEntry,
QueryHost,
IsolateSegment,
BlockIOC,
RunForensics,
KillProcess,
SubmitContainmentPlan,
)
from .tasks import get_task, build_network
from .graders import grade_episode
class CyberSOCEnvironment(Environment):
"""
Deterministic SOC incident response environment.
Simulates a 500-node enterprise network under attack. The agent must
investigate alerts, contain threats, and submit a containment plan
while minimizing business downtime.
Supports concurrent WebSocket sessions (each gets own instance).
Example:
>>> env = CyberSOCEnvironment()
>>> obs = env.reset(task_id="easy")
>>> print(len(obs.alert_queue)) # Initial alerts
>>> obs = env.step(SOCActionWrapper(type="query_host", hostname="WS-042"))
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
"""Initialize the environment (actual state set in reset)."""
super().__init__()
self._state = SOCState(episode_id=str(uuid4()), step_count=0)
self._network: Dict[str, List[Dict[str, Any]]] = {}
self._task_def: Dict[str, Any] = {}
self._alert_queue: List[Dict[str, Any]] = []
self._host_index: Dict[str, Dict[str, Any]] = {} # hostname -> host dict
self._plan_entries: List[Dict[str, Any]] = []
self._last_forensics: Optional[ForensicsResult] = None
# ===========================================================================
# reset()
# ===========================================================================
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> SOCObservation:
"""Reset the environment for a specific task.
Args:
seed: Ignored (environment is fully deterministic).
episode_id: Optional custom episode ID.
**kwargs: Must include task_id ('easy', 'medium', or 'hard').
Returns:
Initial SOCObservation with alert queue and network state.
"""
task_id = kwargs.get("task_id", "easy")
self._task_def = get_task(task_id)
# Build deterministic network
self._network = build_network()
# Build hostname index for O(1) lookups
self._host_index = {}
for subnet_name, hosts in self._network.items():
for host in hosts:
self._host_index[host["hostname"]] = host
# Inject attack chain: mark compromised hosts, add malicious processes
for threat in self._task_def["attack_chain"]:
for hostname in threat["compromised_hosts"]:
if hostname in self._host_index:
host = self._host_index[hostname]
host["status"] = "compromised"
for proc in threat["malicious_processes"]:
if proc not in host["running_processes"]:
host["running_processes"].append(proc)
# Initialize alert queue (deep copy so mutations don't affect task def)
self._alert_queue = copy.deepcopy(self._task_def["initial_alerts"])
# Reset state
eid = episode_id or str(uuid4())
self._state = SOCState(
episode_id=eid,
step_count=0,
task_id=task_id,
max_steps=self._task_def["max_steps"],
total_reward=0.0,
business_impact=self._task_def["initial_business_impact"],
contained_threats=[],
active_threats=[t["threat_id"] for t in self._task_def["attack_chain"]],
blocked_iocs=[],
isolated_subnets=[],
forensics_run=[],
killed_processes=[],
queried_hosts=[],
timeline=[],
is_done=False,
submitted_plan=False,
)
self._plan_entries = []
self._last_forensics = None
self._reset_rubric()
return self._build_observation(reward=0.0, done=False)
# ===========================================================================
# step()
# ===========================================================================
def step(
self,
action: SOCActionWrapper, # type: ignore[override]
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> SOCObservation:
"""Process one agent action.
Args:
action: SOCActionWrapper containing the typed action.
timeout_s: Ignored.
Returns:
SOCObservation with updated state, reward, and done flag.
"""
if self._state.is_done:
return self._build_observation(reward=0.0, done=True)
# Increment step
self._state.step_count += 1
# Convert wrapper to typed action
typed_action = action.to_typed_action()
# Dispatch to handler
reward = 0.0
result_description = "unknown action"
if isinstance(typed_action, QueryHost):
reward, result_description = self._handle_query_host(typed_action)
elif isinstance(typed_action, IsolateSegment):
reward, result_description = self._handle_isolate_segment(typed_action)
elif isinstance(typed_action, BlockIOC):
reward, result_description = self._handle_block_ioc(typed_action)
elif isinstance(typed_action, RunForensics):
reward, result_description = self._handle_run_forensics(typed_action)
elif isinstance(typed_action, KillProcess):
reward, result_description = self._handle_kill_process(typed_action)
elif isinstance(typed_action, SubmitContainmentPlan):
reward, result_description = self._handle_submit_plan(typed_action)
# Business impact grows each step (attacker progresses)
if not self._state.is_done:
impact_rate = self._task_def.get("impact_per_step", 0.02)
# Reduce impact growth if threats are being contained
active_ratio = len(self._state.active_threats) / max(1, len(self._task_def["attack_chain"]))
self._state.business_impact = min(
1.0,
self._state.business_impact + impact_rate * active_ratio,
)
# Record timeline
self._state.timeline.append({
"step": self._state.step_count,
"action_type": typed_action.type,
"target": self._get_action_target(typed_action),
"result": result_description,
"reward": reward,
})
# Accumulate reward
self._state.total_reward += reward
# Check termination
done = False
if self._state.submitted_plan:
done = True
self._state.is_done = True
elif self._state.step_count >= self._state.max_steps:
done = True
self._state.is_done = True
reward -= 0.20 # Penalty for running out of time
self._state.total_reward += (-0.20)
return self._build_observation(reward=reward, done=done)
# ===========================================================================
# Action Handlers (return (reward, description))
# ===========================================================================
def _handle_query_host(self, action: QueryHost) -> tuple[float, str]:
"""Query a host for status info."""
hostname = action.hostname
self._last_forensics = None # Clear forensics from previous step
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found in network"
host = self._host_index[hostname]
# Reward for querying compromised hosts (useful investigation)
reward = 0.0
if host["status"] == "compromised" and hostname not in self._state.queried_hosts:
reward = 0.05 # Good: investigating a compromised host
elif hostname in self._state.queried_hosts:
reward = -0.02 # Penalty: re-querying same host wastes time
self._state.queried_hosts.append(hostname)
return reward, f"Queried {hostname}: status={host['status']}, procs={len(host['running_processes'])}"
def _handle_isolate_segment(self, action: IsolateSegment) -> tuple[float, str]:
"""Isolate a network segment."""
subnet = action.subnet
self._last_forensics = None
if subnet not in self._network:
return -0.05, f"Subnet '{subnet}' does not exist"
if subnet in self._state.isolated_subnets:
return -0.02, f"Subnet '{subnet}' is already isolated"
# Isolate all hosts in the subnet
for host in self._network[subnet]:
host["status"] = "isolated"
self._state.isolated_subnets.append(subnet)
# Check if this contains any active threats
reward = 0.0
threats_contained = []
for threat in self._task_def["attack_chain"]:
if threat["threat_id"] in self._state.active_threats:
# Check if any compromised hosts are in this subnet
for ch in threat["compromised_hosts"]:
if ch in self._host_index and self._host_index[ch]["subnet"] == subnet:
threats_contained.append(threat["threat_id"])
break
if threats_contained:
reward = 0.15 * len(threats_contained) # Good: containing lateral movement
for tid in threats_contained:
if tid not in self._state.contained_threats:
self._state.contained_threats.append(tid)
if tid in self._state.active_threats:
self._state.active_threats.remove(tid)
# Check if this is an unnecessary isolation (business downtime)
must_not_isolate = self._task_def["containment_requirements"].get("must_not_isolate", [])
if subnet in must_not_isolate:
reward -= 0.10 # Penalty: unnecessary downtime
self._state.business_impact = min(1.0, self._state.business_impact + 0.08)
return reward, f"Isolated subnet '{subnet}'. Threats contained: {threats_contained}"
def _handle_block_ioc(self, action: BlockIOC) -> tuple[float, str]:
"""Block an IOC at the perimeter."""
ioc = action.ioc_value
self._last_forensics = None
if ioc in self._state.blocked_iocs:
return -0.02, f"IOC '{ioc}' is already blocked"
self._state.blocked_iocs.append(ioc)
# Check if this IOC is relevant to any active threat
reward = 0.0
relevant = False
for threat in self._task_def["attack_chain"]:
all_iocs = (
threat["iocs"].get("hashes", [])
+ threat["iocs"].get("ips", [])
+ threat["iocs"].get("domains", [])
)
if ioc in all_iocs:
relevant = True
# Extra reward for blocking C2 server IPs
if ioc in threat.get("c2_servers", []):
reward += 0.15 # High value: cutting C2
else:
reward += 0.10 # Good: blocking relevant IOC
break
if not relevant:
reward = -0.03 # Noise: blocking irrelevant IOC
return reward, f"Blocked IOC '{ioc}' (type={action.ioc_type}). Relevant: {relevant}"
def _handle_run_forensics(self, action: RunForensics) -> tuple[float, str]:
"""Run forensic analysis on a host."""
hostname = action.hostname
if hostname not in self._host_index:
self._last_forensics = None
return -0.05, f"Host '{hostname}' not found"
host = self._host_index[hostname]
# Build forensics result based on actual host state
is_compromised = host["status"] == "compromised"
malicious_procs = []
suspicious_files = []
network_conns = []
registry_mods = []
memory_artifacts = []
if is_compromised:
# Find which threat(s) affect this host
for threat in self._task_def["attack_chain"]:
if hostname in threat["compromised_hosts"]:
malicious_procs.extend(threat["malicious_processes"])
# Generate deterministic forensic artifacts
for proc in threat["malicious_processes"]:
suspicious_files.append(f"C:\\Windows\\Temp\\{proc}.dat")
registry_mods.append(f"HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run\\{proc}")
for c2 in threat.get("c2_servers", []):
network_conns.append(f"{c2}:443")
for ioc_hash in threat["iocs"].get("hashes", []):
memory_artifacts.append(f"memory_inject_{ioc_hash[:8]}")
self._last_forensics = ForensicsResult(
hostname=hostname,
malicious_processes=malicious_procs,
suspicious_files=suspicious_files,
network_connections=network_conns,
registry_modifications=registry_mods,
memory_artifacts=memory_artifacts,
is_compromised=is_compromised,
)
# Reward
reward = 0.0
if hostname not in self._state.forensics_run:
if is_compromised:
reward = 0.10 # Good: found evidence
else:
reward = 0.02 # Cleared a host (some value)
self._state.forensics_run.append(hostname)
else:
reward = -0.02 # Re-running forensics wastes time
return reward, f"Forensics on {hostname}: compromised={is_compromised}, procs={malicious_procs}"
def _handle_kill_process(self, action: KillProcess) -> tuple[float, str]:
"""Kill a process on a host."""
hostname = action.hostname
process = action.process_name
self._last_forensics = None
if hostname not in self._host_index:
return -0.05, f"Host '{hostname}' not found"
host = self._host_index[hostname]
if host["status"] == "isolated":
return -0.02, f"Host '{hostname}' is isolated — cannot interact"
if process not in host["running_processes"]:
return -0.03, f"Process '{process}' not running on {hostname}"
# Kill the process
host["running_processes"].remove(process)
self._state.killed_processes.append({"hostname": hostname, "process": process})
# Check if this was a malicious process
reward = 0.0
was_malicious = False
for threat in self._task_def["attack_chain"]:
if hostname in threat["compromised_hosts"] and process in threat["malicious_processes"]:
was_malicious = True
reward = 0.15 # Major reward: stopping malicious activity
# Check if all processes for this threat are killed
all_killed = True
for th_host in threat["compromised_hosts"]:
for th_proc in threat["malicious_processes"]:
still_running = (
th_host in self._host_index
and th_proc in self._host_index[th_host]["running_processes"]
)
if still_running:
all_killed = False
break
if all_killed and threat["threat_id"] in self._state.active_threats:
self._state.active_threats.remove(threat["threat_id"])
if threat["threat_id"] not in self._state.contained_threats:
self._state.contained_threats.append(threat["threat_id"])
reward += 0.10 # Bonus: fully contained a threat
break
if not was_malicious:
reward = -0.08 # Penalty: killing legitimate process = downtime
self._state.business_impact = min(1.0, self._state.business_impact + 0.03)
return reward, f"Killed '{process}' on {hostname}. Malicious: {was_malicious}"
def _handle_submit_plan(self, action: SubmitContainmentPlan) -> tuple[float, str]:
"""Submit the final containment plan."""
self._last_forensics = None
self._state.submitted_plan = True
self._plan_entries = [entry.model_dump() for entry in action.plan]
# Grade the episode
final_score = grade_episode(
task_id=self._state.task_id,
task_def=self._task_def,
killed_processes=self._state.killed_processes,
blocked_iocs=self._state.blocked_iocs,
forensics_run=self._state.forensics_run,
isolated_subnets=self._state.isolated_subnets,
submitted_plan=True,
plan_entries=self._plan_entries,
final_business_impact=self._state.business_impact,
step_count=self._state.step_count,
total_reward=self._state.total_reward,
)
# Reward proportional to final grade
reward = final_score * 1.0 # Scale: perfect score = 1.0 reward
description = (
f"Containment plan submitted. "
f"Grade: {final_score:.3f}. "
f"Threats contained: {len(self._state.contained_threats)}/{len(self._task_def['attack_chain'])}. "
f"Business impact: {self._state.business_impact:.2f}"
)
return reward, description
# ===========================================================================
# Helpers
# ===========================================================================
def _build_observation(self, reward: float, done: bool) -> SOCObservation:
"""Build the observation from current state."""
# Compute network topology summary
subnet_counts = {name: len(hosts) for name, hosts in self._network.items()}
compromised = sum(
1 for hosts in self._network.values()
for h in hosts if h["status"] == "compromised"
)
isolated = sum(
1 for hosts in self._network.values()
for h in hosts if h["status"] == "isolated"
)
total = sum(len(hosts) for hosts in self._network.values())
topology = NetworkTopology(
total_hosts=total,
subnets=subnet_counts,
compromised_count=compromised,
isolated_count=isolated,
online_count=total - compromised - isolated,
)
# Build alert list
alerts = [Alert(**a) for a in self._alert_queue]
# Build timeline
timeline = [
TimelineEntry(
step=t["step"],
action_type=t["action_type"],
target=t["target"],
result=t["result"],
reward=t["reward"],
)
for t in self._state.timeline
]
# Compute final grade if done
final_score_val = None
grade_breakdown_val = None
if done and self._state.submitted_plan:
computed_score = grade_episode(
task_id=self._state.task_id,
task_def=self._task_def,
killed_processes=self._state.killed_processes,
blocked_iocs=self._state.blocked_iocs,
forensics_run=self._state.forensics_run,
isolated_subnets=self._state.isolated_subnets,
submitted_plan=self._state.submitted_plan,
plan_entries=self._plan_entries,
final_business_impact=self._state.business_impact,
step_count=self._state.step_count,
total_reward=self._state.total_reward,
)
final_score_val = round(computed_score, 4)
grade_breakdown_val = {
"threats_contained": len(self._state.contained_threats),
"total_threats": len(self._task_def["attack_chain"]),
"iocs_blocked": len(self._state.blocked_iocs),
"hosts_forensics": len(self._state.forensics_run),
"subnets_isolated": len(self._state.isolated_subnets),
"business_impact": round(self._state.business_impact, 4),
}
return SOCObservation(
alert_queue=alerts,
network_topology=topology,
host_forensics=self._last_forensics,
timeline=timeline,
business_impact_score=round(self._state.business_impact, 4),
step_count=self._state.step_count,
active_threats=list(self._state.active_threats),
max_steps=self._state.max_steps,
task_id=self._state.task_id,
total_reward=round(self._state.total_reward, 4),
final_score=final_score_val,
grade_breakdown=grade_breakdown_val,
done=done,
reward=round(reward, 4),
)
def _get_action_target(self, action: Any) -> str:
"""Extract the target string from a typed action for timeline logging."""
if isinstance(action, QueryHost):
return action.hostname
elif isinstance(action, IsolateSegment):
return action.subnet
elif isinstance(action, BlockIOC):
return f"{action.ioc_type}:{action.ioc_value}"
elif isinstance(action, RunForensics):
return action.hostname
elif isinstance(action, KillProcess):
return f"{action.hostname}/{action.process_name}"
elif isinstance(action, SubmitContainmentPlan):
return f"{len(action.plan)} entries"
return "unknown"
@property
def state(self) -> SOCState:
"""Get the current internal environment state."""
return self._state