File size: 5,585 Bytes
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4094f
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4094f
 
 
 
 
 
 
bb6a031
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""FastAPI application module for OpenSOC, mountable from server.py.

Endpoints follow the OpenEnv conventions plus a lightweight `/grade`:

  POST /reset?task=<stage>&mode=<self_play|defender_only>&seed=<n>
  POST /step?task=<stage>&mode=...&seed=<n>          (body: Action)
  GET  /state?task=<stage>&mode=...&seed=<n>
  POST /grade?task=<stage>&mode=...&seed=<n>
  GET  /tasks
  GET  /health

Per-(task, mode, seed) env instances are cached in a process-local dict so
multiple concurrent clients can share the FastAPI process without stepping
on each other's episodes.

This module does NOT inherit from openenv-core's MCPEnvironment because the
`craft_incident`/`submit_triage` action surface is non-MCP (single-action
unions are simpler for GRPO rollouts).  Tool names are deliberately
non-reserved so an MCPEnvironment wrapper can be added later if a team
wants to expose the env over MCP transports.
"""

from __future__ import annotations

import os
from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel

from env import Action, Observation, OpenSOCEnv


app = FastAPI(
    title="OpenSOC",
    description="Self-play SOC triage OpenEnv environment for cybersecurity defender LLMs.",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

_envs: Dict[str, OpenSOCEnv] = {}


def _env_key(task: str, mode: str, seed: int) -> str:
    return f"{task}::{mode}::{seed}"


def _get_env(task: str, mode: str, seed: int) -> OpenSOCEnv:
    key = _env_key(task, mode, seed)
    if key not in _envs:
        try:
            _envs[key] = OpenSOCEnv(task_id=task, mode=mode, seed=seed)  # type: ignore[arg-type]
        except ValueError as exc:
            raise HTTPException(status_code=400, detail=str(exc)) from exc
    return _envs[key]


# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------

class StepResult(BaseModel):
    observation: Observation
    reward: float
    done: bool
    info: Dict[str, Any]


class GradeResult(BaseModel):
    task: str
    mode: str
    score: float
    defender_reward: Optional[float]
    attacker_reward: Optional[float]
    ground_truth: Optional[str]
    plausible: Optional[bool]
    schema_violation: bool


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.post("/reset", response_model=Observation)
def reset(
    task: str = Query("stage1_basic", description="Curriculum stage id."),
    mode: str = Query("defender_only", description="self_play | defender_only"),
    seed: int = Query(0),
):
    """Reset the environment and return the initial observation."""
    env = _get_env(task, mode, seed)
    return env.reset()


@app.post("/step", response_model=StepResult)
def step(
    action: Action,
    task: str = Query("stage1_basic"),
    mode: str = Query("defender_only"),
    seed: int = Query(0),
):
    """Execute one action and return observation, reward, done, info."""
    env = _get_env(task, mode, seed)
    if env._state is None:
        raise HTTPException(status_code=400, detail="Call /reset first.")
    try:
        obs, reward, done, info = env.step(action)
    except RuntimeError as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc
    return StepResult(observation=obs, reward=reward, done=done, info=info)


@app.get("/state")
def state(
    task: str = Query("stage1_basic"),
    mode: str = Query("defender_only"),
    seed: int = Query(0),
):
    """Return the full internal episode state."""
    env = _get_env(task, mode, seed)
    return env.state()


@app.get("/tasks")
def list_tasks():
    """List the available curriculum stages."""
    from tasks.registry import STAGE_REGISTRY
    return {
        "tasks": [
            {"id": stage_id, "difficulty": cfg["difficulty"], "description": cfg["description"]}
            for stage_id, cfg in STAGE_REGISTRY.items()
        ],
        "modes": ["self_play", "defender_only"],
    }


@app.post("/grade", response_model=GradeResult)
def grade(
    task: str = Query("stage1_basic"),
    mode: str = Query("defender_only"),
    seed: int = Query(0),
):
    """Compute a normalized [0, 1] score for the just-finished episode."""
    env = _get_env(task, mode, seed)
    if env._state is None:
        raise HTTPException(status_code=400, detail="No episode to grade. Call /reset first.")
    s = env._state
    return GradeResult(
        task=task,
        mode=mode,
        score=env.grade(),
        defender_reward=s.defender_reward,
        attacker_reward=s.attacker_reward,
        ground_truth=s.ground_truth.value if s.ground_truth else None,
        plausible=s.plausible,
        schema_violation=s.schema_violation,
    )


@app.get("/health")
def health():
    return {"status": "ok", "env": "OpenSOC", "version": "1.0.0"}


@app.get("/", include_in_schema=False)
def index():
    # Spaces iframes load the root URL; send human visitors to the Gradio
    # demo and leave the JSON API endpoints untouched for the OpenEnv judge.
    return RedirectResponse(url="/demo/", status_code=307)


def main() -> None:
    import uvicorn

    port = int(os.getenv("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)