vir_env / models.py
arun-misra's picture
Upload folder using huggingface_hub
34a06f7 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Network Defense (AI-SOAR) Environment.
The vir_env environment simulates an Active Network Defense scenario
where an RL agent must contain and eliminate a spreading virus across
a 10-node enterprise network.
"""
from enum import Enum
from typing import Any, Dict, List, Optional
from openenv.core.env_server.types import Action, Observation
from pydantic import Field, field_validator, model_validator
class ActionType(str, Enum):
"""Strictly controlled action space for the network defense agent."""
SCAN_NETWORK = "scan_network"
ISOLATE_NODE = "isolate_node"
DEPLOY_PATCH = "deploy_patch"
class NetworkAction(Action):
"""
Action for the Network Defense environment.
The agent must choose a defensive action and optionally a target node.
"""
action_type: ActionType = Field(
default=ActionType.SCAN_NETWORK,
description="The defensive action: scan_network, isolate_node, or deploy_patch.",
)
target: Optional[str] = Field(
default=None,
description="Target node name for isolate_node or deploy_patch. Omit for scan_network.",
)
reasoning: str = Field(
default="",
description="A concise justification for the chosen action.",
)
@model_validator(mode="before")
@classmethod
def coerce_input_payload(cls, data: Any) -> Any:
"""Accept common variants and gracefully handle empty action payloads."""
if data is None:
return {"action_type": ActionType.SCAN_NETWORK.value}
if not isinstance(data, dict):
return data
if not data:
return {"action_type": ActionType.SCAN_NETWORK.value}
# Some clients send alternate key names for the action selector.
if "action_type" not in data:
for key in ("action", "type", "tool"):
candidate = data.get(key)
if isinstance(candidate, str) and candidate.strip():
data = {**data, "action_type": candidate}
break
return data
@field_validator("action_type", mode="before")
@classmethod
def normalize_action(cls, v: str) -> str:
"""Map common shorthand to strict Enum values."""
if not isinstance(v, str):
return v
mapping = {
"scan": "scan_network",
"isolate": "isolate_node",
"patch": "deploy_patch",
"fix": "deploy_patch",
}
val = v.lower().strip()
return mapping.get(val, val)
class NetworkObservation(Observation):
"""
Observation from the Network Defense environment.
Contains the full current network state so the agent can make
an informed defensive decision.
"""
# Full network snapshot
network_state: Dict[str, Any] = Field(
default_factory=dict,
description="Current state of every node: {name: {status, connections}}.",
)
# Episode counters
step: int = Field(default=0, description="Current step number within this episode.")
max_steps: int = Field(default=20, description="Step budget for this difficulty.")
# Threat summary
infected_count: int = Field(default=0, description="Number of currently infected nodes.")
clean_count: int = Field(default=10, description="Number of clean (healthy) nodes.")
isolated_count: int = Field(default=0, description="Number of isolated nodes.")
# Episode context
task: str = Field(default="easy", description="Current difficulty: easy, medium, or hard.")
message: str = Field(default="", description="Feedback message from the last action.")
scenario_name: str = Field(default="", description="Human-readable scenario name.")
cumulative_score: float = Field(default=0.0, description="Running total score this episode.")
reward_breakdown: Dict[str, float] = Field(default_factory=dict, description="Per-component reward breakdown.")
db_infected_steps: int = Field(default=0, description="Consecutive steps DB has been infected.")
auth_compromised: bool = Field(default=False, description="Whether Auth was ever compromised this episode.")
spread_events: List[Dict[str, Any]] = Field(
default_factory=list,
description="Recent virus spread events [{from, to, step}] — last 5.",
)