anshumanatrey's picture
Sync: compliance mapping, anti-gaming, 55 tests, mandatory stdout format, pivoting+compliance weights
c1a5935 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
FastAPI application for the Security Audit Environment.
"""
try:
from openenv.core.env_server.http_server import create_app
except Exception as e:
raise ImportError(
"openenv is required. Install with: pip install openenv-core"
) from e
try:
from models import SecurityAuditAction, SecurityAuditObservation
from server.security_audit_env_environment import SecurityAuditEnvironment
from server.scenarios import list_scenarios
except ImportError:
from ..models import SecurityAuditAction, SecurityAuditObservation
from .security_audit_env_environment import SecurityAuditEnvironment
from .scenarios import list_scenarios
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from fastapi.responses import JSONResponse
class GraderRequest(BaseModel):
"""Request body for the /grader endpoint."""
scenario_id: str = Field(default="easy", description="Scenario to grade against")
findings: List[Dict[str, Any]] = Field(default_factory=list)
discovered_hosts: List[str] = Field(default_factory=list)
discovered_ports: Dict[str, List[int]] = Field(default_factory=dict)
steps_used: int = Field(default=0)
app = create_app(
SecurityAuditEnvironment,
SecurityAuditAction,
SecurityAuditObservation,
env_name="security_audit_env",
max_concurrent_envs=4,
)
# --- Health check ---
@app.get("/health")
async def health():
"""Health check endpoint for container orchestration."""
return {"status": "healthy", "environment": "security_audit_env"}
# --- Custom Hackathon Endpoints ---
@app.get("/tasks")
async def get_tasks():
"""Return list of available tasks and the action schema."""
scenarios = list_scenarios()
action_schema = SecurityAuditAction.model_json_schema()
return JSONResponse({
"tasks": scenarios,
"action_schema": action_schema,
"tools": [
"network_scan", "service_fingerprint", "web_crawl",
"vulnerability_scan", "test_injection", "test_xss",
"test_auth", "test_config", "test_crypto", "check_secrets",
],
})
@app.post("/grader")
async def run_grader(data: GraderRequest):
"""Return grader scores for a completed episode."""
try:
from server.scenarios import get_scenario
from server.grader import grade_episode
except ImportError:
from .scenarios import get_scenario
from .grader import grade_episode
scenario = get_scenario(data.scenario_id)
grades = grade_episode(
scenario, data.findings, data.discovered_hosts,
data.discovered_ports, steps_used=data.steps_used,
)
return JSONResponse(grades)
@app.post("/baseline")
async def run_baseline():
"""Trigger baseline inference and return scores for all 3 tasks.
Runs a deterministic audit agent (no LLM) that scans, tests endpoints,
parses tool output for detections, submits findings, and pivots through
discovered vulns to unlock hidden hosts.
"""
import re
try:
from server.scenarios import get_scenario
except ImportError:
from .scenarios import get_scenario
def _do_step(env, **kwargs):
"""Step and return (obs, done)."""
obs = env.step(SecurityAuditAction(**kwargs))
return obs, getattr(obs, "done", False)
def _parse_and_submit(env, host, endpoint, tool_name, obs_text):
"""Parse tool output for detections and submit findings."""
# Patterns that indicate a vulnerability was found
patterns = {
"CRITICAL": re.compile(r"\[CRITICAL\]\s*(.+?)(?:\n|$)"),
"ALERT": re.compile(r"\[ALERT\]\s*(.+?)(?:\n|$)"),
"MISCONFIGURATION": re.compile(r"\[MISCONFIGURATION\]\s*(.+?)(?:\n|$)"),
"CRYPTO ISSUE": re.compile(r"\[CRYPTO ISSUE\]\s*(.+?)(?:\n|$)"),
"SECRET EXPOSED": re.compile(r"\[SECRET EXPOSED\]\s*(.+?)(?:\n|$)"),
"VULNERABLE": re.compile(r"\[!\] VULNERABLE:\s*(.+?)(?:\n|$)"),
"DETECTED": re.compile(r"\[\w+\]\s*(.+?)\s*DETECTED", re.IGNORECASE),
}
cwe_match = re.search(r"CWE:\s*(CWE-\d+)", obs_text)
owasp_match = re.search(r"OWASP:\s*(.+?)(?:\n|$)", obs_text)
cvss_match = re.search(r"Suggested CVSS:\s*([\d.]+)\s*\((\w+)\)", obs_text)
evidence_match = re.search(r"Evidence:\s*(.+?)(?:\n|$)", obs_text)
remediation_match = re.search(r"Remediation:\s*(.+?)(?:\n|$)", obs_text)
for severity_hint, pat in patterns.items():
m = pat.search(obs_text)
if m:
title = m.group(1).strip()
# Also check for HIGH/MEDIUM severity labels
sev_label_match = re.search(r"\[(\w+)\].*DETECTED", obs_text)
severity = "High"
if cvss_match:
severity = cvss_match.group(2)
elif sev_label_match:
severity = sev_label_match.group(1).capitalize()
finding = {
"title": title,
"host": host,
"type": title,
"severity": severity,
}
if endpoint:
finding["endpoint"] = endpoint
if cwe_match:
finding["cwe"] = cwe_match.group(1)
if owasp_match:
finding["owasp"] = owasp_match.group(1).strip()
if cvss_match:
finding["cvss_score"] = float(cvss_match.group(1))
if evidence_match:
finding["evidence"] = evidence_match.group(1).strip()
if remediation_match:
finding["remediation"] = remediation_match.group(1).strip()
sub_obs = env.step(SecurityAuditAction(
action_type="submit_finding",
arguments=finding,
))
return True, getattr(sub_obs, "done", False)
return False, False
results = {}
for scenario_id in ["easy", "medium", "hard"]:
env = SecurityAuditEnvironment()
env.reset(scenario_id=scenario_id)
scenario = get_scenario(scenario_id)
done = False
# Phase 1: Initial network scan
obs, done = _do_step(env,
action_type="use_tool", tool_name="network_scan",
arguments={"target": scenario["target_network"]}
)
# We may need multiple passes to unlock hidden hosts
for _pass in range(3):
if done:
break
hosts_snapshot = list(env._discovered_hosts)
for host in hosts_snapshot:
if done:
break
# Crawl endpoints
crawl_obs, done = _do_step(env,
action_type="use_tool", tool_name="web_crawl",
arguments={"host": host}
)
if done:
break
# Extract discovered endpoints from crawl output
endpoints = []
for line in crawl_obs.tool_output.split("\n"):
ep_match = re.search(r"(?:GET|POST|PUT|DELETE|PATCH)\s+(/\S+)", line)
if ep_match:
endpoints.append(ep_match.group(1).strip())
# Test each endpoint with injection/xss tools
for ep in endpoints:
if done:
break
for tool in ["test_injection", "test_xss"]:
if done:
break
obs, done = _do_step(env,
action_type="use_tool", tool_name=tool,
arguments={"host": host, "endpoint": ep}
)
if not done:
_, done = _parse_and_submit(env, host, ep, tool, obs.tool_output)
# check_secrets per endpoint
if not done:
obs, done = _do_step(env,
action_type="use_tool", tool_name="check_secrets",
arguments={"host": host, "endpoint": ep}
)
if not done:
_, done = _parse_and_submit(env, host, ep, "check_secrets", obs.tool_output)
# Host-level tools (no endpoint needed)
for tool in ["test_auth", "test_config", "test_crypto", "vulnerability_scan"]:
if done:
break
obs, done = _do_step(env,
action_type="use_tool", tool_name=tool,
arguments={"host": host}
)
if not done:
_, done = _parse_and_submit(env, host, None, tool, obs.tool_output)
if done:
break
# Re-scan to discover newly unlocked hosts
obs, done = _do_step(env,
action_type="use_tool", tool_name="network_scan",
arguments={"target": scenario["target_network"]}
)
# If no new hosts appeared, stop iterating
if set(env._discovered_hosts) == set(hosts_snapshot):
break
# Generate final report (safe to call even after step limit —
# step() returns _finish_episode with grades regardless)
obs = env.step(SecurityAuditAction(action_type="generate_report"))
grades = obs.metadata.get("grades", {}) if obs.metadata else {}
results[scenario_id] = grades
scores = {sid: g.get("final_score", 0) for sid, g in results.items()}
# Reasoning gap: how much does performance drop when labels are removed?
# A perfect reasoning agent: gap = 0 (same score regardless of output format)
# A pure pattern matcher: gap = 1.0 (scores high on labeled, zero on raw)
easy_score = scores.get("easy", 0)
hard_score = scores.get("hard", 0)
reasoning_gap = round(easy_score - hard_score, 4) if easy_score > 0 else 0.0
return JSONResponse({
"baseline_scores": scores,
"reasoning_gap": reasoning_gap,
"reasoning_gap_interpretation": (
"Score difference between easy (labeled output) and hard (raw output). "
"Gap of 1.0 = pure pattern matcher. Gap of 0.0 = genuine reasoning."
),
"details": results,
})
def main(host: str = "0.0.0.0", port: int = 8000):
"""Entry point for direct execution."""
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()