File size: 5,585 Bytes
bb6a031 bf4094f bb6a031 bf4094f bb6a031 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """FastAPI application module for OpenSOC, mountable from server.py.
Endpoints follow the OpenEnv conventions plus a lightweight `/grade`:
POST /reset?task=<stage>&mode=<self_play|defender_only>&seed=<n>
POST /step?task=<stage>&mode=...&seed=<n> (body: Action)
GET /state?task=<stage>&mode=...&seed=<n>
POST /grade?task=<stage>&mode=...&seed=<n>
GET /tasks
GET /health
Per-(task, mode, seed) env instances are cached in a process-local dict so
multiple concurrent clients can share the FastAPI process without stepping
on each other's episodes.
This module does NOT inherit from openenv-core's MCPEnvironment because the
`craft_incident`/`submit_triage` action surface is non-MCP (single-action
unions are simpler for GRPO rollouts). Tool names are deliberately
non-reserved so an MCPEnvironment wrapper can be added later if a team
wants to expose the env over MCP transports.
"""
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from env import Action, Observation, OpenSOCEnv
app = FastAPI(
title="OpenSOC",
description="Self-play SOC triage OpenEnv environment for cybersecurity defender LLMs.",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
_envs: Dict[str, OpenSOCEnv] = {}
def _env_key(task: str, mode: str, seed: int) -> str:
return f"{task}::{mode}::{seed}"
def _get_env(task: str, mode: str, seed: int) -> OpenSOCEnv:
key = _env_key(task, mode, seed)
if key not in _envs:
try:
_envs[key] = OpenSOCEnv(task_id=task, mode=mode, seed=seed) # type: ignore[arg-type]
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return _envs[key]
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class StepResult(BaseModel):
observation: Observation
reward: float
done: bool
info: Dict[str, Any]
class GradeResult(BaseModel):
task: str
mode: str
score: float
defender_reward: Optional[float]
attacker_reward: Optional[float]
ground_truth: Optional[str]
plausible: Optional[bool]
schema_violation: bool
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.post("/reset", response_model=Observation)
def reset(
task: str = Query("stage1_basic", description="Curriculum stage id."),
mode: str = Query("defender_only", description="self_play | defender_only"),
seed: int = Query(0),
):
"""Reset the environment and return the initial observation."""
env = _get_env(task, mode, seed)
return env.reset()
@app.post("/step", response_model=StepResult)
def step(
action: Action,
task: str = Query("stage1_basic"),
mode: str = Query("defender_only"),
seed: int = Query(0),
):
"""Execute one action and return observation, reward, done, info."""
env = _get_env(task, mode, seed)
if env._state is None:
raise HTTPException(status_code=400, detail="Call /reset first.")
try:
obs, reward, done, info = env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return StepResult(observation=obs, reward=reward, done=done, info=info)
@app.get("/state")
def state(
task: str = Query("stage1_basic"),
mode: str = Query("defender_only"),
seed: int = Query(0),
):
"""Return the full internal episode state."""
env = _get_env(task, mode, seed)
return env.state()
@app.get("/tasks")
def list_tasks():
"""List the available curriculum stages."""
from tasks.registry import STAGE_REGISTRY
return {
"tasks": [
{"id": stage_id, "difficulty": cfg["difficulty"], "description": cfg["description"]}
for stage_id, cfg in STAGE_REGISTRY.items()
],
"modes": ["self_play", "defender_only"],
}
@app.post("/grade", response_model=GradeResult)
def grade(
task: str = Query("stage1_basic"),
mode: str = Query("defender_only"),
seed: int = Query(0),
):
"""Compute a normalized [0, 1] score for the just-finished episode."""
env = _get_env(task, mode, seed)
if env._state is None:
raise HTTPException(status_code=400, detail="No episode to grade. Call /reset first.")
s = env._state
return GradeResult(
task=task,
mode=mode,
score=env.grade(),
defender_reward=s.defender_reward,
attacker_reward=s.attacker_reward,
ground_truth=s.ground_truth.value if s.ground_truth else None,
plausible=s.plausible,
schema_violation=s.schema_violation,
)
@app.get("/health")
def health():
return {"status": "ok", "env": "OpenSOC", "version": "1.0.0"}
@app.get("/", include_in_schema=False)
def index():
# Spaces iframes load the root URL; send human visitors to the Gradio
# demo and leave the JSON API endpoints untouched for the OpenEnv judge.
return RedirectResponse(url="/demo/", status_code=307)
def main() -> None:
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|