Demo / server /app.py
Ajayyy00
Add alternating self-play training scaffolding.
292f6a5
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
FastAPI application for the CyberSOCEnv Environment.
Endpoints:
- POST /reset: Reset the environment (pass task_id in body)
- POST /step: Execute an action
- GET /state: Get current environment state
- GET /schema: Get action/observation schemas
- WS /ws: WebSocket endpoint for persistent sessions
Usage:
# Development (with auto-reload):
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
# Production:
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
"""
import os
import random
try:
from openenv.core.env_server.http_server import create_app
except Exception as e: # pragma: no cover
raise ImportError(
"openenv is required. Install with: pip install 'openenv-core[core]'"
) from e
try:
from ..models import SOCObservation, SOCActionWrapper
from .play_environment import CyberSOCEnvironment
except (ImportError, ModuleNotFoundError):
from models import SOCObservation, SOCActionWrapper
from server.play_environment import CyberSOCEnvironment
class FrozenCheckpointRedPolicy:
"""Lightweight frozen red policy shim keyed by checkpoint identity."""
def __init__(self, checkpoint: str):
self.checkpoint = checkpoint
self._rng = random.Random(hash(checkpoint))
def act(self, red_observation):
blue_action = red_observation.get("blue_action_type", "")
# Deterministic-ish heuristic controlled by checkpoint hash and episode step.
trigger = blue_action in {"kill_process", "isolate_segment"}
if trigger and self._rng.random() < 0.5:
return {
"action_type": "lateral_pivot",
"source_host": red_observation.get("blue_action_target", ""),
}
return {"action_type": "noop"}
_frozen_checkpoint = os.environ.get("CYBERSOC_FROZEN_RED_CHECKPOINT", "").strip()
_adaptive = os.environ.get("CYBERSOC_ADAPTIVE", "1").strip() not in {"0", "false", "False"}
_red_policy = FrozenCheckpointRedPolicy(_frozen_checkpoint) if _frozen_checkpoint else None
class ConfiguredCyberSOCEnvironment(CyberSOCEnvironment):
def __init__(self):
super().__init__(adaptive=_adaptive, neural_red_policy=_red_policy)
# Create the app with the CyberSOCEnv environment
app = create_app(
ConfiguredCyberSOCEnvironment,
SOCActionWrapper,
SOCObservation,
env_name="cybersocenv",
max_concurrent_envs=16,
)
# CORS middleware — allows the dashboard frontend to communicate with the server
from starlette.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from fastapi.responses import JSONResponse
@app.get("/health")
async def health():
"""Lightweight liveness probe used by the dashboard to check server reachability.
Returns only status — never returns episode state, so the dashboard
cannot accidentally render stale data from a previous session.
"""
return JSONResponse({"status": "ok", "service": "cybersocenv"})
def main(host: str = "0.0.0.0", port: int = 8000):
"""Entry point for direct execution.
Usage:
python -m play.server.app
python -m play.server.app --port 8001
"""
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()