security-audit-env / server /security_audit_env_environment.py
anshumanatrey's picture
KB-driven dynamic architecture: 26 vuln types, procedural generation, parameter-level testing
d3c5d92 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Security Audit Environment Implementation.
Simulates real-world VAPT engagements where an AI agent audits
infrastructure for security vulnerabilities and compliance gaps.
"""
import random
from copy import deepcopy
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
try:
from ..models import SecurityAuditAction, SecurityAuditObservation, SecurityAuditState
except ImportError:
from models import SecurityAuditAction, SecurityAuditObservation, SecurityAuditState
try:
from .scenarios import get_scenario, list_scenarios
from .tools_engine import TOOL_DEFINITIONS, execute_tool
from .grader import grade_episode, match_single_finding
except ImportError:
from server.scenarios import get_scenario, list_scenarios
from server.tools_engine import TOOL_DEFINITIONS, execute_tool
from server.grader import grade_episode, match_single_finding
class SecurityAuditEnvironment(Environment):
"""
AI Security Audit Training Environment.
Simulates real-world Vulnerability Assessment & Penetration Testing (VAPT)
engagements. The agent discovers hosts, scans services, identifies
vulnerabilities, and submits structured findings — just like a
professional security auditor.
Three scenarios with increasing difficulty:
- Easy: Startup web app (2 hosts, 3 vulns)
- Medium: E-commerce platform (4 hosts, 6 vulns)
- Hard: Enterprise SOC2 audit (6 hosts, 10 vulns + honeypots)
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
# Difficulty multiplier for per-step tool/finding rewards
_DIFFICULTY_REWARD_MULTIPLIER = {"easy": 1.0, "medium": 1.3, "hard": 1.6}
def __init__(self):
super().__init__()
self._state = SecurityAuditState()
self._scenario = None
self._discovered_hosts: list = []
self._discovered_ports: dict = {}
self._discovered_services: dict = {}
self._submitted_findings: list = []
self._action_history: list = []
self._discovered_vulns: set = set()
self._episode_reward: float = 0.0
self._last_tool_call: tuple = ()
self._rng: random.Random = random.Random()
def reset(self, seed=None, episode_id=None, **kwargs) -> SecurityAuditObservation:
"""Reset the environment for a new audit engagement.
kwargs:
scenario_id: "easy", "medium", or "hard" (default: "easy")
"""
scenario_id = kwargs.get("scenario_id", "easy")
self._scenario = deepcopy(get_scenario(scenario_id))
self._discovered_hosts = []
self._discovered_ports = {}
self._discovered_services = {}
self._submitted_findings = []
self._action_history = []
self._discovered_vulns = set()
self._episode_reward = 0.0
self._last_tool_call = ()
self._rng = random.Random(seed) if seed is not None else random.Random()
eid = episode_id or str(uuid4())
self._state = SecurityAuditState(
episode_id=eid,
step_count=0,
scenario_id=scenario_id,
scenario_name=self._scenario["name"],
target_network=self._scenario["target_network"],
max_steps=self._scenario["max_steps"],
)
self._reset_rubric()
return SecurityAuditObservation(
tool_output="",
message=self._scenario["briefing"],
discovered_hosts=[],
discovered_services={},
findings_submitted=0,
steps_remaining=self._scenario["max_steps"],
done=False,
reward=0.0,
)
def step(self, action: SecurityAuditAction, **kwargs) -> SecurityAuditObservation:
self._state.step_count += 1
steps_remaining = self._state.max_steps - self._state.step_count
self._action_history.append({
"step": self._state.step_count,
"action_type": action.action_type,
"tool_name": action.tool_name,
"arguments": action.arguments,
})
if steps_remaining <= 0:
return self._finish_episode("Step limit reached. Audit terminated.", truncated=True)
if action.action_type == "list_tools":
return self._handle_list_tools(steps_remaining)
elif action.action_type == "use_tool":
return self._handle_use_tool(action, steps_remaining)
elif action.action_type == "submit_finding":
return self._handle_submit_finding(action, steps_remaining)
elif action.action_type == "generate_report":
return self._finish_episode("Audit report generated.", truncated=False)
else:
return SecurityAuditObservation(
tool_output=f"Unknown action_type: {action.action_type}",
message="Use list_tools, use_tool, submit_finding, or generate_report.",
discovered_hosts=self._discovered_hosts,
discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings),
steps_remaining=steps_remaining,
current_phase=self._current_phase(),
done=False,
reward=-0.05,
)
@property
def state(self) -> SecurityAuditState:
self._state.discovered_hosts = list(self._discovered_hosts)
self._state.discovered_ports = dict(self._discovered_ports)
self._state.discovered_services = dict(self._discovered_services)
self._state.submitted_findings = list(self._submitted_findings)
self._state.total_reward = self._episode_reward
return self._state
def _current_phase(self) -> str:
"""Determine current audit phase from agent progress."""
if len(self._submitted_findings) > 0:
return "exploitation"
if len(self._discovered_hosts) > 0:
return "enumeration"
return "reconnaissance"
# --- Action Handlers ---
def _handle_list_tools(self, steps_remaining):
tools_text = "Available security audit tools:\n\n"
for tool in TOOL_DEFINITIONS:
params = ", ".join(f"{k}: {v}" for k, v in tool["parameters"].items())
tools_text += f" {tool['name']}\n"
tools_text += f" Description: {tool['description']}\n"
tools_text += f" Parameters: {params}\n\n"
return SecurityAuditObservation(
tool_output=tools_text, available_tools=TOOL_DEFINITIONS,
message="Use 'use_tool' action with tool_name and arguments to run a tool.",
discovered_hosts=self._discovered_hosts, discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings), steps_remaining=steps_remaining,
current_phase=self._current_phase(), done=False, reward=0.0,
)
def _handle_use_tool(self, action, steps_remaining):
if not action.tool_name:
return SecurityAuditObservation(
tool_output="Error: tool_name is required for use_tool action.",
message="Specify which tool to use.",
discovered_hosts=self._discovered_hosts, discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings), steps_remaining=steps_remaining,
current_phase=self._current_phase(), done=False, reward=-0.02,
)
target = action.arguments.get("host", action.arguments.get("target", ""))
honeypots = self._scenario.get("honeypots", [])
honeypot_penalty = -0.10 if target in honeypots else 0.0
# Detect redundant tool calls
current_call = (action.tool_name, tuple(sorted(action.arguments.items())))
redundancy_penalty = -0.01 if current_call == self._last_tool_call else 0.0
self._last_tool_call = current_call
output, new_hosts, new_ports, tool_reward = execute_tool(
action.tool_name, action.arguments, self._scenario,
self._discovered_hosts, self._discovered_ports, self._discovered_vulns,
)
# Difficulty multiplier on positive rewards
difficulty = self._scenario.get("id", "easy")
multiplier = self._DIFFICULTY_REWARD_MULTIPLIER.get(difficulty, 1.0)
if tool_reward > 0:
tool_reward *= multiplier
for h in new_hosts:
if h not in self._discovered_hosts:
self._discovered_hosts.append(h)
for h, ports in new_ports.items():
if h not in self._discovered_ports:
self._discovered_ports[h] = []
for p in ports:
if p not in self._discovered_ports[h]:
self._discovered_ports[h].append(p)
for h in self._discovered_hosts:
host_ports = self._scenario.get("ports", {}).get(h, [])
known_ports = self._discovered_ports.get(h, [])
services = [f"{p['service']}/{p['port']}" for p in host_ports if p["port"] in known_ports]
if services:
self._discovered_services[h] = services
reward = tool_reward + honeypot_penalty + redundancy_penalty
self._episode_reward += reward
message = ""
if honeypot_penalty < 0:
message = "WARNING: Activity detected on a monitoring system. This may be a honeypot."
if redundancy_penalty < 0:
message += " Note: Repeated identical tool call — consider a different action."
return SecurityAuditObservation(
tool_output=output, message=message.strip(),
discovered_hosts=self._discovered_hosts, discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings), steps_remaining=steps_remaining,
current_phase=self._current_phase(), done=False, reward=reward,
)
def _handle_submit_finding(self, action, steps_remaining):
finding = action.arguments
required = ["title", "host", "severity"]
missing = [k for k in required if k not in finding]
if missing:
return SecurityAuditObservation(
tool_output=f"Error: Missing required fields: {', '.join(missing)}",
message="Finding must include at least: title, host, severity. "
"Recommended: cvss_score, cwe, owasp, endpoint, evidence, remediation.",
discovered_hosts=self._discovered_hosts, discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings), steps_remaining=steps_remaining,
current_phase=self._current_phase(), done=False, reward=-0.02,
)
self._submitted_findings.append(finding)
# Match using same logic as grader for consistency
gt_vulns = self._scenario.get("vulnerabilities", [])
matched_id = match_single_finding(finding, gt_vulns, self._discovered_vulns)
difficulty = self._scenario.get("id", "easy")
multiplier = self._DIFFICULTY_REWARD_MULTIPLIER.get(difficulty, 1.0)
if matched_id:
reward = 0.12 * multiplier
self._discovered_vulns.add(matched_id)
else:
# Diminishing reward for unmatched findings to prevent spam
unmatched = len(self._submitted_findings) - len(self._discovered_vulns)
if unmatched <= 2:
reward = 0.02
elif unmatched <= 4:
reward = 0.01
else:
reward = 0.0
self._episode_reward += reward
return SecurityAuditObservation(
tool_output=f"Finding #{len(self._submitted_findings)} recorded: {finding.get('title', 'Untitled')}",
message=f"Finding submitted. Total findings: {len(self._submitted_findings)}.",
discovered_hosts=self._discovered_hosts, discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings), steps_remaining=steps_remaining,
current_phase=self._current_phase(), done=False, reward=reward,
)
def _finish_episode(self, message, truncated=False):
"""End the audit and compute final grade."""
grades = grade_episode(
self._scenario, self._submitted_findings,
self._discovered_hosts, self._discovered_ports,
steps_used=self._state.step_count,
)
final_score = grades["final_score"]
self._episode_reward += final_score
report_lines = [
"=" * 60, "SECURITY AUDIT REPORT", "=" * 60,
f"Scenario: {self._scenario['name']}",
f"Company: {self._scenario['company']}",
f"Compliance Framework: {self._scenario['compliance_context']}",
"",
"RESULTS:",
f" Final Score: {final_score:.2f} / 1.00",
f" Detection Rate: {grades['detection_rate']:.2f} ({grades['true_positives']}/{grades['total_vulnerabilities']} vulnerabilities found)",
f" Coverage: {grades['coverage']:.2f} ({grades['hosts_examined']}/{grades['total_hosts']} hosts examined)",
f" Severity Accuracy: {grades['severity_accuracy']:.2f}",
f" Classification: CWE {grades['cwe_accuracy']:.2f} | OWASP {grades['owasp_accuracy']:.2f} | Combined {grades['classification_accuracy']:.2f}",
f" Report Quality: {grades['report_quality']:.2f} (fields: {grades['field_completeness']:.2f}, narrative: {grades['narrative_quality']:.2f})",
f" Pivoting Score: {grades['pivoting_score']:.2f}",
f" Exploitation Proof: {grades['exploitation_proof']:.2f}",
f" False Positives: {grades['false_positives']} (penalty: -{grades['fp_penalty']:.2f})",
f" Honeypot Penalty: -{grades['honeypot_penalty']:.2f}",
"",
"COMPLIANCE:",
f" Framework: {grades['compliance_framework']}",
f" Controls Covered: {grades['compliance_controls_covered']}/{grades['compliance_controls_expected']}",
f" Compliance Coverage: {grades['compliance_coverage']:.2f}",
"",
f"Steps Used: {self._state.step_count} / {self._scenario['max_steps']} (efficiency: {grades['efficiency']:.2f})",
f"Findings Submitted: {len(self._submitted_findings)}",
"=" * 60,
]
return SecurityAuditObservation(
tool_output="\n".join(report_lines), message=message,
discovered_hosts=self._discovered_hosts, discovered_services=self._discovered_services,
findings_submitted=len(self._submitted_findings), steps_remaining=0,
done=True, truncated=truncated, current_phase="reporting",
reward=final_score, metadata={"grades": grades},
)