Spaces:
Running
Running
| """FastAPI application module for OpenSOC, mountable from server.py. | |
| Endpoints follow the OpenEnv conventions plus a lightweight `/grade`: | |
| POST /reset?task=<stage>&mode=<self_play|defender_only>&seed=<n> | |
| POST /step?task=<stage>&mode=...&seed=<n> (body: Action) | |
| GET /state?task=<stage>&mode=...&seed=<n> | |
| POST /grade?task=<stage>&mode=...&seed=<n> | |
| GET /tasks | |
| GET /health | |
| Per-(task, mode, seed) env instances are cached in a process-local dict so | |
| multiple concurrent clients can share the FastAPI process without stepping | |
| on each other's episodes. | |
| This module does NOT inherit from openenv-core's MCPEnvironment because the | |
| `craft_incident`/`submit_triage` action surface is non-MCP (single-action | |
| unions are simpler for GRPO rollouts). Tool names are deliberately | |
| non-reserved so an MCPEnvironment wrapper can be added later if a team | |
| wants to expose the env over MCP transports. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import RedirectResponse | |
| from pydantic import BaseModel | |
| from env import Action, Observation, OpenSOCEnv | |
| app = FastAPI( | |
| title="OpenSOC", | |
| description="Self-play SOC triage OpenEnv environment for cybersecurity defender LLMs.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| _envs: Dict[str, OpenSOCEnv] = {} | |
| def _env_key(task: str, mode: str, seed: int) -> str: | |
| return f"{task}::{mode}::{seed}" | |
| def _get_env(task: str, mode: str, seed: int) -> OpenSOCEnv: | |
| key = _env_key(task, mode, seed) | |
| if key not in _envs: | |
| try: | |
| _envs[key] = OpenSOCEnv(task_id=task, mode=mode, seed=seed) # type: ignore[arg-type] | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| return _envs[key] | |
| # --------------------------------------------------------------------------- | |
| # Response models | |
| # --------------------------------------------------------------------------- | |
| class StepResult(BaseModel): | |
| observation: Observation | |
| reward: float | |
| done: bool | |
| info: Dict[str, Any] | |
| class GradeResult(BaseModel): | |
| task: str | |
| mode: str | |
| score: float | |
| defender_reward: Optional[float] | |
| attacker_reward: Optional[float] | |
| ground_truth: Optional[str] | |
| plausible: Optional[bool] | |
| schema_violation: bool | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def reset( | |
| task: str = Query("stage1_basic", description="Curriculum stage id."), | |
| mode: str = Query("defender_only", description="self_play | defender_only"), | |
| seed: int = Query(0), | |
| ): | |
| """Reset the environment and return the initial observation.""" | |
| env = _get_env(task, mode, seed) | |
| return env.reset() | |
| def step( | |
| action: Action, | |
| task: str = Query("stage1_basic"), | |
| mode: str = Query("defender_only"), | |
| seed: int = Query(0), | |
| ): | |
| """Execute one action and return observation, reward, done, info.""" | |
| env = _get_env(task, mode, seed) | |
| if env._state is None: | |
| raise HTTPException(status_code=400, detail="Call /reset first.") | |
| try: | |
| obs, reward, done, info = env.step(action) | |
| except RuntimeError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| return StepResult(observation=obs, reward=reward, done=done, info=info) | |
| def state( | |
| task: str = Query("stage1_basic"), | |
| mode: str = Query("defender_only"), | |
| seed: int = Query(0), | |
| ): | |
| """Return the full internal episode state.""" | |
| env = _get_env(task, mode, seed) | |
| return env.state() | |
| def list_tasks(): | |
| """List the available curriculum stages.""" | |
| from tasks.registry import STAGE_REGISTRY | |
| return { | |
| "tasks": [ | |
| {"id": stage_id, "difficulty": cfg["difficulty"], "description": cfg["description"]} | |
| for stage_id, cfg in STAGE_REGISTRY.items() | |
| ], | |
| "modes": ["self_play", "defender_only"], | |
| } | |
| def grade( | |
| task: str = Query("stage1_basic"), | |
| mode: str = Query("defender_only"), | |
| seed: int = Query(0), | |
| ): | |
| """Compute a normalized [0, 1] score for the just-finished episode.""" | |
| env = _get_env(task, mode, seed) | |
| if env._state is None: | |
| raise HTTPException(status_code=400, detail="No episode to grade. Call /reset first.") | |
| s = env._state | |
| return GradeResult( | |
| task=task, | |
| mode=mode, | |
| score=env.grade(), | |
| defender_reward=s.defender_reward, | |
| attacker_reward=s.attacker_reward, | |
| ground_truth=s.ground_truth.value if s.ground_truth else None, | |
| plausible=s.plausible, | |
| schema_violation=s.schema_violation, | |
| ) | |
| def health(): | |
| return {"status": "ok", "env": "OpenSOC", "version": "1.0.0"} | |
| def index(): | |
| # Spaces iframes load the root URL; send human visitors to the Gradio | |
| # demo and leave the JSON API endpoints untouched for the OpenEnv judge. | |
| return RedirectResponse(url="/demo/", status_code=307) | |
| def main() -> None: | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |