Spaces:
Running
Running
File size: 6,180 Bytes
c1060df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
AdaptShield FastAPI Server
CRITICAL: Uses factory pattern (make_env function), NOT singleton.
Singleton was the Round 1 failure — always served wrong task.
Factory creates a fresh isolated instance per evaluator session.
openenv validate requires:
- def main() function present
- called as main() in if __name__ block (literal string check)
- port 7860 (HF Spaces default)
"""
import os
import sys
from typing import Any, Dict
from uuid import uuid4
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from fastapi import Body, HTTPException
from openenv.core.env_server.http_server import create_app
except Exception as e:
raise ImportError(
"openenv-core required. Install: pip install openenv-core"
) from e
from models import AdaptShieldAction, AdaptShieldObservation
from server.adaptshield_environment import AdaptShieldEnvironment
DEFAULT_TASK = os.getenv("ADAPTSHIELD_TASK", "direct-triage")
SOC_SESSIONS: Dict[str, AdaptShieldEnvironment] = {}
def make_env() -> AdaptShieldEnvironment:
"""
Factory function — fresh isolated instance per session.
Never a singleton. Evaluator sessions must be independent.
"""
return AdaptShieldEnvironment(task_name=DEFAULT_TASK)
app = create_app(
make_env,
AdaptShieldAction,
AdaptShieldObservation,
env_name="adaptshield",
max_concurrent_envs=10,
)
@app.post("/soc/reset", tags=["AdaptShield SOC Tools"])
async def soc_reset(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Start a persistent demo session for SOC tool/API workflows."""
task = str(payload.get("task", DEFAULT_TASK))
env = AdaptShieldEnvironment(task_name=task)
obs = env.reset()
session_id = str(uuid4())
SOC_SESSIONS[session_id] = env
return {
"session_id": session_id,
"observation": obs.model_dump(mode="json"),
"available_tools": obs.metadata.get("available_tools", []),
}
@app.post("/soc/step", tags=["AdaptShield SOC Tools"])
async def soc_step(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Step a persistent SOC tool/API session."""
env = _soc_session(payload)
try:
action = AdaptShieldAction(**dict(payload.get("action", {})))
except Exception as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
obs = env.step(action)
return {
"session_id": payload.get("session_id"),
"observation": obs.model_dump(mode="json"),
"reward": float(obs.reward),
"done": bool(obs.done),
}
@app.post("/tools/log_search", tags=["AdaptShield SOC Tools"])
async def tool_log_search(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Search stateful SIEM/application logs for the active session."""
return _soc_session(payload).call_tool(
"log_search",
node=payload.get("node", payload.get("target_node", "unknown")),
query=payload.get("query", ""),
)
@app.post("/tools/cmdb_lookup", tags=["AdaptShield SOC Tools"])
async def tool_cmdb_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Look up service ownership, criticality, and dependency blast radius."""
return _soc_session(payload).call_tool(
"cmdb_lookup",
node=payload.get("node", payload.get("target_node", "unknown")),
)
@app.post("/tools/edr_status", tags=["AdaptShield SOC Tools"])
async def tool_edr_status(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Check endpoint containment and persistence indicators."""
return _soc_session(payload).call_tool(
"edr_status",
node=payload.get("node", payload.get("target_node", "unknown")),
)
@app.post("/tools/vuln_lookup", tags=["AdaptShield SOC Tools"])
async def tool_vuln_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Query internal vulnerability/advisory evidence for a service package."""
return _soc_session(payload).call_tool(
"vuln_lookup",
node=payload.get("node", payload.get("target_node", "unknown")),
package=payload.get("package", ""),
)
@app.post("/tools/identity_lookup", tags=["AdaptShield SOC Tools"])
async def tool_identity_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Inspect account behavior and unusual source-host affinity for a service identity."""
return _soc_session(payload).call_tool(
"identity_lookup",
node=payload.get("node", payload.get("target_node", "unknown")),
)
@app.post("/tools/change_calendar_lookup", tags=["AdaptShield SOC Tools"])
async def tool_change_calendar_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Check whether a deploy or maintenance window was actually scheduled."""
return _soc_session(payload).call_tool(
"change_calendar_lookup",
node=payload.get("node", payload.get("target_node", "unknown")),
)
@app.post("/tools/netflow_lookup", tags=["AdaptShield SOC Tools"])
async def tool_netflow_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
"""Inspect east-west and outbound traffic summaries for the active session."""
return _soc_session(payload).call_tool(
"netflow_lookup",
node=payload.get("node", payload.get("target_node", "unknown")),
)
def _soc_session(payload: Dict[str, Any]) -> AdaptShieldEnvironment:
session_id = str(payload.get("session_id", ""))
env = SOC_SESSIONS.get(session_id)
if env is None:
raise HTTPException(
status_code=404,
detail="Unknown SOC session. Call /soc/reset first.",
)
return env
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
"""Start the uvicorn server. Call main() to run."""
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
main()
|