| |
| |
| |
| |
| |
|
|
| """ |
| Data models for the CyberSOCEnv — Enterprise Cybersecurity Operations Center. |
| |
| Defines strict Pydantic models for: |
| - Observation: What the agent sees (alerts, forensics, network state, business impact) |
| - Action: What the agent can do (discriminated union of 6 action types) |
| - Internal state: Deterministic network graph, attack chains, and task tracking |
| """ |
|
|
| from __future__ import annotations |
|
|
| from enum import Enum |
| from typing import Annotated, Any, Dict, List, Literal, Optional, Union |
|
|
| from openenv.core.env_server.types import Action, Observation, State |
| from pydantic import BaseModel, ConfigDict, Field |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Severity(str, Enum): |
| """SIEM alert severity levels.""" |
| LOW = "low" |
| MEDIUM = "medium" |
| HIGH = "high" |
| CRITICAL = "critical" |
|
|
|
|
| class ThreatType(str, Enum): |
| """Classification of threat types in the SOC environment.""" |
| RANSOMWARE = "ransomware" |
| PHISHING = "phishing" |
| CREDENTIAL_THEFT = "credential_theft" |
| LATERAL_MOVEMENT = "lateral_movement" |
| C2_COMMUNICATION = "c2_communication" |
| DATA_EXFILTRATION = "data_exfiltration" |
| PRIVILEGE_ESCALATION = "privilege_escalation" |
| MALWARE = "malware" |
| CRYPTOMINING = "cryptomining" |
| SUPPLY_CHAIN = "supply_chain" |
| INSIDER_THREAT = "insider_threat" |
| WEBSHELL = "webshell" |
| BOTNET = "botnet" |
|
|
|
|
| class HostStatus(str, Enum): |
| """Host operational status.""" |
| ONLINE = "online" |
| COMPROMISED = "compromised" |
| ISOLATED = "isolated" |
| OFFLINE = "offline" |
|
|
|
|
| class SubnetRole(str, Enum): |
| """Business function of a network subnet.""" |
| CORPORATE = "corporate" |
| ENGINEERING = "engineering" |
| FINANCE = "finance" |
| DMZ = "dmz" |
| DATACENTER = "datacenter" |
| EXECUTIVE = "executive" |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Alert(BaseModel): |
| """A single SIEM/EDR alert in the queue.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| alert_id: str = Field(..., description="Unique alert identifier") |
| timestamp: str = Field(..., description="ISO-8601 timestamp of the alert") |
| source_host: str = Field(..., description="Hostname that generated the alert") |
| severity: Severity = Field(..., description="Alert severity level") |
| threat_type: ThreatType = Field(..., description="Classified threat type") |
| description: str = Field(..., description="Human-readable alert description") |
| ioc_indicators: List[str] = Field( |
| default_factory=list, |
| description="Indicators of compromise (IPs, hashes, domains)", |
| ) |
| subnet: str = Field(..., description="Subnet where the alert originated") |
| is_acknowledged: bool = Field(default=False, description="Whether the SOC analyst has acknowledged this alert") |
|
|
|
|
| class HostInfo(BaseModel): |
| """Summary information about a single network host.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| hostname: str = Field(..., description="Host FQDN") |
| ip_address: str = Field(..., description="IPv4 address") |
| subnet: str = Field(..., description="Subnet the host belongs to") |
| role: SubnetRole = Field(..., description="Business function") |
| status: HostStatus = Field(default=HostStatus.ONLINE, description="Current status") |
| running_processes: List[str] = Field(default_factory=list, description="Running process names") |
| open_ports: List[int] = Field(default_factory=list, description="Open TCP ports") |
| criticality: float = Field( |
| default=0.5, ge=0.0, le=1.0, |
| description="Business criticality score (0=low, 1=mission-critical)", |
| ) |
|
|
|
|
| class NetworkTopology(BaseModel): |
| """Summarized view of the 500-node enterprise network.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| total_hosts: int = Field(default=500, description="Total hosts in the network") |
| subnets: Dict[str, int] = Field( |
| default_factory=dict, |
| description="Map of subnet name -> host count", |
| ) |
| compromised_count: int = Field(default=0, description="Number of compromised hosts") |
| isolated_count: int = Field(default=0, description="Number of isolated hosts") |
| online_count: int = Field(default=500, description="Number of online hosts") |
|
|
|
|
| class ForensicsResult(BaseModel): |
| """Results from running forensics on a host.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| hostname: str = Field(..., description="Analyzed host") |
| malicious_processes: List[str] = Field(default_factory=list, description="Detected malicious processes") |
| suspicious_files: List[str] = Field(default_factory=list, description="Suspicious file paths found") |
| network_connections: List[str] = Field( |
| default_factory=list, |
| description="Suspicious outbound connections (ip:port)", |
| ) |
| registry_modifications: List[str] = Field(default_factory=list, description="Modified registry keys") |
| memory_artifacts: List[str] = Field(default_factory=list, description="In-memory IOCs found") |
| is_compromised: bool = Field(default=False, description="Whether forensics confirm compromise") |
|
|
|
|
| class TimelineEntry(BaseModel): |
| """A single entry in the analyst action timeline.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| step: int = Field(..., description="Step number when this action was taken") |
| action_type: str = Field(..., description="Type of action taken") |
| target: str = Field(..., description="Target of the action (host, subnet, IOC)") |
| result: str = Field(..., description="Outcome description") |
| reward: float = Field(default=0.0, description="Reward received for this action") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SOCObservation(Observation): |
| """What the SOC agent sees at each step. |
| |
| Extends OpenEnv Observation (inherits: done, reward, metadata). |
| """ |
|
|
| alert_queue: List[Alert] = Field( |
| default_factory=list, |
| description="Current queue of unresolved SIEM/EDR alerts", |
| ) |
| network_topology: NetworkTopology = Field( |
| default_factory=NetworkTopology, |
| description="Summary of the enterprise network state", |
| ) |
| host_forensics: Optional[ForensicsResult] = Field( |
| default=None, |
| description="Forensics results if RunForensics was the last action, else None", |
| ) |
| timeline: List[TimelineEntry] = Field( |
| default_factory=list, |
| description="Chronological log of all actions taken in this episode", |
| ) |
| business_impact_score: float = Field( |
| default=0.0, ge=0.0, le=1.0, |
| description="Current business impact (0=no impact, 1=catastrophic outage)", |
| ) |
| step_count: int = Field(default=0, ge=0, description="Current step number") |
| active_threats: List[str] = Field( |
| default_factory=list, |
| description="List of threat IDs that are still active/uncontained", |
| ) |
| max_steps: int = Field(default=30, description="Maximum steps allowed in this episode") |
| task_id: str = Field(default="easy", description="Current task identifier") |
| total_reward: float = Field(default=0.0, description="Accumulated episode reward") |
| final_score: Optional[float] = Field( |
| default=None, |
| description="Post-episode grader score (0.0-1.0). Only set when done=True and plan submitted.", |
| ) |
| grade_breakdown: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="Detailed grading breakdown. Only set when done=True and plan submitted.", |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class QueryHost(Action): |
| """Query a specific host for status, processes, and connections.""" |
| type: Literal["query_host"] = Field(default="query_host", description="Action discriminator") |
| hostname: str = Field(..., description="Target hostname to query") |
|
|
|
|
| class IsolateSegment(Action): |
| """Isolate an entire network segment from the rest of the network.""" |
| type: Literal["isolate_segment"] = Field(default="isolate_segment", description="Action discriminator") |
| subnet: str = Field(..., description="Subnet name to isolate (e.g. 'finance', 'engineering')") |
| reason: str = Field(default="", description="Justification for isolation") |
|
|
|
|
| class BlockIOC(Action): |
| """Block an Indicator of Compromise at the perimeter firewall.""" |
| type: Literal["block_ioc"] = Field(default="block_ioc", description="Action discriminator") |
| ioc_value: str = Field(..., description="The IOC to block (IP, domain, or file hash)") |
| ioc_type: Literal["ip", "domain", "hash"] = Field(..., description="Type of IOC") |
|
|
|
|
| class RunForensics(Action): |
| """Run deep forensic analysis on a specific host.""" |
| type: Literal["run_forensics"] = Field(default="run_forensics", description="Action discriminator") |
| hostname: str = Field(..., description="Target hostname for forensics") |
|
|
|
|
| class KillProcess(Action): |
| """Terminate a specific process on a host.""" |
| type: Literal["kill_process"] = Field(default="kill_process", description="Action discriminator") |
| hostname: str = Field(..., description="Host where the process is running") |
| process_name: str = Field(..., description="Name of the process to terminate") |
|
|
|
|
| class ContainmentEntry(BaseModel): |
| """A single entry in the containment plan.""" |
| model_config = ConfigDict(extra="forbid") |
|
|
| threat_id: str = Field(..., description="Threat being addressed") |
| actions_taken: List[str] = Field(..., description="List of actions taken to contain this threat") |
| root_cause: str = Field(..., description="Identified root cause") |
| confidence: float = Field( |
| ..., ge=0.0, le=1.0, |
| description="Confidence in the containment (0-1)", |
| ) |
|
|
|
|
| class SubmitContainmentPlan(Action): |
| """Submit the final containment plan to end the episode.""" |
| type: Literal["submit_containment_plan"] = Field( |
| default="submit_containment_plan", description="Action discriminator" |
| ) |
| plan: List[ContainmentEntry] = Field( |
| ..., description="The containment plan addressing all identified threats" |
| ) |
| executive_summary: str = Field( |
| ..., description="Brief executive summary for CISO reporting" |
| ) |
|
|
|
|
| |
| SOCAction = Annotated[ |
| Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan], |
| Field(discriminator="type"), |
| ] |
|
|
| |
| class SOCActionWrapper(Action): |
| """Wrapper that deserializes the discriminated union action. |
| |
| OpenEnv's create_app expects a single Action subclass. This wrapper |
| uses a discriminated union field so the HTTP/WS layer can parse |
| any of the 6 action types from a flat JSON payload. |
| |
| Client sends: {"action": {"type": "query_host", "hostname": "WS-001"}} |
| The wrapper validates -> QueryHost(hostname="WS-001") |
| """ |
| type: str = Field(..., description="Action type discriminator") |
|
|
| model_config = ConfigDict(extra="allow") |
|
|
| def to_typed_action(self) -> Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan]: |
| """Convert the raw wrapper into the correctly typed action.""" |
| data = self.model_dump(exclude={"metadata"}) |
| action_map = { |
| "query_host": QueryHost, |
| "isolate_segment": IsolateSegment, |
| "block_ioc": BlockIOC, |
| "run_forensics": RunForensics, |
| "kill_process": KillProcess, |
| "submit_containment_plan": SubmitContainmentPlan, |
| } |
| cls = action_map.get(data["type"]) |
| if cls is None: |
| raise ValueError( |
| f"Unknown action type: {data['type']}. " |
| f"Valid types: {list(action_map.keys())}" |
| ) |
| return cls(**data) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SOCState(State): |
| """Internal environment state tracking the attack simulation. |
| |
| Extends OpenEnv State (inherits: episode_id, step_count). |
| Uses extra='allow' from base State. |
| """ |
|
|
| task_id: str = Field(default="easy", description="Current task: 'easy', 'medium', or 'hard'") |
| max_steps: int = Field(default=30, description="Maximum steps for this episode") |
| total_reward: float = Field(default=0.0, description="Accumulated reward") |
| business_impact: float = Field(default=0.0, ge=0.0, le=1.0, description="Current business impact score") |
| contained_threats: List[str] = Field(default_factory=list, description="Threat IDs that have been contained") |
| active_threats: List[str] = Field(default_factory=list, description="Currently active threat IDs") |
| blocked_iocs: List[str] = Field(default_factory=list, description="IOCs blocked at perimeter") |
| isolated_subnets: List[str] = Field(default_factory=list, description="Isolated network segments") |
| forensics_run: List[str] = Field(default_factory=list, description="Hosts that had forensics run") |
| killed_processes: List[Dict[str, str]] = Field(default_factory=list, description="Processes killed") |
| queried_hosts: List[str] = Field(default_factory=list, description="Hosts queried") |
| timeline: List[Dict[str, Any]] = Field(default_factory=list, description="Action timeline") |
| is_done: bool = Field(default=False, description="Whether episode has ended") |
| submitted_plan: bool = Field(default=False, description="Whether containment plan was submitted") |
|
|