# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ CyberSOCEnv — Enterprise Cybersecurity Operations Center Environment. Implements the OpenEnv Environment interface for a deterministic SOC incident response simulation on a 500-node enterprise network. The agent receives SIEM/EDR alerts, queries hosts, runs forensics, isolates segments, blocks IOCs, kills processes, and submits a containment plan — all while minimizing business downtime. """ from __future__ import annotations import copy from typing import Any, Dict, List, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..models import ( SOCObservation, SOCActionWrapper, SOCState, Alert, NetworkTopology, ForensicsResult, TimelineEntry, QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan, ) except ImportError: from models import ( SOCObservation, SOCActionWrapper, SOCState, Alert, NetworkTopology, ForensicsResult, TimelineEntry, QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan, ) from .tasks import get_task, build_network from .graders import grade_episode class CyberSOCEnvironment(Environment): """ Deterministic SOC incident response environment. Simulates a 500-node enterprise network under attack. The agent must investigate alerts, contain threats, and submit a containment plan while minimizing business downtime. Supports concurrent WebSocket sessions (each gets own instance). Example: >>> env = CyberSOCEnvironment() >>> obs = env.reset(task_id="easy") >>> print(len(obs.alert_queue)) # Initial alerts >>> obs = env.step(SOCActionWrapper(type="query_host", hostname="WS-042")) """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): """Initialize the environment (actual state set in reset).""" super().__init__() self._state = SOCState(episode_id=str(uuid4()), step_count=0) self._network: Dict[str, List[Dict[str, Any]]] = {} self._task_def: Dict[str, Any] = {} self._alert_queue: List[Dict[str, Any]] = [] self._host_index: Dict[str, Dict[str, Any]] = {} # hostname -> host dict self._plan_entries: List[Dict[str, Any]] = [] self._last_forensics: Optional[ForensicsResult] = None # =========================================================================== # reset() # =========================================================================== def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> SOCObservation: """Reset the environment for a specific task. Args: seed: Ignored (environment is fully deterministic). episode_id: Optional custom episode ID. **kwargs: Must include task_id ('easy', 'medium', or 'hard'). Returns: Initial SOCObservation with alert queue and network state. """ task_id = kwargs.get("task_id", "easy") self._task_def = get_task(task_id) # Build deterministic network self._network = build_network() # Build hostname index for O(1) lookups self._host_index = {} for subnet_name, hosts in self._network.items(): for host in hosts: self._host_index[host["hostname"]] = host # Inject attack chain: mark compromised hosts, add malicious processes for threat in self._task_def["attack_chain"]: for hostname in threat["compromised_hosts"]: if hostname in self._host_index: host = self._host_index[hostname] host["status"] = "compromised" for proc in threat["malicious_processes"]: if proc not in host["running_processes"]: host["running_processes"].append(proc) # Initialize alert queue (deep copy so mutations don't affect task def) self._alert_queue = copy.deepcopy(self._task_def["initial_alerts"]) # Reset state eid = episode_id or str(uuid4()) self._state = SOCState( episode_id=eid, step_count=0, task_id=task_id, max_steps=self._task_def["max_steps"], total_reward=0.0, business_impact=self._task_def["initial_business_impact"], contained_threats=[], active_threats=[t["threat_id"] for t in self._task_def["attack_chain"]], blocked_iocs=[], isolated_subnets=[], forensics_run=[], killed_processes=[], queried_hosts=[], timeline=[], is_done=False, submitted_plan=False, ) self._plan_entries = [] self._last_forensics = None self._reset_rubric() return self._build_observation(reward=0.0, done=False) # =========================================================================== # step() # =========================================================================== def step( self, action: SOCActionWrapper, # type: ignore[override] timeout_s: Optional[float] = None, **kwargs: Any, ) -> SOCObservation: """Process one agent action. Args: action: SOCActionWrapper containing the typed action. timeout_s: Ignored. Returns: SOCObservation with updated state, reward, and done flag. """ if self._state.is_done: return self._build_observation(reward=0.0, done=True) # Increment step self._state.step_count += 1 # Convert wrapper to typed action typed_action = action.to_typed_action() # Dispatch to handler reward = 0.0 result_description = "unknown action" if isinstance(typed_action, QueryHost): reward, result_description = self._handle_query_host(typed_action) elif isinstance(typed_action, IsolateSegment): reward, result_description = self._handle_isolate_segment(typed_action) elif isinstance(typed_action, BlockIOC): reward, result_description = self._handle_block_ioc(typed_action) elif isinstance(typed_action, RunForensics): reward, result_description = self._handle_run_forensics(typed_action) elif isinstance(typed_action, KillProcess): reward, result_description = self._handle_kill_process(typed_action) elif isinstance(typed_action, SubmitContainmentPlan): reward, result_description = self._handle_submit_plan(typed_action) # Business impact grows each step (attacker progresses) if not self._state.is_done: impact_rate = self._task_def.get("impact_per_step", 0.02) # Reduce impact growth if threats are being contained active_ratio = len(self._state.active_threats) / max(1, len(self._task_def["attack_chain"])) self._state.business_impact = min( 1.0, self._state.business_impact + impact_rate * active_ratio, ) # Record timeline self._state.timeline.append({ "step": self._state.step_count, "action_type": typed_action.type, "target": self._get_action_target(typed_action), "result": result_description, "reward": reward, }) # Accumulate reward self._state.total_reward += reward # Check termination done = False if self._state.submitted_plan: done = True self._state.is_done = True elif self._state.step_count >= self._state.max_steps: done = True self._state.is_done = True reward -= 0.20 # Penalty for running out of time self._state.total_reward += (-0.20) return self._build_observation(reward=reward, done=done) # =========================================================================== # Action Handlers (return (reward, description)) # =========================================================================== def _handle_query_host(self, action: QueryHost) -> tuple[float, str]: """Query a host for status info.""" hostname = action.hostname self._last_forensics = None # Clear forensics from previous step if hostname not in self._host_index: return -0.05, f"Host '{hostname}' not found in network" host = self._host_index[hostname] # Reward for querying compromised hosts (useful investigation) reward = 0.0 if host["status"] == "compromised" and hostname not in self._state.queried_hosts: reward = 0.05 # Good: investigating a compromised host elif hostname in self._state.queried_hosts: reward = -0.02 # Penalty: re-querying same host wastes time self._state.queried_hosts.append(hostname) return reward, f"Queried {hostname}: status={host['status']}, procs={len(host['running_processes'])}" def _handle_isolate_segment(self, action: IsolateSegment) -> tuple[float, str]: """Isolate a network segment.""" subnet = action.subnet self._last_forensics = None if subnet not in self._network: return -0.05, f"Subnet '{subnet}' does not exist" if subnet in self._state.isolated_subnets: return -0.02, f"Subnet '{subnet}' is already isolated" # Isolate all hosts in the subnet for host in self._network[subnet]: host["status"] = "isolated" self._state.isolated_subnets.append(subnet) # Check if this contains any active threats reward = 0.0 threats_contained = [] for threat in self._task_def["attack_chain"]: if threat["threat_id"] in self._state.active_threats: # Check if any compromised hosts are in this subnet for ch in threat["compromised_hosts"]: if ch in self._host_index and self._host_index[ch]["subnet"] == subnet: threats_contained.append(threat["threat_id"]) break if threats_contained: reward = 0.15 * len(threats_contained) # Good: containing lateral movement for tid in threats_contained: if tid not in self._state.contained_threats: self._state.contained_threats.append(tid) if tid in self._state.active_threats: self._state.active_threats.remove(tid) # Check if this is an unnecessary isolation (business downtime) must_not_isolate = self._task_def["containment_requirements"].get("must_not_isolate", []) if subnet in must_not_isolate: reward -= 0.10 # Penalty: unnecessary downtime self._state.business_impact = min(1.0, self._state.business_impact + 0.08) return reward, f"Isolated subnet '{subnet}'. Threats contained: {threats_contained}" def _handle_block_ioc(self, action: BlockIOC) -> tuple[float, str]: """Block an IOC at the perimeter.""" ioc = action.ioc_value self._last_forensics = None if ioc in self._state.blocked_iocs: return -0.02, f"IOC '{ioc}' is already blocked" self._state.blocked_iocs.append(ioc) # Check if this IOC is relevant to any active threat reward = 0.0 relevant = False for threat in self._task_def["attack_chain"]: all_iocs = ( threat["iocs"].get("hashes", []) + threat["iocs"].get("ips", []) + threat["iocs"].get("domains", []) ) if ioc in all_iocs: relevant = True # Extra reward for blocking C2 server IPs if ioc in threat.get("c2_servers", []): reward += 0.15 # High value: cutting C2 else: reward += 0.10 # Good: blocking relevant IOC break if not relevant: reward = -0.03 # Noise: blocking irrelevant IOC return reward, f"Blocked IOC '{ioc}' (type={action.ioc_type}). Relevant: {relevant}" def _handle_run_forensics(self, action: RunForensics) -> tuple[float, str]: """Run forensic analysis on a host.""" hostname = action.hostname if hostname not in self._host_index: self._last_forensics = None return -0.05, f"Host '{hostname}' not found" host = self._host_index[hostname] # Build forensics result based on actual host state is_compromised = host["status"] == "compromised" malicious_procs = [] suspicious_files = [] network_conns = [] registry_mods = [] memory_artifacts = [] if is_compromised: # Find which threat(s) affect this host for threat in self._task_def["attack_chain"]: if hostname in threat["compromised_hosts"]: malicious_procs.extend(threat["malicious_processes"]) # Generate deterministic forensic artifacts for proc in threat["malicious_processes"]: suspicious_files.append(f"C:\\Windows\\Temp\\{proc}.dat") registry_mods.append(f"HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run\\{proc}") for c2 in threat.get("c2_servers", []): network_conns.append(f"{c2}:443") for ioc_hash in threat["iocs"].get("hashes", []): memory_artifacts.append(f"memory_inject_{ioc_hash[:8]}") self._last_forensics = ForensicsResult( hostname=hostname, malicious_processes=malicious_procs, suspicious_files=suspicious_files, network_connections=network_conns, registry_modifications=registry_mods, memory_artifacts=memory_artifacts, is_compromised=is_compromised, ) # Reward reward = 0.0 if hostname not in self._state.forensics_run: if is_compromised: reward = 0.10 # Good: found evidence else: reward = 0.02 # Cleared a host (some value) self._state.forensics_run.append(hostname) else: reward = -0.02 # Re-running forensics wastes time return reward, f"Forensics on {hostname}: compromised={is_compromised}, procs={malicious_procs}" def _handle_kill_process(self, action: KillProcess) -> tuple[float, str]: """Kill a process on a host.""" hostname = action.hostname process = action.process_name self._last_forensics = None if hostname not in self._host_index: return -0.05, f"Host '{hostname}' not found" host = self._host_index[hostname] if host["status"] == "isolated": return -0.02, f"Host '{hostname}' is isolated — cannot interact" if process not in host["running_processes"]: return -0.03, f"Process '{process}' not running on {hostname}" # Kill the process host["running_processes"].remove(process) self._state.killed_processes.append({"hostname": hostname, "process": process}) # Check if this was a malicious process reward = 0.0 was_malicious = False for threat in self._task_def["attack_chain"]: if hostname in threat["compromised_hosts"] and process in threat["malicious_processes"]: was_malicious = True reward = 0.15 # Major reward: stopping malicious activity # Check if all processes for this threat are killed all_killed = True for th_host in threat["compromised_hosts"]: for th_proc in threat["malicious_processes"]: still_running = ( th_host in self._host_index and th_proc in self._host_index[th_host]["running_processes"] ) if still_running: all_killed = False break if all_killed and threat["threat_id"] in self._state.active_threats: self._state.active_threats.remove(threat["threat_id"]) if threat["threat_id"] not in self._state.contained_threats: self._state.contained_threats.append(threat["threat_id"]) reward += 0.10 # Bonus: fully contained a threat break if not was_malicious: reward = -0.08 # Penalty: killing legitimate process = downtime self._state.business_impact = min(1.0, self._state.business_impact + 0.03) return reward, f"Killed '{process}' on {hostname}. Malicious: {was_malicious}" def _handle_submit_plan(self, action: SubmitContainmentPlan) -> tuple[float, str]: """Submit the final containment plan.""" self._last_forensics = None self._state.submitted_plan = True self._plan_entries = [entry.model_dump() for entry in action.plan] # Grade the episode final_score = grade_episode( task_id=self._state.task_id, task_def=self._task_def, killed_processes=self._state.killed_processes, blocked_iocs=self._state.blocked_iocs, forensics_run=self._state.forensics_run, isolated_subnets=self._state.isolated_subnets, submitted_plan=True, plan_entries=self._plan_entries, final_business_impact=self._state.business_impact, step_count=self._state.step_count, total_reward=self._state.total_reward, ) # Reward proportional to final grade reward = final_score * 1.0 # Scale: perfect score = 1.0 reward description = ( f"Containment plan submitted. " f"Grade: {final_score:.3f}. " f"Threats contained: {len(self._state.contained_threats)}/{len(self._task_def['attack_chain'])}. " f"Business impact: {self._state.business_impact:.2f}" ) return reward, description # =========================================================================== # Helpers # =========================================================================== def _build_observation(self, reward: float, done: bool) -> SOCObservation: """Build the observation from current state.""" # Compute network topology summary subnet_counts = {name: len(hosts) for name, hosts in self._network.items()} compromised = sum( 1 for hosts in self._network.values() for h in hosts if h["status"] == "compromised" ) isolated = sum( 1 for hosts in self._network.values() for h in hosts if h["status"] == "isolated" ) total = sum(len(hosts) for hosts in self._network.values()) topology = NetworkTopology( total_hosts=total, subnets=subnet_counts, compromised_count=compromised, isolated_count=isolated, online_count=total - compromised - isolated, ) # Build alert list alerts = [Alert(**a) for a in self._alert_queue] # Build timeline timeline = [ TimelineEntry( step=t["step"], action_type=t["action_type"], target=t["target"], result=t["result"], reward=t["reward"], ) for t in self._state.timeline ] # Compute final grade if done final_score_val = None grade_breakdown_val = None if done and self._state.submitted_plan: computed_score = grade_episode( task_id=self._state.task_id, task_def=self._task_def, killed_processes=self._state.killed_processes, blocked_iocs=self._state.blocked_iocs, forensics_run=self._state.forensics_run, isolated_subnets=self._state.isolated_subnets, submitted_plan=self._state.submitted_plan, plan_entries=self._plan_entries, final_business_impact=self._state.business_impact, step_count=self._state.step_count, total_reward=self._state.total_reward, ) final_score_val = round(computed_score, 4) grade_breakdown_val = { "threats_contained": len(self._state.contained_threats), "total_threats": len(self._task_def["attack_chain"]), "iocs_blocked": len(self._state.blocked_iocs), "hosts_forensics": len(self._state.forensics_run), "subnets_isolated": len(self._state.isolated_subnets), "business_impact": round(self._state.business_impact, 4), } return SOCObservation( alert_queue=alerts, network_topology=topology, host_forensics=self._last_forensics, timeline=timeline, business_impact_score=round(self._state.business_impact, 4), step_count=self._state.step_count, active_threats=list(self._state.active_threats), max_steps=self._state.max_steps, task_id=self._state.task_id, total_reward=round(self._state.total_reward, 4), final_score=final_score_val, grade_breakdown=grade_breakdown_val, done=done, reward=round(reward, 4), ) def _get_action_target(self, action: Any) -> str: """Extract the target string from a typed action for timeline logging.""" if isinstance(action, QueryHost): return action.hostname elif isinstance(action, IsolateSegment): return action.subnet elif isinstance(action, BlockIOC): return f"{action.ioc_type}:{action.ioc_value}" elif isinstance(action, RunForensics): return action.hostname elif isinstance(action, KillProcess): return f"{action.hostname}/{action.process_name}" elif isinstance(action, SubmitContainmentPlan): return f"{len(action.plan)} entries" return "unknown" @property def state(self) -> SOCState: """Get the current internal environment state.""" return self._state