Sync: compliance mapping, anti-gaming, 55 tests, mandatory stdout format, pivoting+compliance weights
c1a5935 verified | # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Data models for the Security Audit Environment. | |
| Simulates real-world VAPT (Vulnerability Assessment & Penetration Testing) | |
| engagements where an AI agent audits infrastructure for security compliance. | |
| """ | |
| from typing import Any, Dict, List, Literal, Optional | |
| from openenv.core.env_server.types import Action, Observation, State | |
| from pydantic import Field | |
| class SecurityAuditAction(Action): | |
| """Action for the Security Audit environment. | |
| The agent interacts via tool calls — discover hosts, scan services, | |
| test for vulnerabilities, submit findings, and generate reports. | |
| """ | |
| action_type: Literal[ | |
| "list_tools", | |
| "use_tool", | |
| "submit_finding", | |
| "generate_report", | |
| ] = Field(..., description="Type of action to take") | |
| tool_name: Optional[str] = Field( | |
| default=None, | |
| description="Tool to invoke (required when action_type='use_tool')", | |
| ) | |
| arguments: Dict[str, Any] = Field( | |
| default_factory=dict, | |
| description="Tool-specific arguments", | |
| ) | |
| class SecurityAuditObservation(Observation): | |
| """Observation returned after each step. | |
| Contains tool output, current discovery state, and audit progress. | |
| """ | |
| tool_output: str = Field( | |
| default="", | |
| description="Text output from the executed tool", | |
| ) | |
| available_tools: Optional[List[Dict[str, Any]]] = Field( | |
| default=None, | |
| description="List of available tools (populated by list_tools action)", | |
| ) | |
| discovered_hosts: List[str] = Field( | |
| default_factory=list, | |
| description="Hosts discovered so far", | |
| ) | |
| discovered_services: Dict[str, List[str]] = Field( | |
| default_factory=dict, | |
| description="Services discovered per host (host → [service descriptions])", | |
| ) | |
| findings_submitted: int = Field( | |
| default=0, | |
| description="Number of findings submitted so far", | |
| ) | |
| steps_remaining: int = Field( | |
| default=0, | |
| description="Steps remaining before episode ends", | |
| ) | |
| message: str = Field( | |
| default="", | |
| description="Human-readable status message", | |
| ) | |
| truncated: bool = Field( | |
| default=False, | |
| description="True if episode ended due to step limit (truncation), " | |
| "False if agent called generate_report (termination). " | |
| "Important for RL value function estimation.", | |
| ) | |
| current_phase: str = Field( | |
| default="reconnaissance", | |
| description="Current audit phase: reconnaissance, enumeration, exploitation, or reporting", | |
| ) | |
| class SecurityAuditState(State): | |
| """Full episode state for the security audit. | |
| Extends base State (episode_id, step_count) with audit-specific tracking. | |
| """ | |
| scenario_id: str = Field(default="", description="Current scenario identifier") | |
| scenario_name: str = Field(default="", description="Human-readable scenario name") | |
| target_network: str = Field(default="", description="Target network CIDR") | |
| max_steps: int = Field(default=50, description="Maximum steps allowed") | |
| discovered_hosts: List[str] = Field(default_factory=list) | |
| discovered_ports: Dict[str, List[int]] = Field(default_factory=dict) | |
| discovered_services: Dict[str, List[str]] = Field(default_factory=dict) | |
| submitted_findings: List[Dict[str, Any]] = Field(default_factory=list) | |
| total_reward: float = Field(default=0.0) | |