Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Data models for the Network Defense (AI-SOAR) Environment. | |
| The vir_env environment simulates an Active Network Defense scenario | |
| where an RL agent must contain and eliminate a spreading virus across | |
| a 10-node enterprise network. | |
| """ | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| from openenv.core.env_server.types import Action, Observation | |
| from pydantic import Field, field_validator, model_validator | |
| class ActionType(str, Enum): | |
| """Strictly controlled action space for the network defense agent.""" | |
| SCAN_NETWORK = "scan_network" | |
| ISOLATE_NODE = "isolate_node" | |
| DEPLOY_PATCH = "deploy_patch" | |
| class NetworkAction(Action): | |
| """ | |
| Action for the Network Defense environment. | |
| The agent must choose a defensive action and optionally a target node. | |
| """ | |
| action_type: ActionType = Field( | |
| default=ActionType.SCAN_NETWORK, | |
| description="The defensive action: scan_network, isolate_node, or deploy_patch.", | |
| ) | |
| target: Optional[str] = Field( | |
| default=None, | |
| description="Target node name for isolate_node or deploy_patch. Omit for scan_network.", | |
| ) | |
| reasoning: str = Field( | |
| default="", | |
| description="A concise justification for the chosen action.", | |
| ) | |
| def coerce_input_payload(cls, data: Any) -> Any: | |
| """Accept common variants and gracefully handle empty action payloads.""" | |
| if data is None: | |
| return {"action_type": ActionType.SCAN_NETWORK.value} | |
| if not isinstance(data, dict): | |
| return data | |
| if not data: | |
| return {"action_type": ActionType.SCAN_NETWORK.value} | |
| # Some clients send alternate key names for the action selector. | |
| if "action_type" not in data: | |
| for key in ("action", "type", "tool"): | |
| candidate = data.get(key) | |
| if isinstance(candidate, str) and candidate.strip(): | |
| data = {**data, "action_type": candidate} | |
| break | |
| return data | |
| def normalize_action(cls, v: str) -> str: | |
| """Map common shorthand to strict Enum values.""" | |
| if not isinstance(v, str): | |
| return v | |
| mapping = { | |
| "scan": "scan_network", | |
| "isolate": "isolate_node", | |
| "patch": "deploy_patch", | |
| "fix": "deploy_patch", | |
| } | |
| val = v.lower().strip() | |
| return mapping.get(val, val) | |
| class NetworkObservation(Observation): | |
| """ | |
| Observation from the Network Defense environment. | |
| Contains the full current network state so the agent can make | |
| an informed defensive decision. | |
| """ | |
| # Full network snapshot | |
| network_state: Dict[str, Any] = Field( | |
| default_factory=dict, | |
| description="Current state of every node: {name: {status, connections}}.", | |
| ) | |
| # Episode counters | |
| step: int = Field(default=0, description="Current step number within this episode.") | |
| max_steps: int = Field(default=20, description="Step budget for this difficulty.") | |
| # Threat summary | |
| infected_count: int = Field(default=0, description="Number of currently infected nodes.") | |
| clean_count: int = Field(default=10, description="Number of clean (healthy) nodes.") | |
| isolated_count: int = Field(default=0, description="Number of isolated nodes.") | |
| # Episode context | |
| task: str = Field(default="easy", description="Current difficulty: easy, medium, or hard.") | |
| message: str = Field(default="", description="Feedback message from the last action.") | |
| scenario_name: str = Field(default="", description="Human-readable scenario name.") | |
| cumulative_score: float = Field(default=0.0, description="Running total score this episode.") | |
| reward_breakdown: Dict[str, float] = Field(default_factory=dict, description="Per-component reward breakdown.") | |
| db_infected_steps: int = Field(default=0, description="Consecutive steps DB has been infected.") | |
| auth_compromised: bool = Field(default=False, description="Whether Auth was ever compromised this episode.") | |
| spread_events: List[Dict[str, Any]] = Field( | |
| default_factory=list, | |
| description="Recent virus spread events [{from, to, step}] — last 5.", | |
| ) | |