Demo / models.py
Ajayyy00
Add FSP multi-agent architecture: Red Team LLM action space + alternating turns
5719ec3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the CyberSOCEnv — Enterprise Cybersecurity Operations Center.
Defines strict Pydantic models for:
- Observation: What the agent sees (alerts, forensics, network state, business impact)
- Action: What the agent can do (discriminated union of 6 action types)
- Internal state: Deterministic network graph, attack chains, and task tracking
"""
from __future__ import annotations
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from openenv.core.env_server.types import Action, Observation, State
from pydantic import BaseModel, ConfigDict, Field
# =============================================================================
# Enums
# =============================================================================
class Severity(str, Enum):
"""SIEM alert severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ThreatType(str, Enum):
"""Classification of threat types in the SOC environment."""
RANSOMWARE = "ransomware"
PHISHING = "phishing"
CREDENTIAL_THEFT = "credential_theft"
LATERAL_MOVEMENT = "lateral_movement"
C2_COMMUNICATION = "c2_communication"
DATA_EXFILTRATION = "data_exfiltration"
PRIVILEGE_ESCALATION = "privilege_escalation"
MALWARE = "malware"
CRYPTOMINING = "cryptomining"
SUPPLY_CHAIN = "supply_chain"
INSIDER_THREAT = "insider_threat"
WEBSHELL = "webshell"
BOTNET = "botnet"
class HostStatus(str, Enum):
"""Host operational status."""
ONLINE = "online"
COMPROMISED = "compromised"
ISOLATED = "isolated"
OFFLINE = "offline"
class SubnetRole(str, Enum):
"""Business function of a network subnet."""
CORPORATE = "corporate"
ENGINEERING = "engineering"
FINANCE = "finance"
DMZ = "dmz"
DATACENTER = "datacenter"
EXECUTIVE = "executive"
# =============================================================================
# Alert & Network Sub-Models (used in Observation)
# =============================================================================
class Alert(BaseModel):
"""A single SIEM/EDR alert in the queue."""
model_config = ConfigDict(extra="forbid")
alert_id: str = Field(..., description="Unique alert identifier")
timestamp: str = Field(..., description="ISO-8601 timestamp of the alert")
source_host: str = Field(..., description="Hostname that generated the alert")
severity: Severity = Field(..., description="Alert severity level")
threat_type: ThreatType = Field(..., description="Classified threat type")
description: str = Field(..., description="Human-readable alert description")
ioc_indicators: List[str] = Field(
default_factory=list,
description="Indicators of compromise (IPs, hashes, domains)",
)
subnet: str = Field(..., description="Subnet where the alert originated")
is_acknowledged: bool = Field(default=False, description="Whether the SOC analyst has acknowledged this alert")
class HostInfo(BaseModel):
"""Summary information about a single network host."""
model_config = ConfigDict(extra="forbid")
hostname: str = Field(..., description="Host FQDN")
ip_address: str = Field(..., description="IPv4 address")
subnet: str = Field(..., description="Subnet the host belongs to")
role: SubnetRole = Field(..., description="Business function")
status: HostStatus = Field(default=HostStatus.ONLINE, description="Current status")
running_processes: List[str] = Field(default_factory=list, description="Running process names")
open_ports: List[int] = Field(default_factory=list, description="Open TCP ports")
criticality: float = Field(
default=0.5, ge=0.0, le=1.0,
description="Business criticality score (0=low, 1=mission-critical)",
)
class NetworkTopology(BaseModel):
"""Summarized view of the 500-node enterprise network."""
model_config = ConfigDict(extra="forbid")
total_hosts: int = Field(default=500, description="Total hosts in the network")
subnets: Dict[str, int] = Field(
default_factory=dict,
description="Map of subnet name -> host count",
)
compromised_count: int = Field(default=0, description="Number of compromised hosts")
isolated_count: int = Field(default=0, description="Number of isolated hosts")
online_count: int = Field(default=500, description="Number of online hosts")
class ForensicsResult(BaseModel):
"""Results from running forensics on a host."""
model_config = ConfigDict(extra="forbid")
hostname: str = Field(..., description="Analyzed host")
malicious_processes: List[str] = Field(default_factory=list, description="Detected malicious processes")
suspicious_files: List[str] = Field(default_factory=list, description="Suspicious file paths found")
network_connections: List[str] = Field(
default_factory=list,
description="Suspicious outbound connections (ip:port)",
)
registry_modifications: List[str] = Field(default_factory=list, description="Modified registry keys")
memory_artifacts: List[str] = Field(default_factory=list, description="In-memory IOCs found")
is_compromised: bool = Field(default=False, description="Whether forensics confirm compromise")
class TimelineEntry(BaseModel):
"""A single entry in the analyst action timeline."""
model_config = ConfigDict(extra="forbid")
step: int = Field(..., description="Step number when this action was taken")
action_type: str = Field(..., description="Type of action taken")
target: str = Field(..., description="Target of the action (host, subnet, IOC)")
result: str = Field(..., description="Outcome description")
reward: float = Field(default=0.0, description="Reward received for this action")
# =============================================================================
# Observation
# =============================================================================
class SOCObservation(Observation):
"""What the SOC agent sees at each step.
Extends OpenEnv Observation (inherits: done, reward, metadata).
"""
episode_id: str = Field(
default="",
description="Unique UUID for this episode — used by the RL training loop to prevent hash collisions in GRPO batched rollouts.",
)
alert_queue: List[Alert] = Field(
default_factory=list,
description="Current queue of unresolved SIEM/EDR alerts",
)
network_topology: NetworkTopology = Field(
default_factory=NetworkTopology,
description="Summary of the enterprise network state",
)
host_forensics: Optional[ForensicsResult] = Field(
default=None,
description="Forensics results if RunForensics was the last action, else None",
)
timeline: List[TimelineEntry] = Field(
default_factory=list,
description="Chronological log of all actions taken in this episode",
)
business_impact_score: float = Field(
default=0.0, ge=0.0, le=1.0,
description="Current business impact (0=no impact, 1=catastrophic outage)",
)
step_count: int = Field(default=0, ge=0, description="Current step number")
active_threats: List[str] = Field(
default_factory=list,
description="List of threat IDs that are still active/uncontained",
)
max_steps: int = Field(default=30, description="Maximum steps allowed in this episode")
task_id: str = Field(default="easy", description="Current task identifier")
total_reward: float = Field(default=0.0, description="Accumulated episode reward")
final_score: Optional[float] = Field(
default=None,
description="Post-episode grader score (0.0-1.0). Only set when done=True and plan submitted.",
)
grade_breakdown: Optional[Dict[str, Any]] = Field(
default=None,
description="Detailed grading breakdown. Only set when done=True and plan submitted.",
)
correlation_results: Optional[Dict[str, Any]] = Field(
default=None,
description="Results from the most recent correlate_alerts call.",
)
ioc_enrichment: Optional[Dict[str, Any]] = Field(
default=None,
description="Results from the most recent enrich_ioc call.",
)
vulnerability_results: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Results from the most recent scan_host_vulnerabilities call.",
)
playbook_result: Optional[Dict[str, Any]] = Field(
default=None,
description="Results from the most recent trigger_playbook call.",
)
threat_graph_summary: Optional[str] = Field(
default=None,
description="Compact textual summary of the current threat graph.",
)
available_playbooks: List[str] = Field(
default_factory=list,
description="Names of SOAR playbooks available to the agent.",
)
reward_dimensions: Optional[Dict[str, float]] = Field(
default=None,
description=(
"Running partial scores for each of the 10 grading dimensions, "
"updated every step for live GRPO credit-assignment signals. "
"Keys match grade_breakdown (threat_containment, ioc_blocking, etc.)."
),
)
active_turn: str = Field(
default="blue",
description="Whose turn it is next: 'blue' or 'red'. Used by FSP inference loops.",
)
red_observation: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Red Team's current view of the world (populated when active_turn='red'). "
"Contains compromised_hosts and blue_actions_detected."
),
)
# =============================================================================
# Actions (Discriminated Union)
# =============================================================================
class QueryHost(Action):
"""Query a specific host for status, processes, and connections."""
type: Literal["query_host"] = Field(default="query_host", description="Action discriminator")
hostname: str = Field(..., description="Target hostname to query")
class IsolateSegment(Action):
"""Isolate an entire network segment, or a single host from the network."""
type: Literal["isolate_segment"] = Field(default="isolate_segment", description="Action discriminator")
subnet: str = Field(default="", description="Subnet name to isolate (mutually exclusive with target_host)")
target_host: Optional[str] = Field(
default=None,
description="If set, isolate only this single host instead of the whole subnet",
)
reason: str = Field(default="", description="Justification for isolation")
class BlockIOC(Action):
"""Block an Indicator of Compromise at the perimeter firewall."""
type: Literal["block_ioc"] = Field(default="block_ioc", description="Action discriminator")
ioc_value: str = Field(..., description="The IOC to block (IP, domain, or file hash)")
ioc_type: Literal["ip", "domain", "hash"] = Field(..., description="Type of IOC")
class RunForensics(Action):
"""Run deep forensic analysis on a specific host."""
type: Literal["run_forensics"] = Field(default="run_forensics", description="Action discriminator")
hostname: str = Field(..., description="Target hostname for forensics")
class KillProcess(Action):
"""Terminate a specific process on a host."""
type: Literal["kill_process"] = Field(default="kill_process", description="Action discriminator")
hostname: str = Field(..., description="Host where the process is running")
process_name: str = Field(..., description="Name of the process to terminate")
class ContainmentEntry(BaseModel):
"""A single entry in the containment plan."""
model_config = ConfigDict(extra="forbid")
threat_id: str = Field(..., description="Threat being addressed")
actions_taken: List[str] = Field(..., description="List of actions taken to contain this threat")
root_cause: str = Field(..., description="Identified root cause")
confidence: float = Field(
..., ge=0.0, le=1.0,
description="Confidence in the containment (0-1)",
)
class SubmitContainmentPlan(Action):
"""Submit the final containment plan to end the episode."""
type: Literal["submit_containment_plan"] = Field(
default="submit_containment_plan", description="Action discriminator"
)
plan: List[ContainmentEntry] = Field(
..., description="The containment plan addressing all identified threats"
)
executive_summary: str = Field(
..., description="Brief executive summary for CISO reporting"
)
class CorrelateAlerts(Action):
"""Correlate two or more alerts to find shared entities/IOCs."""
type: Literal["correlate_alerts"] = Field(default="correlate_alerts")
alert_ids: List[str] = Field(..., min_length=2, description="At least 2 alert IDs to correlate")
class EnrichIOC(Action):
"""Enrich an IOC with threat-intelligence data (actor, MITRE TTPs)."""
type: Literal["enrich_ioc"] = Field(default="enrich_ioc")
ioc_value: str = Field(..., description="The IOC value to enrich")
ioc_type: Literal["ip", "domain", "hash", "filename"] = Field(..., description="Type of IOC")
class ScanHostVulnerabilities(Action):
"""Run a vulnerability scan on a host to discover CVEs."""
type: Literal["scan_host_vulnerabilities"] = Field(default="scan_host_vulnerabilities")
hostname: str = Field(..., description="Target hostname for vulnerability scan")
class TerminatePID(Action):
"""Terminate a process by PID on a host."""
type: Literal["terminate_pid"] = Field(default="terminate_pid")
hostname: str = Field(..., description="Host where the PID is running")
pid: str = Field(..., description="Target process PID")
class CreateFirewallRule(Action):
"""Create a host-level firewall rule for a target IP."""
type: Literal["create_firewall_rule"] = Field(default="create_firewall_rule")
hostname: str = Field(..., description="Host where firewall rule is applied")
target_ip: str = Field(..., description="Target IP for the rule")
action: Literal["drop", "allow"] = Field(..., description="Firewall action to apply")
class QuarantineFile(Action):
"""Quarantine a suspicious file on a host."""
type: Literal["quarantine_file"] = Field(default="quarantine_file")
hostname: str = Field(..., description="Host where file is located")
file_path: str = Field(..., description="File path to quarantine")
# =============================================================================
# Red Team Actions (FSP — Fictitious Self-Play)
# =============================================================================
class LateralPivot(Action):
"""Red Team: move laterally from a compromised host to a new target."""
type: Literal["lateral_pivot"] = Field(default="lateral_pivot")
source_host: str = Field(..., description="Already-compromised host used as the pivot point")
target_host: str = Field(..., description="Destination host to compromise")
class DeployPayload(Action):
"""Red Team: deploy a malicious payload on a host Red already controls."""
type: Literal["deploy_payload"] = Field(default="deploy_payload")
hostname: str = Field(..., description="Compromised host to deploy payload on")
payload_type: Literal["ransomware", "exfiltration", "c2"] = Field(
..., description="Class of payload to deploy"
)
class EvadeDetection(Action):
"""Red Team: apply an evasion technique on a compromised host."""
type: Literal["evade_detection"] = Field(default="evade_detection")
hostname: str = Field(..., description="Compromised host to apply evasion on")
technique: Literal["migrate_pid", "clear_logs"] = Field(
...,
description=(
"migrate_pid: rename running malicious processes to blend with system names; "
"clear_logs: remove SIEM alerts originating from this host"
),
)
class PassTurn(Action):
"""Red Team: remain stealthy and take no action this turn."""
type: Literal["pass_turn"] = Field(default="pass_turn")
# Constant used by dashboard_server and inference to route payloads
RED_ACTION_TYPES: frozenset = frozenset(
{"lateral_pivot", "deploy_payload", "evade_detection", "pass_turn"}
)
# Discriminated union of all Red actions
RedAction = Annotated[
Union[LateralPivot, DeployPayload, EvadeDetection, PassTurn],
Field(discriminator="type"),
]
class RedActionWrapper(Action):
"""Wrapper for Red Team actions — mirrors SOCActionWrapper for the WS/HTTP layer."""
type: str = Field(..., description="Red action type discriminator")
model_config = ConfigDict(extra="allow")
def to_typed_action(self):
"""Deserialize to the correctly-typed Red action."""
data = self.model_dump(exclude={"metadata"})
action_map = {
"lateral_pivot": LateralPivot,
"deploy_payload": DeployPayload,
"evade_detection": EvadeDetection,
"pass_turn": PassTurn,
}
cls = action_map.get(data["type"])
if cls is None:
raise ValueError(
f"Unknown red action type: {data['type']}. "
f"Valid types: {list(action_map)}"
)
return cls(**data)
# Discriminated union of all SOC actions
SOCAction = Annotated[
Union[
QueryHost,
IsolateSegment,
BlockIOC,
RunForensics,
KillProcess,
SubmitContainmentPlan,
CorrelateAlerts,
EnrichIOC,
ScanHostVulnerabilities,
TerminatePID,
CreateFirewallRule,
QuarantineFile,
],
Field(discriminator="type"),
]
# Wrapper model so OpenEnv's create_app can accept it as a single Action class
class SOCActionWrapper(Action):
"""Wrapper that deserializes the discriminated union action.
OpenEnv's create_app expects a single Action subclass. This wrapper
uses a discriminated union field so the HTTP/WS layer can parse
any of the 6 action types from a flat JSON payload.
Client sends: {"action": {"type": "query_host", "hostname": "WS-001"}}
The wrapper validates -> QueryHost(hostname="WS-001")
"""
type: str = Field(..., description="Action type discriminator")
model_config = ConfigDict(extra="allow") # Allow action-specific fields
def to_typed_action(self):
"""Convert the raw wrapper into the correctly typed action."""
data = self.model_dump(exclude={"metadata"})
action_map = {
"query_host": QueryHost,
"isolate_segment": IsolateSegment,
"block_ioc": BlockIOC,
"run_forensics": RunForensics,
"kill_process": KillProcess,
"submit_containment_plan": SubmitContainmentPlan,
"correlate_alerts": CorrelateAlerts,
"enrich_ioc": EnrichIOC,
"scan_host_vulnerabilities": ScanHostVulnerabilities,
"terminate_pid": TerminatePID,
"create_firewall_rule": CreateFirewallRule,
"quarantine_file": QuarantineFile,
}
cls = action_map.get(data["type"])
if cls is None:
raise ValueError(
f"Unknown action type: {data['type']}. "
f"Valid types: {list(action_map.keys())}"
)
return cls(**data)
# =============================================================================
# Internal State (not exposed to agent directly)
# =============================================================================
class SOCState(State):
"""Internal environment state tracking the attack simulation.
Extends OpenEnv State (inherits: episode_id, step_count).
Uses extra='allow' from base State.
"""
task_id: str = Field(default="easy", description="Current task: 'easy', 'medium', or 'hard'")
max_steps: int = Field(default=30, description="Maximum steps for this episode")
total_reward: float = Field(default=0.0, description="Accumulated reward")
business_impact: float = Field(default=0.0, ge=0.0, le=1.0, description="Current business impact score")
contained_threats: List[str] = Field(default_factory=list, description="Threat IDs that have been contained")
active_threats: List[str] = Field(default_factory=list, description="Currently active threat IDs")
blocked_iocs: List[str] = Field(default_factory=list, description="IOCs blocked at perimeter")
isolated_subnets: List[str] = Field(default_factory=list, description="Isolated network segments")
forensics_run: List[str] = Field(default_factory=list, description="Hosts that had forensics run")
killed_processes: List[Dict[str, str]] = Field(default_factory=list, description="Processes killed")
queried_hosts: List[str] = Field(default_factory=list, description="Hosts queried")
timeline: List[Dict[str, Any]] = Field(default_factory=list, description="Action timeline")
is_done: bool = Field(default=False, description="Whether episode has ended")
submitted_plan: bool = Field(default=False, description="Whether containment plan was submitted")
enriched_iocs: List[str] = Field(default_factory=list, description="IOCs that have been threat-intel enriched")
scanned_hosts: List[str] = Field(default_factory=list, description="Hosts that had vulnerability scans")
correlated_alert_pairs: List[Any] = Field(default_factory=list, description="Pairs/groups of alert IDs correlated together")
live_requirements: Optional[Dict[str, Any]] = Field(
default=None,
description="Mutable copy of containment_requirements (for adaptive grading).",
)
active_turn: str = Field(
default="blue",
description="Current active turn in the FSP engine: 'blue' or 'red'.",
)