| """FastAPI application module for OpenSOC, mountable from server.py. |
| |
| Endpoints follow the OpenEnv conventions plus a lightweight `/grade`: |
| |
| POST /reset?task=<stage>&mode=<self_play|defender_only>&seed=<n> |
| POST /step?task=<stage>&mode=...&seed=<n> (body: Action) |
| GET /state?task=<stage>&mode=...&seed=<n> |
| POST /grade?task=<stage>&mode=...&seed=<n> |
| GET /tasks |
| GET /health |
| |
| Per-(task, mode, seed) env instances are cached in a process-local dict so |
| multiple concurrent clients can share the FastAPI process without stepping |
| on each other's episodes. |
| |
| This module does NOT inherit from openenv-core's MCPEnvironment because the |
| `craft_incident`/`submit_triage` action surface is non-MCP (single-action |
| unions are simpler for GRPO rollouts). Tool names are deliberately |
| non-reserved so an MCPEnvironment wrapper can be added later if a team |
| wants to expose the env over MCP transports. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Any, Dict, Optional |
|
|
| from fastapi import FastAPI, HTTPException, Query |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import RedirectResponse |
| from pydantic import BaseModel |
|
|
| from env import Action, Observation, OpenSOCEnv |
|
|
|
|
| app = FastAPI( |
| title="OpenSOC", |
| description="Self-play SOC triage OpenEnv environment for cybersecurity defender LLMs.", |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| _envs: Dict[str, OpenSOCEnv] = {} |
|
|
|
|
| def _env_key(task: str, mode: str, seed: int) -> str: |
| return f"{task}::{mode}::{seed}" |
|
|
|
|
| def _get_env(task: str, mode: str, seed: int) -> OpenSOCEnv: |
| key = _env_key(task, mode, seed) |
| if key not in _envs: |
| try: |
| _envs[key] = OpenSOCEnv(task_id=task, mode=mode, seed=seed) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| return _envs[key] |
|
|
|
|
| |
| |
| |
|
|
| class StepResult(BaseModel): |
| observation: Observation |
| reward: float |
| done: bool |
| info: Dict[str, Any] |
|
|
|
|
| class GradeResult(BaseModel): |
| task: str |
| mode: str |
| score: float |
| defender_reward: Optional[float] |
| attacker_reward: Optional[float] |
| ground_truth: Optional[str] |
| plausible: Optional[bool] |
| schema_violation: bool |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/reset", response_model=Observation) |
| def reset( |
| task: str = Query("stage1_basic", description="Curriculum stage id."), |
| mode: str = Query("defender_only", description="self_play | defender_only"), |
| seed: int = Query(0), |
| ): |
| """Reset the environment and return the initial observation.""" |
| env = _get_env(task, mode, seed) |
| return env.reset() |
|
|
|
|
| @app.post("/step", response_model=StepResult) |
| def step( |
| action: Action, |
| task: str = Query("stage1_basic"), |
| mode: str = Query("defender_only"), |
| seed: int = Query(0), |
| ): |
| """Execute one action and return observation, reward, done, info.""" |
| env = _get_env(task, mode, seed) |
| if env._state is None: |
| raise HTTPException(status_code=400, detail="Call /reset first.") |
| try: |
| obs, reward, done, info = env.step(action) |
| except RuntimeError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| return StepResult(observation=obs, reward=reward, done=done, info=info) |
|
|
|
|
| @app.get("/state") |
| def state( |
| task: str = Query("stage1_basic"), |
| mode: str = Query("defender_only"), |
| seed: int = Query(0), |
| ): |
| """Return the full internal episode state.""" |
| env = _get_env(task, mode, seed) |
| return env.state() |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks(): |
| """List the available curriculum stages.""" |
| from tasks.registry import STAGE_REGISTRY |
| return { |
| "tasks": [ |
| {"id": stage_id, "difficulty": cfg["difficulty"], "description": cfg["description"]} |
| for stage_id, cfg in STAGE_REGISTRY.items() |
| ], |
| "modes": ["self_play", "defender_only"], |
| } |
|
|
|
|
| @app.post("/grade", response_model=GradeResult) |
| def grade( |
| task: str = Query("stage1_basic"), |
| mode: str = Query("defender_only"), |
| seed: int = Query(0), |
| ): |
| """Compute a normalized [0, 1] score for the just-finished episode.""" |
| env = _get_env(task, mode, seed) |
| if env._state is None: |
| raise HTTPException(status_code=400, detail="No episode to grade. Call /reset first.") |
| s = env._state |
| return GradeResult( |
| task=task, |
| mode=mode, |
| score=env.grade(), |
| defender_reward=s.defender_reward, |
| attacker_reward=s.attacker_reward, |
| ground_truth=s.ground_truth.value if s.ground_truth else None, |
| plausible=s.plausible, |
| schema_violation=s.schema_violation, |
| ) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "env": "OpenSOC", "version": "1.0.0"} |
|
|
|
|
| @app.get("/", include_in_schema=False) |
| def index(): |
| |
| |
| return RedirectResponse(url="/demo/", status_code=307) |
|
|
|
|
| def main() -> None: |
| import uvicorn |
|
|
| port = int(os.getenv("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|