# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the CyberSOCEnv — Enterprise Cybersecurity Operations Center. Defines strict Pydantic models for: - Observation: What the agent sees (alerts, forensics, network state, business impact) - Action: What the agent can do (discriminated union of 6 action types) - Internal state: Deterministic network graph, attack chains, and task tracking """ from __future__ import annotations from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Union from openenv.core.env_server.types import Action, Observation, State from pydantic import BaseModel, ConfigDict, Field # ============================================================================= # Enums # ============================================================================= class Severity(str, Enum): """SIEM alert severity levels.""" LOW = "low" MEDIUM = "medium" HIGH = "high" CRITICAL = "critical" class ThreatType(str, Enum): """Classification of threat types in the SOC environment.""" RANSOMWARE = "ransomware" PHISHING = "phishing" CREDENTIAL_THEFT = "credential_theft" LATERAL_MOVEMENT = "lateral_movement" C2_COMMUNICATION = "c2_communication" DATA_EXFILTRATION = "data_exfiltration" PRIVILEGE_ESCALATION = "privilege_escalation" MALWARE = "malware" CRYPTOMINING = "cryptomining" SUPPLY_CHAIN = "supply_chain" INSIDER_THREAT = "insider_threat" WEBSHELL = "webshell" BOTNET = "botnet" class HostStatus(str, Enum): """Host operational status.""" ONLINE = "online" COMPROMISED = "compromised" ISOLATED = "isolated" OFFLINE = "offline" class SubnetRole(str, Enum): """Business function of a network subnet.""" CORPORATE = "corporate" ENGINEERING = "engineering" FINANCE = "finance" DMZ = "dmz" DATACENTER = "datacenter" EXECUTIVE = "executive" # ============================================================================= # Alert & Network Sub-Models (used in Observation) # ============================================================================= class Alert(BaseModel): """A single SIEM/EDR alert in the queue.""" model_config = ConfigDict(extra="forbid") alert_id: str = Field(..., description="Unique alert identifier") timestamp: str = Field(..., description="ISO-8601 timestamp of the alert") source_host: str = Field(..., description="Hostname that generated the alert") severity: Severity = Field(..., description="Alert severity level") threat_type: ThreatType = Field(..., description="Classified threat type") description: str = Field(..., description="Human-readable alert description") ioc_indicators: List[str] = Field( default_factory=list, description="Indicators of compromise (IPs, hashes, domains)", ) subnet: str = Field(..., description="Subnet where the alert originated") is_acknowledged: bool = Field(default=False, description="Whether the SOC analyst has acknowledged this alert") class HostInfo(BaseModel): """Summary information about a single network host.""" model_config = ConfigDict(extra="forbid") hostname: str = Field(..., description="Host FQDN") ip_address: str = Field(..., description="IPv4 address") subnet: str = Field(..., description="Subnet the host belongs to") role: SubnetRole = Field(..., description="Business function") status: HostStatus = Field(default=HostStatus.ONLINE, description="Current status") running_processes: List[str] = Field(default_factory=list, description="Running process names") open_ports: List[int] = Field(default_factory=list, description="Open TCP ports") criticality: float = Field( default=0.5, ge=0.0, le=1.0, description="Business criticality score (0=low, 1=mission-critical)", ) class NetworkTopology(BaseModel): """Summarized view of the 500-node enterprise network.""" model_config = ConfigDict(extra="forbid") total_hosts: int = Field(default=500, description="Total hosts in the network") subnets: Dict[str, int] = Field( default_factory=dict, description="Map of subnet name -> host count", ) compromised_count: int = Field(default=0, description="Number of compromised hosts") isolated_count: int = Field(default=0, description="Number of isolated hosts") online_count: int = Field(default=500, description="Number of online hosts") class ForensicsResult(BaseModel): """Results from running forensics on a host.""" model_config = ConfigDict(extra="forbid") hostname: str = Field(..., description="Analyzed host") malicious_processes: List[str] = Field(default_factory=list, description="Detected malicious processes") suspicious_files: List[str] = Field(default_factory=list, description="Suspicious file paths found") network_connections: List[str] = Field( default_factory=list, description="Suspicious outbound connections (ip:port)", ) registry_modifications: List[str] = Field(default_factory=list, description="Modified registry keys") memory_artifacts: List[str] = Field(default_factory=list, description="In-memory IOCs found") is_compromised: bool = Field(default=False, description="Whether forensics confirm compromise") class TimelineEntry(BaseModel): """A single entry in the analyst action timeline.""" model_config = ConfigDict(extra="forbid") step: int = Field(..., description="Step number when this action was taken") action_type: str = Field(..., description="Type of action taken") target: str = Field(..., description="Target of the action (host, subnet, IOC)") result: str = Field(..., description="Outcome description") reward: float = Field(default=0.0, description="Reward received for this action") # ============================================================================= # Observation # ============================================================================= class SOCObservation(Observation): """What the SOC agent sees at each step. Extends OpenEnv Observation (inherits: done, reward, metadata). """ alert_queue: List[Alert] = Field( default_factory=list, description="Current queue of unresolved SIEM/EDR alerts", ) network_topology: NetworkTopology = Field( default_factory=NetworkTopology, description="Summary of the enterprise network state", ) host_forensics: Optional[ForensicsResult] = Field( default=None, description="Forensics results if RunForensics was the last action, else None", ) timeline: List[TimelineEntry] = Field( default_factory=list, description="Chronological log of all actions taken in this episode", ) business_impact_score: float = Field( default=0.0, ge=0.0, le=1.0, description="Current business impact (0=no impact, 1=catastrophic outage)", ) step_count: int = Field(default=0, ge=0, description="Current step number") active_threats: List[str] = Field( default_factory=list, description="List of threat IDs that are still active/uncontained", ) max_steps: int = Field(default=30, description="Maximum steps allowed in this episode") task_id: str = Field(default="easy", description="Current task identifier") total_reward: float = Field(default=0.0, description="Accumulated episode reward") final_score: Optional[float] = Field( default=None, description="Post-episode grader score (0.0-1.0). Only set when done=True and plan submitted.", ) grade_breakdown: Optional[Dict[str, Any]] = Field( default=None, description="Detailed grading breakdown. Only set when done=True and plan submitted.", ) # ============================================================================= # Actions (Discriminated Union) # ============================================================================= class QueryHost(Action): """Query a specific host for status, processes, and connections.""" type: Literal["query_host"] = Field(default="query_host", description="Action discriminator") hostname: str = Field(..., description="Target hostname to query") class IsolateSegment(Action): """Isolate an entire network segment from the rest of the network.""" type: Literal["isolate_segment"] = Field(default="isolate_segment", description="Action discriminator") subnet: str = Field(..., description="Subnet name to isolate (e.g. 'finance', 'engineering')") reason: str = Field(default="", description="Justification for isolation") class BlockIOC(Action): """Block an Indicator of Compromise at the perimeter firewall.""" type: Literal["block_ioc"] = Field(default="block_ioc", description="Action discriminator") ioc_value: str = Field(..., description="The IOC to block (IP, domain, or file hash)") ioc_type: Literal["ip", "domain", "hash"] = Field(..., description="Type of IOC") class RunForensics(Action): """Run deep forensic analysis on a specific host.""" type: Literal["run_forensics"] = Field(default="run_forensics", description="Action discriminator") hostname: str = Field(..., description="Target hostname for forensics") class KillProcess(Action): """Terminate a specific process on a host.""" type: Literal["kill_process"] = Field(default="kill_process", description="Action discriminator") hostname: str = Field(..., description="Host where the process is running") process_name: str = Field(..., description="Name of the process to terminate") class ContainmentEntry(BaseModel): """A single entry in the containment plan.""" model_config = ConfigDict(extra="forbid") threat_id: str = Field(..., description="Threat being addressed") actions_taken: List[str] = Field(..., description="List of actions taken to contain this threat") root_cause: str = Field(..., description="Identified root cause") confidence: float = Field( ..., ge=0.0, le=1.0, description="Confidence in the containment (0-1)", ) class SubmitContainmentPlan(Action): """Submit the final containment plan to end the episode.""" type: Literal["submit_containment_plan"] = Field( default="submit_containment_plan", description="Action discriminator" ) plan: List[ContainmentEntry] = Field( ..., description="The containment plan addressing all identified threats" ) executive_summary: str = Field( ..., description="Brief executive summary for CISO reporting" ) # Discriminated union of all SOC actions SOCAction = Annotated[ Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan], Field(discriminator="type"), ] # Wrapper model so OpenEnv's create_app can accept it as a single Action class class SOCActionWrapper(Action): """Wrapper that deserializes the discriminated union action. OpenEnv's create_app expects a single Action subclass. This wrapper uses a discriminated union field so the HTTP/WS layer can parse any of the 6 action types from a flat JSON payload. Client sends: {"action": {"type": "query_host", "hostname": "WS-001"}} The wrapper validates -> QueryHost(hostname="WS-001") """ type: str = Field(..., description="Action type discriminator") model_config = ConfigDict(extra="allow") # Allow action-specific fields def to_typed_action(self) -> Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan]: """Convert the raw wrapper into the correctly typed action.""" data = self.model_dump(exclude={"metadata"}) action_map = { "query_host": QueryHost, "isolate_segment": IsolateSegment, "block_ioc": BlockIOC, "run_forensics": RunForensics, "kill_process": KillProcess, "submit_containment_plan": SubmitContainmentPlan, } cls = action_map.get(data["type"]) if cls is None: raise ValueError( f"Unknown action type: {data['type']}. " f"Valid types: {list(action_map.keys())}" ) return cls(**data) # ============================================================================= # Internal State (not exposed to agent directly) # ============================================================================= class SOCState(State): """Internal environment state tracking the attack simulation. Extends OpenEnv State (inherits: episode_id, step_count). Uses extra='allow' from base State. """ task_id: str = Field(default="easy", description="Current task: 'easy', 'medium', or 'hard'") max_steps: int = Field(default=30, description="Maximum steps for this episode") total_reward: float = Field(default=0.0, description="Accumulated reward") business_impact: float = Field(default=0.0, ge=0.0, le=1.0, description="Current business impact score") contained_threats: List[str] = Field(default_factory=list, description="Threat IDs that have been contained") active_threats: List[str] = Field(default_factory=list, description="Currently active threat IDs") blocked_iocs: List[str] = Field(default_factory=list, description="IOCs blocked at perimeter") isolated_subnets: List[str] = Field(default_factory=list, description="Isolated network segments") forensics_run: List[str] = Field(default_factory=list, description="Hosts that had forensics run") killed_processes: List[Dict[str, str]] = Field(default_factory=list, description="Processes killed") queried_hosts: List[str] = Field(default_factory=list, description="Hosts queried") timeline: List[Dict[str, Any]] = Field(default_factory=list, description="Action timeline") is_done: bool = Field(default=False, description="Whether episode has ended") submitted_plan: bool = Field(default=False, description="Whether containment plan was submitted")