"""FastAPI application module for OpenSOC, mountable from server.py. Endpoints follow the OpenEnv conventions plus a lightweight `/grade`: POST /reset?task=&mode=&seed= POST /step?task=&mode=...&seed= (body: Action) GET /state?task=&mode=...&seed= POST /grade?task=&mode=...&seed= GET /tasks GET /health Per-(task, mode, seed) env instances are cached in a process-local dict so multiple concurrent clients can share the FastAPI process without stepping on each other's episodes. This module does NOT inherit from openenv-core's MCPEnvironment because the `craft_incident`/`submit_triage` action surface is non-MCP (single-action unions are simpler for GRPO rollouts). Tool names are deliberately non-reserved so an MCPEnvironment wrapper can be added later if a team wants to expose the env over MCP transports. """ from __future__ import annotations import os from typing import Any, Dict, Optional from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from pydantic import BaseModel from env import Action, Observation, OpenSOCEnv app = FastAPI( title="OpenSOC", description="Self-play SOC triage OpenEnv environment for cybersecurity defender LLMs.", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) _envs: Dict[str, OpenSOCEnv] = {} def _env_key(task: str, mode: str, seed: int) -> str: return f"{task}::{mode}::{seed}" def _get_env(task: str, mode: str, seed: int) -> OpenSOCEnv: key = _env_key(task, mode, seed) if key not in _envs: try: _envs[key] = OpenSOCEnv(task_id=task, mode=mode, seed=seed) # type: ignore[arg-type] except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc return _envs[key] # --------------------------------------------------------------------------- # Response models # --------------------------------------------------------------------------- class StepResult(BaseModel): observation: Observation reward: float done: bool info: Dict[str, Any] class GradeResult(BaseModel): task: str mode: str score: float defender_reward: Optional[float] attacker_reward: Optional[float] ground_truth: Optional[str] plausible: Optional[bool] schema_violation: bool # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.post("/reset", response_model=Observation) def reset( task: str = Query("stage1_basic", description="Curriculum stage id."), mode: str = Query("defender_only", description="self_play | defender_only"), seed: int = Query(0), ): """Reset the environment and return the initial observation.""" env = _get_env(task, mode, seed) return env.reset() @app.post("/step", response_model=StepResult) def step( action: Action, task: str = Query("stage1_basic"), mode: str = Query("defender_only"), seed: int = Query(0), ): """Execute one action and return observation, reward, done, info.""" env = _get_env(task, mode, seed) if env._state is None: raise HTTPException(status_code=400, detail="Call /reset first.") try: obs, reward, done, info = env.step(action) except RuntimeError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc return StepResult(observation=obs, reward=reward, done=done, info=info) @app.get("/state") def state( task: str = Query("stage1_basic"), mode: str = Query("defender_only"), seed: int = Query(0), ): """Return the full internal episode state.""" env = _get_env(task, mode, seed) return env.state() @app.get("/tasks") def list_tasks(): """List the available curriculum stages.""" from tasks.registry import STAGE_REGISTRY return { "tasks": [ {"id": stage_id, "difficulty": cfg["difficulty"], "description": cfg["description"]} for stage_id, cfg in STAGE_REGISTRY.items() ], "modes": ["self_play", "defender_only"], } @app.post("/grade", response_model=GradeResult) def grade( task: str = Query("stage1_basic"), mode: str = Query("defender_only"), seed: int = Query(0), ): """Compute a normalized [0, 1] score for the just-finished episode.""" env = _get_env(task, mode, seed) if env._state is None: raise HTTPException(status_code=400, detail="No episode to grade. Call /reset first.") s = env._state return GradeResult( task=task, mode=mode, score=env.grade(), defender_reward=s.defender_reward, attacker_reward=s.attacker_reward, ground_truth=s.ground_truth.value if s.ground_truth else None, plausible=s.plausible, schema_violation=s.schema_violation, ) @app.get("/health") def health(): return {"status": "ok", "env": "OpenSOC", "version": "1.0.0"} @app.get("/", include_in_schema=False) def index(): # Spaces iframes load the root URL; send human visitors to the Gradio # demo and leave the JSON API endpoints untouched for the OpenEnv judge. return RedirectResponse(url="/demo/", status_code=307) def main() -> None: import uvicorn port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)