| |
| |
| |
| |
| |
|
|
| """ |
| Data models for the CyberSOCEnv — Enterprise Cybersecurity Operations Center. |
| |
| Defines strict Pydantic models for: |
| - Observation: What the agent sees (alerts, forensics, network state, business impact) |
| - Action: What the agent can do (discriminated union of 6 action types) |
| - Internal state: Deterministic network graph, attack chains, and task tracking |
| """ |
|
|
| from __future__ import annotations |
|
|
| from enum import Enum |
| from typing import Annotated, Any, Dict, List, Literal, Optional, Union |
|
|
| from openenv.core.env_server.types import Action, Observation, State |
| from pydantic import BaseModel, ConfigDict, Field |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Severity(str, Enum): |
| """SIEM alert severity levels.""" |
| LOW = "low" |
| MEDIUM = "medium" |
| HIGH = "high" |
| CRITICAL = "critical" |
|
|
|
|
| class ThreatType(str, Enum): |
| """Classification of threat types in the SOC environment.""" |
| RANSOMWARE = "ransomware" |
| PHISHING = "phishing" |
| CREDENTIAL_THEFT = "credential_theft" |
| LATERAL_MOVEMENT = "lateral_movement" |
| C2_COMMUNICATION = "c2_communication" |
| DATA_EXFILTRATION = "data_exfiltration" |
| PRIVILEGE_ESCALATION = "privilege_escalation" |
| MALWARE = "malware" |
| CRYPTOMINING = "cryptomining" |
| SUPPLY_CHAIN = "supply_chain" |
| INSIDER_THREAT = "insider_threat" |
| WEBSHELL = "webshell" |
| BOTNET = "botnet" |
|
|
|
|
| class HostStatus(str, Enum): |
| """Host operational status.""" |
| ONLINE = "online" |
| COMPROMISED = "compromised" |
| ISOLATED = "isolated" |
| OFFLINE = "offline" |
|
|
|
|
| class SubnetRole(str, Enum): |
| """Business function of a network subnet.""" |
| CORPORATE = "corporate" |
| ENGINEERING = "engineering" |
| FINANCE = "finance" |
| DMZ = "dmz" |
| DATACENTER = "datacenter" |
| EXECUTIVE = "executive" |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Alert(BaseModel): |
| """A single SIEM/EDR alert in the queue.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| alert_id: str = Field(..., description="Unique alert identifier") |
| timestamp: str = Field(..., description="ISO-8601 timestamp of the alert") |
| source_host: str = Field(..., description="Hostname that generated the alert") |
| severity: Severity = Field(..., description="Alert severity level") |
| threat_type: ThreatType = Field(..., description="Classified threat type") |
| description: str = Field(..., description="Human-readable alert description") |
| ioc_indicators: List[str] = Field( |
| default_factory=list, |
| description="Indicators of compromise (IPs, hashes, domains)", |
| ) |
| subnet: str = Field(..., description="Subnet where the alert originated") |
| is_acknowledged: bool = Field(default=False, description="Whether the SOC analyst has acknowledged this alert") |
|
|
|
|
| class HostInfo(BaseModel): |
| """Summary information about a single network host.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| hostname: str = Field(..., description="Host FQDN") |
| ip_address: str = Field(..., description="IPv4 address") |
| subnet: str = Field(..., description="Subnet the host belongs to") |
| role: SubnetRole = Field(..., description="Business function") |
| status: HostStatus = Field(default=HostStatus.ONLINE, description="Current status") |
| running_processes: List[str] = Field(default_factory=list, description="Running process names") |
| open_ports: List[int] = Field(default_factory=list, description="Open TCP ports") |
| criticality: float = Field( |
| default=0.5, ge=0.0, le=1.0, |
| description="Business criticality score (0=low, 1=mission-critical)", |
| ) |
|
|
|
|
| class NetworkTopology(BaseModel): |
| """Summarized view of the 500-node enterprise network.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| total_hosts: int = Field(default=500, description="Total hosts in the network") |
| subnets: Dict[str, int] = Field( |
| default_factory=dict, |
| description="Map of subnet name -> host count", |
| ) |
| compromised_count: int = Field(default=0, description="Number of compromised hosts") |
| isolated_count: int = Field(default=0, description="Number of isolated hosts") |
| online_count: int = Field(default=500, description="Number of online hosts") |
|
|
|
|
| class ForensicsResult(BaseModel): |
| """Results from running forensics on a host.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| hostname: str = Field(..., description="Analyzed host") |
| malicious_processes: List[str] = Field(default_factory=list, description="Detected malicious processes") |
| suspicious_files: List[str] = Field(default_factory=list, description="Suspicious file paths found") |
| network_connections: List[str] = Field( |
| default_factory=list, |
| description="Suspicious outbound connections (ip:port)", |
| ) |
| registry_modifications: List[str] = Field(default_factory=list, description="Modified registry keys") |
| memory_artifacts: List[str] = Field(default_factory=list, description="In-memory IOCs found") |
| is_compromised: bool = Field(default=False, description="Whether forensics confirm compromise") |
|
|
|
|
| class TimelineEntry(BaseModel): |
| """A single entry in the analyst action timeline.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| step: int = Field(..., description="Step number when this action was taken") |
| action_type: str = Field(..., description="Type of action taken") |
| target: str = Field(..., description="Target of the action (host, subnet, IOC)") |
| result: str = Field(..., description="Outcome description") |
| reward: float = Field(default=0.0, description="Reward received for this action") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SOCObservation(Observation): |
| """What the SOC agent sees at each step. |
| |
| Extends OpenEnv Observation (inherits: done, reward, metadata). |
| """ |
|
|
| episode_id: str = Field( |
| default="", |
| description="Unique UUID for this episode — used by the RL training loop to prevent hash collisions in GRPO batched rollouts.", |
| ) |
| alert_queue: List[Alert] = Field( |
| default_factory=list, |
| description="Current queue of unresolved SIEM/EDR alerts", |
| ) |
| network_topology: NetworkTopology = Field( |
| default_factory=NetworkTopology, |
| description="Summary of the enterprise network state", |
| ) |
| host_forensics: Optional[ForensicsResult] = Field( |
| default=None, |
| description="Forensics results if RunForensics was the last action, else None", |
| ) |
| timeline: List[TimelineEntry] = Field( |
| default_factory=list, |
| description="Chronological log of all actions taken in this episode", |
| ) |
| business_impact_score: float = Field( |
| default=0.0, ge=0.0, le=1.0, |
| description="Current business impact (0=no impact, 1=catastrophic outage)", |
| ) |
| step_count: int = Field(default=0, ge=0, description="Current step number") |
| active_threats: List[str] = Field( |
| default_factory=list, |
| description="List of threat IDs that are still active/uncontained", |
| ) |
| max_steps: int = Field(default=30, description="Maximum steps allowed in this episode") |
| task_id: str = Field(default="easy", description="Current task identifier") |
| total_reward: float = Field(default=0.0, description="Accumulated episode reward") |
| final_score: Optional[float] = Field( |
| default=None, |
| description="Post-episode grader score (0.0-1.0). Only set when done=True and plan submitted.", |
| ) |
| grade_breakdown: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="Detailed grading breakdown. Only set when done=True and plan submitted.", |
| ) |
| correlation_results: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="Results from the most recent correlate_alerts call.", |
| ) |
| ioc_enrichment: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="Results from the most recent enrich_ioc call.", |
| ) |
| vulnerability_results: Optional[List[Dict[str, Any]]] = Field( |
| default=None, |
| description="Results from the most recent scan_host_vulnerabilities call.", |
| ) |
| playbook_result: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="Results from the most recent trigger_playbook call.", |
| ) |
| threat_graph_summary: Optional[str] = Field( |
| default=None, |
| description="Compact textual summary of the current threat graph.", |
| ) |
| available_playbooks: List[str] = Field( |
| default_factory=list, |
| description="Names of SOAR playbooks available to the agent.", |
| ) |
| reward_dimensions: Optional[Dict[str, float]] = Field( |
| default=None, |
| description=( |
| "Running partial scores for each of the 10 grading dimensions, " |
| "updated every step for live GRPO credit-assignment signals. " |
| "Keys match grade_breakdown (threat_containment, ioc_blocking, etc.)." |
| ), |
| ) |
| active_turn: str = Field( |
| default="blue", |
| description="Whose turn it is next: 'blue' or 'red'. Used by FSP inference loops.", |
| ) |
| red_observation: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description=( |
| "Red Team's current view of the world (populated when active_turn='red'). " |
| "Contains compromised_hosts and blue_actions_detected." |
| ), |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class QueryHost(Action): |
| """Query a specific host for status, processes, and connections.""" |
| type: Literal["query_host"] = Field(default="query_host", description="Action discriminator") |
| hostname: str = Field(..., description="Target hostname to query") |
|
|
|
|
| class IsolateSegment(Action): |
| """Isolate an entire network segment, or a single host from the network.""" |
| type: Literal["isolate_segment"] = Field(default="isolate_segment", description="Action discriminator") |
| subnet: str = Field(default="", description="Subnet name to isolate (mutually exclusive with target_host)") |
| target_host: Optional[str] = Field( |
| default=None, |
| description="If set, isolate only this single host instead of the whole subnet", |
| ) |
| reason: str = Field(default="", description="Justification for isolation") |
|
|
|
|
| class BlockIOC(Action): |
| """Block an Indicator of Compromise at the perimeter firewall.""" |
| type: Literal["block_ioc"] = Field(default="block_ioc", description="Action discriminator") |
| ioc_value: str = Field(..., description="The IOC to block (IP, domain, or file hash)") |
| ioc_type: Literal["ip", "domain", "hash"] = Field(..., description="Type of IOC") |
|
|
|
|
| class RunForensics(Action): |
| """Run deep forensic analysis on a specific host.""" |
| type: Literal["run_forensics"] = Field(default="run_forensics", description="Action discriminator") |
| hostname: str = Field(..., description="Target hostname for forensics") |
|
|
|
|
| class KillProcess(Action): |
| """Terminate a specific process on a host.""" |
| type: Literal["kill_process"] = Field(default="kill_process", description="Action discriminator") |
| hostname: str = Field(..., description="Host where the process is running") |
| process_name: str = Field(..., description="Name of the process to terminate") |
|
|
|
|
| class ContainmentEntry(BaseModel): |
| """A single entry in the containment plan.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| threat_id: str = Field(..., description="Threat being addressed") |
| actions_taken: List[str] = Field(..., description="List of actions taken to contain this threat") |
| root_cause: str = Field(..., description="Identified root cause") |
| confidence: float = Field( |
| ..., ge=0.0, le=1.0, |
| description="Confidence in the containment (0-1)", |
| ) |
|
|
|
|
| class SubmitContainmentPlan(Action): |
| """Submit the final containment plan to end the episode.""" |
| type: Literal["submit_containment_plan"] = Field( |
| default="submit_containment_plan", description="Action discriminator" |
| ) |
| plan: List[ContainmentEntry] = Field( |
| ..., description="The containment plan addressing all identified threats" |
| ) |
| executive_summary: str = Field( |
| ..., description="Brief executive summary for CISO reporting" |
| ) |
|
|
|
|
| class CorrelateAlerts(Action): |
| """Correlate two or more alerts to find shared entities/IOCs.""" |
| type: Literal["correlate_alerts"] = Field(default="correlate_alerts") |
| alert_ids: List[str] = Field(..., min_length=2, description="At least 2 alert IDs to correlate") |
|
|
|
|
| class EnrichIOC(Action): |
| """Enrich an IOC with threat-intelligence data (actor, MITRE TTPs).""" |
| type: Literal["enrich_ioc"] = Field(default="enrich_ioc") |
| ioc_value: str = Field(..., description="The IOC value to enrich") |
| ioc_type: Literal["ip", "domain", "hash", "filename"] = Field(..., description="Type of IOC") |
|
|
|
|
| class ScanHostVulnerabilities(Action): |
| """Run a vulnerability scan on a host to discover CVEs.""" |
| type: Literal["scan_host_vulnerabilities"] = Field(default="scan_host_vulnerabilities") |
| hostname: str = Field(..., description="Target hostname for vulnerability scan") |
|
|
|
|
| class TerminatePID(Action): |
| """Terminate a process by PID on a host.""" |
| type: Literal["terminate_pid"] = Field(default="terminate_pid") |
| hostname: str = Field(..., description="Host where the PID is running") |
| pid: str = Field(..., description="Target process PID") |
|
|
|
|
| class CreateFirewallRule(Action): |
| """Create a host-level firewall rule for a target IP.""" |
| type: Literal["create_firewall_rule"] = Field(default="create_firewall_rule") |
| hostname: str = Field(..., description="Host where firewall rule is applied") |
| target_ip: str = Field(..., description="Target IP for the rule") |
| action: Literal["drop", "allow"] = Field(..., description="Firewall action to apply") |
|
|
|
|
| class QuarantineFile(Action): |
| """Quarantine a suspicious file on a host.""" |
| type: Literal["quarantine_file"] = Field(default="quarantine_file") |
| hostname: str = Field(..., description="Host where file is located") |
| file_path: str = Field(..., description="File path to quarantine") |
|
|
|
|
| |
| |
| |
|
|
| class LateralPivot(Action): |
| """Red Team: move laterally from a compromised host to a new target.""" |
| type: Literal["lateral_pivot"] = Field(default="lateral_pivot") |
| source_host: str = Field(..., description="Already-compromised host used as the pivot point") |
| target_host: str = Field(..., description="Destination host to compromise") |
|
|
|
|
| class DeployPayload(Action): |
| """Red Team: deploy a malicious payload on a host Red already controls.""" |
| type: Literal["deploy_payload"] = Field(default="deploy_payload") |
| hostname: str = Field(..., description="Compromised host to deploy payload on") |
| payload_type: Literal["ransomware", "exfiltration", "c2"] = Field( |
| ..., description="Class of payload to deploy" |
| ) |
|
|
|
|
| class EvadeDetection(Action): |
| """Red Team: apply an evasion technique on a compromised host.""" |
| type: Literal["evade_detection"] = Field(default="evade_detection") |
| hostname: str = Field(..., description="Compromised host to apply evasion on") |
| technique: Literal["migrate_pid", "clear_logs"] = Field( |
| ..., |
| description=( |
| "migrate_pid: rename running malicious processes to blend with system names; " |
| "clear_logs: remove SIEM alerts originating from this host" |
| ), |
| ) |
|
|
|
|
| class PassTurn(Action): |
| """Red Team: remain stealthy and take no action this turn.""" |
| type: Literal["pass_turn"] = Field(default="pass_turn") |
|
|
|
|
| |
| RED_ACTION_TYPES: frozenset = frozenset( |
| {"lateral_pivot", "deploy_payload", "evade_detection", "pass_turn"} |
| ) |
|
|
| |
| RedAction = Annotated[ |
| Union[LateralPivot, DeployPayload, EvadeDetection, PassTurn], |
| Field(discriminator="type"), |
| ] |
|
|
|
|
| class RedActionWrapper(Action): |
| """Wrapper for Red Team actions — mirrors SOCActionWrapper for the WS/HTTP layer.""" |
|
|
| type: str = Field(..., description="Red action type discriminator") |
| model_config = ConfigDict(extra="allow") |
|
|
| def to_typed_action(self): |
| """Deserialize to the correctly-typed Red action.""" |
| data = self.model_dump(exclude={"metadata"}) |
| action_map = { |
| "lateral_pivot": LateralPivot, |
| "deploy_payload": DeployPayload, |
| "evade_detection": EvadeDetection, |
| "pass_turn": PassTurn, |
| } |
| cls = action_map.get(data["type"]) |
| if cls is None: |
| raise ValueError( |
| f"Unknown red action type: {data['type']}. " |
| f"Valid types: {list(action_map)}" |
| ) |
| return cls(**data) |
|
|
|
|
| |
| SOCAction = Annotated[ |
| Union[ |
| QueryHost, |
| IsolateSegment, |
| BlockIOC, |
| RunForensics, |
| KillProcess, |
| SubmitContainmentPlan, |
| CorrelateAlerts, |
| EnrichIOC, |
| ScanHostVulnerabilities, |
| TerminatePID, |
| CreateFirewallRule, |
| QuarantineFile, |
| ], |
| Field(discriminator="type"), |
| ] |
|
|
| |
| class SOCActionWrapper(Action): |
| """Wrapper that deserializes the discriminated union action. |
| |
| OpenEnv's create_app expects a single Action subclass. This wrapper |
| uses a discriminated union field so the HTTP/WS layer can parse |
| any of the 6 action types from a flat JSON payload. |
| |
| Client sends: {"action": {"type": "query_host", "hostname": "WS-001"}} |
| The wrapper validates -> QueryHost(hostname="WS-001") |
| """ |
| type: str = Field(..., description="Action type discriminator") |
|
|
| model_config = ConfigDict(extra="allow") |
|
|
| def to_typed_action(self): |
| """Convert the raw wrapper into the correctly typed action.""" |
| data = self.model_dump(exclude={"metadata"}) |
| action_map = { |
| "query_host": QueryHost, |
| "isolate_segment": IsolateSegment, |
| "block_ioc": BlockIOC, |
| "run_forensics": RunForensics, |
| "kill_process": KillProcess, |
| "submit_containment_plan": SubmitContainmentPlan, |
| "correlate_alerts": CorrelateAlerts, |
| "enrich_ioc": EnrichIOC, |
| "scan_host_vulnerabilities": ScanHostVulnerabilities, |
| "terminate_pid": TerminatePID, |
| "create_firewall_rule": CreateFirewallRule, |
| "quarantine_file": QuarantineFile, |
| } |
| cls = action_map.get(data["type"]) |
| if cls is None: |
| raise ValueError( |
| f"Unknown action type: {data['type']}. " |
| f"Valid types: {list(action_map.keys())}" |
| ) |
| return cls(**data) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SOCState(State): |
| """Internal environment state tracking the attack simulation. |
| |
| Extends OpenEnv State (inherits: episode_id, step_count). |
| Uses extra='allow' from base State. |
| """ |
|
|
| task_id: str = Field(default="easy", description="Current task: 'easy', 'medium', or 'hard'") |
| max_steps: int = Field(default=30, description="Maximum steps for this episode") |
| total_reward: float = Field(default=0.0, description="Accumulated reward") |
| business_impact: float = Field(default=0.0, ge=0.0, le=1.0, description="Current business impact score") |
| contained_threats: List[str] = Field(default_factory=list, description="Threat IDs that have been contained") |
| active_threats: List[str] = Field(default_factory=list, description="Currently active threat IDs") |
| blocked_iocs: List[str] = Field(default_factory=list, description="IOCs blocked at perimeter") |
| isolated_subnets: List[str] = Field(default_factory=list, description="Isolated network segments") |
| forensics_run: List[str] = Field(default_factory=list, description="Hosts that had forensics run") |
| killed_processes: List[Dict[str, str]] = Field(default_factory=list, description="Processes killed") |
| queried_hosts: List[str] = Field(default_factory=list, description="Hosts queried") |
| timeline: List[Dict[str, Any]] = Field(default_factory=list, description="Action timeline") |
| is_done: bool = Field(default=False, description="Whether episode has ended") |
| submitted_plan: bool = Field(default=False, description="Whether containment plan was submitted") |
| enriched_iocs: List[str] = Field(default_factory=list, description="IOCs that have been threat-intel enriched") |
| scanned_hosts: List[str] = Field(default_factory=list, description="Hosts that had vulnerability scans") |
| correlated_alert_pairs: List[Any] = Field(default_factory=list, description="Pairs/groups of alert IDs correlated together") |
| live_requirements: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="Mutable copy of containment_requirements (for adaptive grading).", |
| ) |
| active_turn: str = Field( |
| default="blue", |
| description="Current active turn in the FSP engine: 'blue' or 'red'.", |
| ) |
|
|