"""FastAPI server exposing the Adaptive AI Firewall environment. Endpoints: POST /reset — Start a new episode POST /step — Multi-session step (batch actions) POST /step_single — Single-session step (Gymnasium-compatible) GET /state — Current environment state GET /tools — List available tool names POST /tool/{name} — Call a specific tool GET /health — Health check GET /stats — Current episode statistics """ from __future__ import annotations import os from typing import Any from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from dotenv import load_dotenv from server.firewall_environment import FirewallEnvironment, ACTIONS from models import ( HealthResponse, NetworkStatsResponse, ResetRequest, StateResponse, StepRequest, StepResponse, StepSingleRequest, StepSingleResponse, ToolRequest, ToolsListResponse, ) load_dotenv() def _clean_env_value(value: str) -> str: return value.strip().strip("`").strip().strip("'").strip('"').strip() def _resolve_api_key(value: str | None) -> str: return _clean_env_value(value or os.getenv("HF_TOKEN") or "") def _resolve_model(value: str | None) -> str: return _clean_env_value(value or os.getenv("MODEL_NAME") or "") def _resolve_base_url(value: str | None) -> str: return _clean_env_value( value or os.getenv("API_BASE_URL") or "" ) PLAYGROUND_HTML = """ Adaptive Firewall Playground

Playground

Click Reset to start a new episode.

Ready
{}
""" env = FirewallEnvironment(seed=42) app = FastAPI( title="Adaptive AI Firewall OpenEnv", version="0.2.0", description="RL environment for adaptive firewall decision making on encrypted traffic.", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", response_class=HTMLResponse) def root() -> HTMLResponse: """Redirect root to the playground UI.""" return HTMLResponse(content=PLAYGROUND_HTML) @app.get("/health", response_model=HealthResponse) def health() -> HealthResponse: return HealthResponse(status="ok", version="0.2.0") @app.post("/reset", response_model=StateResponse) def reset(request: ResetRequest = ResetRequest()) -> StateResponse: try: state = env.reset(task=request.task, seed=request.seed) return StateResponse(**state) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e @app.post("/step", response_model=StepResponse) def step(request: StepRequest = StepRequest()) -> StepResponse: result = env.step(action_map=request.actions) return StepResponse(**result) @app.post("/step_single", response_model=StepSingleResponse) def step_single(request: StepSingleRequest = None) -> StepSingleResponse: if request is None: raise HTTPException(status_code=422, detail="Body is required for /step_single") try: result = env.step_single(action=request.action) return StepSingleResponse(**result) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e @app.get("/state", response_model=StateResponse) def state() -> StateResponse: return StateResponse(**env.state()) @app.get("/stats", response_model=NetworkStatsResponse) def stats() -> NetworkStatsResponse: return NetworkStatsResponse(**env.get_network_stats()) @app.get("/tools", response_model=ToolsListResponse) def list_tools() -> ToolsListResponse: return ToolsListResponse(tools=env.list_tools()) @app.get("/web", response_class=HTMLResponse) def web_interface() -> HTMLResponse: return HTMLResponse(content=PLAYGROUND_HTML) @app.get("/schema") def schema() -> Any: return { "observation_space": { "type": "Box", "shape": [22], "low": 0.0, "high": 1.0, }, "action_space": { "type": "Discrete", "n": 6, "actions": ACTIONS, }, } @app.post("/tool/{name}") def call_tool(name: str, request: ToolRequest) -> Any: try: if name == "evaluate_session": return env.evaluate_session(request.kwargs["session_id"]) if name == "take_action": reward, record = env.take_action( session_id=request.kwargs["session_id"], action=int(request.kwargs["action"]), ) return {"reward": reward, "record": record} if name == "get_network_stats": return env.get_network_stats() if name == "get_threat_intelligence": return env.get_threat_intelligence() raise HTTPException(status_code=404, detail=f"unknown tool: {name}") except KeyError as exc: raise HTTPException(status_code=400, detail=f"missing key: {exc}") from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc def main() -> None: import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) if __name__ == "__main__": main()