"""FastAPI server exposing the Adaptive AI Firewall environment.
Endpoints:
POST /reset — Start a new episode
POST /step — Multi-session step (batch actions)
POST /step_single — Single-session step (Gymnasium-compatible)
GET /state — Current environment state
GET /tools — List available tool names
POST /tool/{name} — Call a specific tool
GET /health — Health check
GET /stats — Current episode statistics
"""
from __future__ import annotations
import os
from typing import Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from dotenv import load_dotenv
from server.firewall_environment import FirewallEnvironment, ACTIONS
from models import (
HealthResponse,
NetworkStatsResponse,
ResetRequest,
StateResponse,
StepRequest,
StepResponse,
StepSingleRequest,
StepSingleResponse,
ToolRequest,
ToolsListResponse,
)
load_dotenv()
def _clean_env_value(value: str) -> str:
return value.strip().strip("`").strip().strip("'").strip('"').strip()
def _resolve_api_key(value: str | None) -> str:
return _clean_env_value(value or os.getenv("HF_TOKEN") or "")
def _resolve_model(value: str | None) -> str:
return _clean_env_value(value or os.getenv("MODEL_NAME") or "")
def _resolve_base_url(value: str | None) -> str:
return _clean_env_value(
value
or os.getenv("API_BASE_URL")
or ""
)
PLAYGROUND_HTML = """
Adaptive Firewall Playground
Playground
Click Reset to start a new episode.
Ready
{}
"""
env = FirewallEnvironment(seed=42)
app = FastAPI(
title="Adaptive AI Firewall OpenEnv",
version="0.2.0",
description="RL environment for adaptive firewall decision making on encrypted traffic.",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
def root() -> HTMLResponse:
"""Redirect root to the playground UI."""
return HTMLResponse(content=PLAYGROUND_HTML)
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
return HealthResponse(status="ok", version="0.2.0")
@app.post("/reset", response_model=StateResponse)
def reset(request: ResetRequest = ResetRequest()) -> StateResponse:
try:
state = env.reset(task=request.task, seed=request.seed)
return StateResponse(**state)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest = StepRequest()) -> StepResponse:
result = env.step(action_map=request.actions)
return StepResponse(**result)
@app.post("/step_single", response_model=StepSingleResponse)
def step_single(request: StepSingleRequest = None) -> StepSingleResponse:
if request is None:
raise HTTPException(status_code=422, detail="Body is required for /step_single")
try:
result = env.step_single(action=request.action)
return StepSingleResponse(**result)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.get("/state", response_model=StateResponse)
def state() -> StateResponse:
return StateResponse(**env.state())
@app.get("/stats", response_model=NetworkStatsResponse)
def stats() -> NetworkStatsResponse:
return NetworkStatsResponse(**env.get_network_stats())
@app.get("/tools", response_model=ToolsListResponse)
def list_tools() -> ToolsListResponse:
return ToolsListResponse(tools=env.list_tools())
@app.get("/web", response_class=HTMLResponse)
def web_interface() -> HTMLResponse:
return HTMLResponse(content=PLAYGROUND_HTML)
@app.get("/schema")
def schema() -> Any:
return {
"observation_space": {
"type": "Box",
"shape": [22],
"low": 0.0,
"high": 1.0,
},
"action_space": {
"type": "Discrete",
"n": 6,
"actions": ACTIONS,
},
}
@app.post("/tool/{name}")
def call_tool(name: str, request: ToolRequest) -> Any:
try:
if name == "evaluate_session":
return env.evaluate_session(request.kwargs["session_id"])
if name == "take_action":
reward, record = env.take_action(
session_id=request.kwargs["session_id"],
action=int(request.kwargs["action"]),
)
return {"reward": reward, "record": record}
if name == "get_network_stats":
return env.get_network_stats()
if name == "get_threat_intelligence":
return env.get_threat_intelligence()
raise HTTPException(status_code=404, detail=f"unknown tool: {name}")
except KeyError as exc:
raise HTTPException(status_code=400, detail=f"missing key: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
def main() -> None:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()