# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the Network Defense (AI-SOAR) Environment. The vir_env environment simulates an Active Network Defense scenario where an RL agent must contain and eliminate a spreading virus across a 10-node enterprise network. """ from enum import Enum from typing import Any, Dict, List, Optional from openenv.core.env_server.types import Action, Observation from pydantic import Field, field_validator, model_validator class ActionType(str, Enum): """Strictly controlled action space for the network defense agent.""" SCAN_NETWORK = "scan_network" ISOLATE_NODE = "isolate_node" DEPLOY_PATCH = "deploy_patch" class NetworkAction(Action): """ Action for the Network Defense environment. The agent must choose a defensive action and optionally a target node. """ action_type: ActionType = Field( default=ActionType.SCAN_NETWORK, description="The defensive action: scan_network, isolate_node, or deploy_patch.", ) target: Optional[str] = Field( default=None, description="Target node name for isolate_node or deploy_patch. Omit for scan_network.", ) reasoning: str = Field( default="", description="A concise justification for the chosen action.", ) @model_validator(mode="before") @classmethod def coerce_input_payload(cls, data: Any) -> Any: """Accept common variants and gracefully handle empty action payloads.""" if data is None: return {"action_type": ActionType.SCAN_NETWORK.value} if not isinstance(data, dict): return data if not data: return {"action_type": ActionType.SCAN_NETWORK.value} # Some clients send alternate key names for the action selector. if "action_type" not in data: for key in ("action", "type", "tool"): candidate = data.get(key) if isinstance(candidate, str) and candidate.strip(): data = {**data, "action_type": candidate} break return data @field_validator("action_type", mode="before") @classmethod def normalize_action(cls, v: str) -> str: """Map common shorthand to strict Enum values.""" if not isinstance(v, str): return v mapping = { "scan": "scan_network", "isolate": "isolate_node", "patch": "deploy_patch", "fix": "deploy_patch", } val = v.lower().strip() return mapping.get(val, val) class NetworkObservation(Observation): """ Observation from the Network Defense environment. Contains the full current network state so the agent can make an informed defensive decision. """ # Full network snapshot network_state: Dict[str, Any] = Field( default_factory=dict, description="Current state of every node: {name: {status, connections}}.", ) # Episode counters step: int = Field(default=0, description="Current step number within this episode.") max_steps: int = Field(default=20, description="Step budget for this difficulty.") # Threat summary infected_count: int = Field(default=0, description="Number of currently infected nodes.") clean_count: int = Field(default=10, description="Number of clean (healthy) nodes.") isolated_count: int = Field(default=0, description="Number of isolated nodes.") # Episode context task: str = Field(default="easy", description="Current difficulty: easy, medium, or hard.") message: str = Field(default="", description="Feedback message from the last action.") scenario_name: str = Field(default="", description="Human-readable scenario name.") cumulative_score: float = Field(default=0.0, description="Running total score this episode.") reward_breakdown: Dict[str, float] = Field(default_factory=dict, description="Per-component reward breakdown.") db_infected_steps: int = Field(default=0, description="Consecutive steps DB has been infected.") auth_compromised: bool = Field(default=False, description="Whether Auth was ever compromised this episode.") spread_events: List[Dict[str, Any]] = Field( default_factory=list, description="Recent virus spread events [{from, to, step}] — last 5.", )