opensoc-env / app_runtime.py
shivam2k3's picture
Add GET / -> /demo/ redirect for Space iframe
bf4094f
"""FastAPI application module for OpenSOC, mountable from server.py.
Endpoints follow the OpenEnv conventions plus a lightweight `/grade`:
POST /reset?task=<stage>&mode=<self_play|defender_only>&seed=<n>
POST /step?task=<stage>&mode=...&seed=<n> (body: Action)
GET /state?task=<stage>&mode=...&seed=<n>
POST /grade?task=<stage>&mode=...&seed=<n>
GET /tasks
GET /health
Per-(task, mode, seed) env instances are cached in a process-local dict so
multiple concurrent clients can share the FastAPI process without stepping
on each other's episodes.
This module does NOT inherit from openenv-core's MCPEnvironment because the
`craft_incident`/`submit_triage` action surface is non-MCP (single-action
unions are simpler for GRPO rollouts). Tool names are deliberately
non-reserved so an MCPEnvironment wrapper can be added later if a team
wants to expose the env over MCP transports.
"""
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from env import Action, Observation, OpenSOCEnv
app = FastAPI(
title="OpenSOC",
description="Self-play SOC triage OpenEnv environment for cybersecurity defender LLMs.",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
_envs: Dict[str, OpenSOCEnv] = {}
def _env_key(task: str, mode: str, seed: int) -> str:
return f"{task}::{mode}::{seed}"
def _get_env(task: str, mode: str, seed: int) -> OpenSOCEnv:
key = _env_key(task, mode, seed)
if key not in _envs:
try:
_envs[key] = OpenSOCEnv(task_id=task, mode=mode, seed=seed) # type: ignore[arg-type]
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return _envs[key]
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class StepResult(BaseModel):
observation: Observation
reward: float
done: bool
info: Dict[str, Any]
class GradeResult(BaseModel):
task: str
mode: str
score: float
defender_reward: Optional[float]
attacker_reward: Optional[float]
ground_truth: Optional[str]
plausible: Optional[bool]
schema_violation: bool
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.post("/reset", response_model=Observation)
def reset(
task: str = Query("stage1_basic", description="Curriculum stage id."),
mode: str = Query("defender_only", description="self_play | defender_only"),
seed: int = Query(0),
):
"""Reset the environment and return the initial observation."""
env = _get_env(task, mode, seed)
return env.reset()
@app.post("/step", response_model=StepResult)
def step(
action: Action,
task: str = Query("stage1_basic"),
mode: str = Query("defender_only"),
seed: int = Query(0),
):
"""Execute one action and return observation, reward, done, info."""
env = _get_env(task, mode, seed)
if env._state is None:
raise HTTPException(status_code=400, detail="Call /reset first.")
try:
obs, reward, done, info = env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return StepResult(observation=obs, reward=reward, done=done, info=info)
@app.get("/state")
def state(
task: str = Query("stage1_basic"),
mode: str = Query("defender_only"),
seed: int = Query(0),
):
"""Return the full internal episode state."""
env = _get_env(task, mode, seed)
return env.state()
@app.get("/tasks")
def list_tasks():
"""List the available curriculum stages."""
from tasks.registry import STAGE_REGISTRY
return {
"tasks": [
{"id": stage_id, "difficulty": cfg["difficulty"], "description": cfg["description"]}
for stage_id, cfg in STAGE_REGISTRY.items()
],
"modes": ["self_play", "defender_only"],
}
@app.post("/grade", response_model=GradeResult)
def grade(
task: str = Query("stage1_basic"),
mode: str = Query("defender_only"),
seed: int = Query(0),
):
"""Compute a normalized [0, 1] score for the just-finished episode."""
env = _get_env(task, mode, seed)
if env._state is None:
raise HTTPException(status_code=400, detail="No episode to grade. Call /reset first.")
s = env._state
return GradeResult(
task=task,
mode=mode,
score=env.grade(),
defender_reward=s.defender_reward,
attacker_reward=s.attacker_reward,
ground_truth=s.ground_truth.value if s.ground_truth else None,
plausible=s.plausible,
schema_violation=s.schema_violation,
)
@app.get("/health")
def health():
return {"status": "ok", "env": "OpenSOC", "version": "1.0.0"}
@app.get("/", include_in_schema=False)
def index():
# Spaces iframes load the root URL; send human visitors to the Gradio
# demo and leave the JSON API endpoints untouched for the OpenEnv judge.
return RedirectResponse(url="/demo/", status_code=307)
def main() -> None:
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)