CyberSOC / models.py
Ajayyy00
Initial commit: CyberSOC Enterprise Environment Baseline
bb0d7fd
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the CyberSOCEnv — Enterprise Cybersecurity Operations Center.
Defines strict Pydantic models for:
- Observation: What the agent sees (alerts, forensics, network state, business impact)
- Action: What the agent can do (discriminated union of 6 action types)
- Internal state: Deterministic network graph, attack chains, and task tracking
"""
from __future__ import annotations
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from openenv.core.env_server.types import Action, Observation, State
from pydantic import BaseModel, ConfigDict, Field
# =============================================================================
# Enums
# =============================================================================
class Severity(str, Enum):
"""SIEM alert severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ThreatType(str, Enum):
"""Classification of threat types in the SOC environment."""
RANSOMWARE = "ransomware"
PHISHING = "phishing"
CREDENTIAL_THEFT = "credential_theft"
LATERAL_MOVEMENT = "lateral_movement"
C2_COMMUNICATION = "c2_communication"
DATA_EXFILTRATION = "data_exfiltration"
PRIVILEGE_ESCALATION = "privilege_escalation"
MALWARE = "malware"
CRYPTOMINING = "cryptomining"
SUPPLY_CHAIN = "supply_chain"
INSIDER_THREAT = "insider_threat"
WEBSHELL = "webshell"
BOTNET = "botnet"
class HostStatus(str, Enum):
"""Host operational status."""
ONLINE = "online"
COMPROMISED = "compromised"
ISOLATED = "isolated"
OFFLINE = "offline"
class SubnetRole(str, Enum):
"""Business function of a network subnet."""
CORPORATE = "corporate"
ENGINEERING = "engineering"
FINANCE = "finance"
DMZ = "dmz"
DATACENTER = "datacenter"
EXECUTIVE = "executive"
# =============================================================================
# Alert & Network Sub-Models (used in Observation)
# =============================================================================
class Alert(BaseModel):
"""A single SIEM/EDR alert in the queue."""
model_config = ConfigDict(extra="forbid")
alert_id: str = Field(..., description="Unique alert identifier")
timestamp: str = Field(..., description="ISO-8601 timestamp of the alert")
source_host: str = Field(..., description="Hostname that generated the alert")
severity: Severity = Field(..., description="Alert severity level")
threat_type: ThreatType = Field(..., description="Classified threat type")
description: str = Field(..., description="Human-readable alert description")
ioc_indicators: List[str] = Field(
default_factory=list,
description="Indicators of compromise (IPs, hashes, domains)",
)
subnet: str = Field(..., description="Subnet where the alert originated")
is_acknowledged: bool = Field(default=False, description="Whether the SOC analyst has acknowledged this alert")
class HostInfo(BaseModel):
"""Summary information about a single network host."""
model_config = ConfigDict(extra="forbid")
hostname: str = Field(..., description="Host FQDN")
ip_address: str = Field(..., description="IPv4 address")
subnet: str = Field(..., description="Subnet the host belongs to")
role: SubnetRole = Field(..., description="Business function")
status: HostStatus = Field(default=HostStatus.ONLINE, description="Current status")
running_processes: List[str] = Field(default_factory=list, description="Running process names")
open_ports: List[int] = Field(default_factory=list, description="Open TCP ports")
criticality: float = Field(
default=0.5, ge=0.0, le=1.0,
description="Business criticality score (0=low, 1=mission-critical)",
)
class NetworkTopology(BaseModel):
"""Summarized view of the 500-node enterprise network."""
model_config = ConfigDict(extra="forbid")
total_hosts: int = Field(default=500, description="Total hosts in the network")
subnets: Dict[str, int] = Field(
default_factory=dict,
description="Map of subnet name -> host count",
)
compromised_count: int = Field(default=0, description="Number of compromised hosts")
isolated_count: int = Field(default=0, description="Number of isolated hosts")
online_count: int = Field(default=500, description="Number of online hosts")
class ForensicsResult(BaseModel):
"""Results from running forensics on a host."""
model_config = ConfigDict(extra="forbid")
hostname: str = Field(..., description="Analyzed host")
malicious_processes: List[str] = Field(default_factory=list, description="Detected malicious processes")
suspicious_files: List[str] = Field(default_factory=list, description="Suspicious file paths found")
network_connections: List[str] = Field(
default_factory=list,
description="Suspicious outbound connections (ip:port)",
)
registry_modifications: List[str] = Field(default_factory=list, description="Modified registry keys")
memory_artifacts: List[str] = Field(default_factory=list, description="In-memory IOCs found")
is_compromised: bool = Field(default=False, description="Whether forensics confirm compromise")
class TimelineEntry(BaseModel):
"""A single entry in the analyst action timeline."""
model_config = ConfigDict(extra="forbid")
step: int = Field(..., description="Step number when this action was taken")
action_type: str = Field(..., description="Type of action taken")
target: str = Field(..., description="Target of the action (host, subnet, IOC)")
result: str = Field(..., description="Outcome description")
reward: float = Field(default=0.0, description="Reward received for this action")
# =============================================================================
# Observation
# =============================================================================
class SOCObservation(Observation):
"""What the SOC agent sees at each step.
Extends OpenEnv Observation (inherits: done, reward, metadata).
"""
alert_queue: List[Alert] = Field(
default_factory=list,
description="Current queue of unresolved SIEM/EDR alerts",
)
network_topology: NetworkTopology = Field(
default_factory=NetworkTopology,
description="Summary of the enterprise network state",
)
host_forensics: Optional[ForensicsResult] = Field(
default=None,
description="Forensics results if RunForensics was the last action, else None",
)
timeline: List[TimelineEntry] = Field(
default_factory=list,
description="Chronological log of all actions taken in this episode",
)
business_impact_score: float = Field(
default=0.0, ge=0.0, le=1.0,
description="Current business impact (0=no impact, 1=catastrophic outage)",
)
step_count: int = Field(default=0, ge=0, description="Current step number")
active_threats: List[str] = Field(
default_factory=list,
description="List of threat IDs that are still active/uncontained",
)
max_steps: int = Field(default=30, description="Maximum steps allowed in this episode")
task_id: str = Field(default="easy", description="Current task identifier")
total_reward: float = Field(default=0.0, description="Accumulated episode reward")
final_score: Optional[float] = Field(
default=None,
description="Post-episode grader score (0.0-1.0). Only set when done=True and plan submitted.",
)
grade_breakdown: Optional[Dict[str, Any]] = Field(
default=None,
description="Detailed grading breakdown. Only set when done=True and plan submitted.",
)
# =============================================================================
# Actions (Discriminated Union)
# =============================================================================
class QueryHost(Action):
"""Query a specific host for status, processes, and connections."""
type: Literal["query_host"] = Field(default="query_host", description="Action discriminator")
hostname: str = Field(..., description="Target hostname to query")
class IsolateSegment(Action):
"""Isolate an entire network segment from the rest of the network."""
type: Literal["isolate_segment"] = Field(default="isolate_segment", description="Action discriminator")
subnet: str = Field(..., description="Subnet name to isolate (e.g. 'finance', 'engineering')")
reason: str = Field(default="", description="Justification for isolation")
class BlockIOC(Action):
"""Block an Indicator of Compromise at the perimeter firewall."""
type: Literal["block_ioc"] = Field(default="block_ioc", description="Action discriminator")
ioc_value: str = Field(..., description="The IOC to block (IP, domain, or file hash)")
ioc_type: Literal["ip", "domain", "hash"] = Field(..., description="Type of IOC")
class RunForensics(Action):
"""Run deep forensic analysis on a specific host."""
type: Literal["run_forensics"] = Field(default="run_forensics", description="Action discriminator")
hostname: str = Field(..., description="Target hostname for forensics")
class KillProcess(Action):
"""Terminate a specific process on a host."""
type: Literal["kill_process"] = Field(default="kill_process", description="Action discriminator")
hostname: str = Field(..., description="Host where the process is running")
process_name: str = Field(..., description="Name of the process to terminate")
class ContainmentEntry(BaseModel):
"""A single entry in the containment plan."""
model_config = ConfigDict(extra="forbid")
threat_id: str = Field(..., description="Threat being addressed")
actions_taken: List[str] = Field(..., description="List of actions taken to contain this threat")
root_cause: str = Field(..., description="Identified root cause")
confidence: float = Field(
..., ge=0.0, le=1.0,
description="Confidence in the containment (0-1)",
)
class SubmitContainmentPlan(Action):
"""Submit the final containment plan to end the episode."""
type: Literal["submit_containment_plan"] = Field(
default="submit_containment_plan", description="Action discriminator"
)
plan: List[ContainmentEntry] = Field(
..., description="The containment plan addressing all identified threats"
)
executive_summary: str = Field(
..., description="Brief executive summary for CISO reporting"
)
# Discriminated union of all SOC actions
SOCAction = Annotated[
Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan],
Field(discriminator="type"),
]
# Wrapper model so OpenEnv's create_app can accept it as a single Action class
class SOCActionWrapper(Action):
"""Wrapper that deserializes the discriminated union action.
OpenEnv's create_app expects a single Action subclass. This wrapper
uses a discriminated union field so the HTTP/WS layer can parse
any of the 6 action types from a flat JSON payload.
Client sends: {"action": {"type": "query_host", "hostname": "WS-001"}}
The wrapper validates -> QueryHost(hostname="WS-001")
"""
type: str = Field(..., description="Action type discriminator")
model_config = ConfigDict(extra="allow") # Allow action-specific fields
def to_typed_action(self) -> Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan]:
"""Convert the raw wrapper into the correctly typed action."""
data = self.model_dump(exclude={"metadata"})
action_map = {
"query_host": QueryHost,
"isolate_segment": IsolateSegment,
"block_ioc": BlockIOC,
"run_forensics": RunForensics,
"kill_process": KillProcess,
"submit_containment_plan": SubmitContainmentPlan,
}
cls = action_map.get(data["type"])
if cls is None:
raise ValueError(
f"Unknown action type: {data['type']}. "
f"Valid types: {list(action_map.keys())}"
)
return cls(**data)
# =============================================================================
# Internal State (not exposed to agent directly)
# =============================================================================
class SOCState(State):
"""Internal environment state tracking the attack simulation.
Extends OpenEnv State (inherits: episode_id, step_count).
Uses extra='allow' from base State.
"""
task_id: str = Field(default="easy", description="Current task: 'easy', 'medium', or 'hard'")
max_steps: int = Field(default=30, description="Maximum steps for this episode")
total_reward: float = Field(default=0.0, description="Accumulated reward")
business_impact: float = Field(default=0.0, ge=0.0, le=1.0, description="Current business impact score")
contained_threats: List[str] = Field(default_factory=list, description="Threat IDs that have been contained")
active_threats: List[str] = Field(default_factory=list, description="Currently active threat IDs")
blocked_iocs: List[str] = Field(default_factory=list, description="IOCs blocked at perimeter")
isolated_subnets: List[str] = Field(default_factory=list, description="Isolated network segments")
forensics_run: List[str] = Field(default_factory=list, description="Hosts that had forensics run")
killed_processes: List[Dict[str, str]] = Field(default_factory=list, description="Processes killed")
queried_hosts: List[str] = Field(default_factory=list, description="Hosts queried")
timeline: List[Dict[str, Any]] = Field(default_factory=list, description="Action timeline")
is_done: bool = Field(default=False, description="Whether episode has ended")
submitted_plan: bool = Field(default=False, description="Whether containment plan was submitted")