Sync: compliance mapping, anti-gaming, 55 tests, mandatory stdout format, pivoting+compliance weights
c1a5935 verified | # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| FastAPI application for the Security Audit Environment. | |
| """ | |
| try: | |
| from openenv.core.env_server.http_server import create_app | |
| except Exception as e: | |
| raise ImportError( | |
| "openenv is required. Install with: pip install openenv-core" | |
| ) from e | |
| try: | |
| from models import SecurityAuditAction, SecurityAuditObservation | |
| from server.security_audit_env_environment import SecurityAuditEnvironment | |
| from server.scenarios import list_scenarios | |
| except ImportError: | |
| from ..models import SecurityAuditAction, SecurityAuditObservation | |
| from .security_audit_env_environment import SecurityAuditEnvironment | |
| from .scenarios import list_scenarios | |
| from typing import Any, Dict, List | |
| from pydantic import BaseModel, Field | |
| from fastapi.responses import JSONResponse | |
| class GraderRequest(BaseModel): | |
| """Request body for the /grader endpoint.""" | |
| scenario_id: str = Field(default="easy", description="Scenario to grade against") | |
| findings: List[Dict[str, Any]] = Field(default_factory=list) | |
| discovered_hosts: List[str] = Field(default_factory=list) | |
| discovered_ports: Dict[str, List[int]] = Field(default_factory=dict) | |
| steps_used: int = Field(default=0) | |
| app = create_app( | |
| SecurityAuditEnvironment, | |
| SecurityAuditAction, | |
| SecurityAuditObservation, | |
| env_name="security_audit_env", | |
| max_concurrent_envs=4, | |
| ) | |
| # --- Health check --- | |
| async def health(): | |
| """Health check endpoint for container orchestration.""" | |
| return {"status": "healthy", "environment": "security_audit_env"} | |
| # --- Custom Hackathon Endpoints --- | |
| async def get_tasks(): | |
| """Return list of available tasks and the action schema.""" | |
| scenarios = list_scenarios() | |
| action_schema = SecurityAuditAction.model_json_schema() | |
| return JSONResponse({ | |
| "tasks": scenarios, | |
| "action_schema": action_schema, | |
| "tools": [ | |
| "network_scan", "service_fingerprint", "web_crawl", | |
| "vulnerability_scan", "test_injection", "test_xss", | |
| "test_auth", "test_config", "test_crypto", "check_secrets", | |
| ], | |
| }) | |
| async def run_grader(data: GraderRequest): | |
| """Return grader scores for a completed episode.""" | |
| try: | |
| from server.scenarios import get_scenario | |
| from server.grader import grade_episode | |
| except ImportError: | |
| from .scenarios import get_scenario | |
| from .grader import grade_episode | |
| scenario = get_scenario(data.scenario_id) | |
| grades = grade_episode( | |
| scenario, data.findings, data.discovered_hosts, | |
| data.discovered_ports, steps_used=data.steps_used, | |
| ) | |
| return JSONResponse(grades) | |
| async def run_baseline(): | |
| """Trigger baseline inference and return scores for all 3 tasks. | |
| Runs a deterministic audit agent (no LLM) that scans, tests endpoints, | |
| parses tool output for detections, submits findings, and pivots through | |
| discovered vulns to unlock hidden hosts. | |
| """ | |
| import re | |
| try: | |
| from server.scenarios import get_scenario | |
| except ImportError: | |
| from .scenarios import get_scenario | |
| def _do_step(env, **kwargs): | |
| """Step and return (obs, done).""" | |
| obs = env.step(SecurityAuditAction(**kwargs)) | |
| return obs, getattr(obs, "done", False) | |
| def _parse_and_submit(env, host, endpoint, tool_name, obs_text): | |
| """Parse tool output for detections and submit findings.""" | |
| # Patterns that indicate a vulnerability was found | |
| patterns = { | |
| "CRITICAL": re.compile(r"\[CRITICAL\]\s*(.+?)(?:\n|$)"), | |
| "ALERT": re.compile(r"\[ALERT\]\s*(.+?)(?:\n|$)"), | |
| "MISCONFIGURATION": re.compile(r"\[MISCONFIGURATION\]\s*(.+?)(?:\n|$)"), | |
| "CRYPTO ISSUE": re.compile(r"\[CRYPTO ISSUE\]\s*(.+?)(?:\n|$)"), | |
| "SECRET EXPOSED": re.compile(r"\[SECRET EXPOSED\]\s*(.+?)(?:\n|$)"), | |
| "VULNERABLE": re.compile(r"\[!\] VULNERABLE:\s*(.+?)(?:\n|$)"), | |
| "DETECTED": re.compile(r"\[\w+\]\s*(.+?)\s*DETECTED", re.IGNORECASE), | |
| } | |
| cwe_match = re.search(r"CWE:\s*(CWE-\d+)", obs_text) | |
| owasp_match = re.search(r"OWASP:\s*(.+?)(?:\n|$)", obs_text) | |
| cvss_match = re.search(r"Suggested CVSS:\s*([\d.]+)\s*\((\w+)\)", obs_text) | |
| evidence_match = re.search(r"Evidence:\s*(.+?)(?:\n|$)", obs_text) | |
| remediation_match = re.search(r"Remediation:\s*(.+?)(?:\n|$)", obs_text) | |
| for severity_hint, pat in patterns.items(): | |
| m = pat.search(obs_text) | |
| if m: | |
| title = m.group(1).strip() | |
| # Also check for HIGH/MEDIUM severity labels | |
| sev_label_match = re.search(r"\[(\w+)\].*DETECTED", obs_text) | |
| severity = "High" | |
| if cvss_match: | |
| severity = cvss_match.group(2) | |
| elif sev_label_match: | |
| severity = sev_label_match.group(1).capitalize() | |
| finding = { | |
| "title": title, | |
| "host": host, | |
| "type": title, | |
| "severity": severity, | |
| } | |
| if endpoint: | |
| finding["endpoint"] = endpoint | |
| if cwe_match: | |
| finding["cwe"] = cwe_match.group(1) | |
| if owasp_match: | |
| finding["owasp"] = owasp_match.group(1).strip() | |
| if cvss_match: | |
| finding["cvss_score"] = float(cvss_match.group(1)) | |
| if evidence_match: | |
| finding["evidence"] = evidence_match.group(1).strip() | |
| if remediation_match: | |
| finding["remediation"] = remediation_match.group(1).strip() | |
| sub_obs = env.step(SecurityAuditAction( | |
| action_type="submit_finding", | |
| arguments=finding, | |
| )) | |
| return True, getattr(sub_obs, "done", False) | |
| return False, False | |
| results = {} | |
| for scenario_id in ["easy", "medium", "hard"]: | |
| env = SecurityAuditEnvironment() | |
| env.reset(scenario_id=scenario_id) | |
| scenario = get_scenario(scenario_id) | |
| done = False | |
| # Phase 1: Initial network scan | |
| obs, done = _do_step(env, | |
| action_type="use_tool", tool_name="network_scan", | |
| arguments={"target": scenario["target_network"]} | |
| ) | |
| # We may need multiple passes to unlock hidden hosts | |
| for _pass in range(3): | |
| if done: | |
| break | |
| hosts_snapshot = list(env._discovered_hosts) | |
| for host in hosts_snapshot: | |
| if done: | |
| break | |
| # Crawl endpoints | |
| crawl_obs, done = _do_step(env, | |
| action_type="use_tool", tool_name="web_crawl", | |
| arguments={"host": host} | |
| ) | |
| if done: | |
| break | |
| # Extract discovered endpoints from crawl output | |
| endpoints = [] | |
| for line in crawl_obs.tool_output.split("\n"): | |
| ep_match = re.search(r"(?:GET|POST|PUT|DELETE|PATCH)\s+(/\S+)", line) | |
| if ep_match: | |
| endpoints.append(ep_match.group(1).strip()) | |
| # Test each endpoint with injection/xss tools | |
| for ep in endpoints: | |
| if done: | |
| break | |
| for tool in ["test_injection", "test_xss"]: | |
| if done: | |
| break | |
| obs, done = _do_step(env, | |
| action_type="use_tool", tool_name=tool, | |
| arguments={"host": host, "endpoint": ep} | |
| ) | |
| if not done: | |
| _, done = _parse_and_submit(env, host, ep, tool, obs.tool_output) | |
| # check_secrets per endpoint | |
| if not done: | |
| obs, done = _do_step(env, | |
| action_type="use_tool", tool_name="check_secrets", | |
| arguments={"host": host, "endpoint": ep} | |
| ) | |
| if not done: | |
| _, done = _parse_and_submit(env, host, ep, "check_secrets", obs.tool_output) | |
| # Host-level tools (no endpoint needed) | |
| for tool in ["test_auth", "test_config", "test_crypto", "vulnerability_scan"]: | |
| if done: | |
| break | |
| obs, done = _do_step(env, | |
| action_type="use_tool", tool_name=tool, | |
| arguments={"host": host} | |
| ) | |
| if not done: | |
| _, done = _parse_and_submit(env, host, None, tool, obs.tool_output) | |
| if done: | |
| break | |
| # Re-scan to discover newly unlocked hosts | |
| obs, done = _do_step(env, | |
| action_type="use_tool", tool_name="network_scan", | |
| arguments={"target": scenario["target_network"]} | |
| ) | |
| # If no new hosts appeared, stop iterating | |
| if set(env._discovered_hosts) == set(hosts_snapshot): | |
| break | |
| # Generate final report (safe to call even after step limit — | |
| # step() returns _finish_episode with grades regardless) | |
| obs = env.step(SecurityAuditAction(action_type="generate_report")) | |
| grades = obs.metadata.get("grades", {}) if obs.metadata else {} | |
| results[scenario_id] = grades | |
| scores = {sid: g.get("final_score", 0) for sid, g in results.items()} | |
| # Reasoning gap: how much does performance drop when labels are removed? | |
| # A perfect reasoning agent: gap = 0 (same score regardless of output format) | |
| # A pure pattern matcher: gap = 1.0 (scores high on labeled, zero on raw) | |
| easy_score = scores.get("easy", 0) | |
| hard_score = scores.get("hard", 0) | |
| reasoning_gap = round(easy_score - hard_score, 4) if easy_score > 0 else 0.0 | |
| return JSONResponse({ | |
| "baseline_scores": scores, | |
| "reasoning_gap": reasoning_gap, | |
| "reasoning_gap_interpretation": ( | |
| "Score difference between easy (labeled output) and hard (raw output). " | |
| "Gap of 1.0 = pure pattern matcher. Gap of 0.0 = genuine reasoning." | |
| ), | |
| "details": results, | |
| }) | |
| def main(host: str = "0.0.0.0", port: int = 8000): | |
| """Entry point for direct execution.""" | |
| import uvicorn | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| main() | |