| """FastAPI server exposing the Adaptive AI Firewall environment. |
| |
| Endpoints: |
| POST /reset — Start a new episode |
| POST /step — Multi-session step (batch actions) |
| POST /step_single — Single-session step (Gymnasium-compatible) |
| GET /state — Current environment state |
| GET /tools — List available tool names |
| POST /tool/{name} — Call a specific tool |
| GET /health — Health check |
| GET /stats — Current episode statistics |
| """ |
| from __future__ import annotations |
|
|
| import csv |
| import json |
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import HTMLResponse |
| from dotenv import load_dotenv |
|
|
| from server.embedded_dashboard_images import IMAGES |
| from server.firewall_environment import FirewallEnvironment, ACTIONS |
| from models import ( |
| HealthResponse, |
| NetworkStatsResponse, |
| ResetRequest, |
| StateResponse, |
| StepRequest, |
| StepResponse, |
| StepSingleRequest, |
| StepSingleResponse, |
| ToolRequest, |
| ToolsListResponse, |
| ) |
|
|
| load_dotenv() |
|
|
|
|
| def _clean_env_value(value: str) -> str: |
| return value.strip().strip("`").strip().strip("'").strip('"').strip() |
|
|
|
|
| def _resolve_api_key(value: str | None) -> str: |
| return _clean_env_value(value or os.getenv("HF_TOKEN") or "") |
|
|
|
|
| def _resolve_model(value: str | None) -> str: |
| return _clean_env_value(value or os.getenv("MODEL_NAME") or "") |
|
|
|
|
| def _resolve_base_url(value: str | None) -> str: |
| return _clean_env_value( |
| value |
| or os.getenv("API_BASE_URL") |
| or "" |
| ) |
|
|
| ROOT_DIR = Path(__file__).resolve().parent.parent |
| PERFORMANCE_MATRIX_PATH = ROOT_DIR / "output" / "performance_matrix.csv" |
|
|
| GRAPH_SPECS = [ |
| ( |
| "Training Loss", |
| "01_training_loss.png", |
| "Loss trend across self-play rounds.", |
| ), |
| ( |
| "Reward Analysis", |
| "02_reward_analysis.png", |
| "Raw score versus difficulty-normalized reward.", |
| ), |
| ( |
| "Elo Progression", |
| "03_elo_progression.png", |
| "Agent Elo compared with adaptive difficulty Elo.", |
| ), |
| ( |
| "Win Rate", |
| "04_win_rate.png", |
| "Rolling win rate and Elo delta per round.", |
| ), |
| ( |
| "Detection vs FP", |
| "05_detection_fp_rate.png", |
| "Detection, false-positive rate, and efficiency.", |
| ), |
| ( |
| "Difficulty Curve", |
| "06_difficulty_progression.png", |
| "Adaptive curriculum difficulty progression.", |
| ), |
| ] |
|
|
|
|
| def _load_performance_matrix(limit: int = 10) -> tuple[list[str], list[dict[str, str]]]: |
| """Load a preview of the performance matrix CSV for dashboard rendering.""" |
| if not PERFORMANCE_MATRIX_PATH.exists(): |
| return [], [] |
|
|
| with PERFORMANCE_MATRIX_PATH.open("r", encoding="utf-8", newline="") as handle: |
| reader = csv.DictReader(handle) |
| rows = list(reader) |
|
|
| headers = [ |
| "Round", |
| "Raw_Score", |
| "Abs_Training_Loss", |
| "Detection_Rate", |
| "FP_Rate", |
| "Efficiency", |
| "Agent_Elo", |
| "Difficulty_Elo", |
| ] |
| preview = [{key: row.get(key, "") for key in headers} for row in rows[:limit]] |
| return headers, preview |
|
|
|
|
| def _build_graph_cards() -> str: |
| cards: list[str] = [] |
| for title, filename, description in GRAPH_SPECS: |
| image_src = IMAGES.get(filename, "") |
| media = ( |
| f'<img src="{image_src}" alt="{title}" loading="lazy"/>' |
| if image_src |
| else '<div class="img-fallback">Image unavailable</div>' |
| ) |
| cards.append( |
| f""" |
| <article class="viz-card"> |
| <div class="viz-copy"> |
| <h3>{title}</h3> |
| <p>{description}</p> |
| </div> |
| <div class="viz-media"> |
| {media} |
| </div> |
| </article> |
| """ |
| ) |
| return "\n".join(cards) |
|
|
|
|
| def _build_table_html() -> str: |
| headers, rows = _load_performance_matrix() |
| if not headers or not rows: |
| return '<div class="table-empty">No performance matrix data available.</div>' |
|
|
| header_html = "".join(f"<th>{header}</th>" for header in headers) |
| body_rows = [] |
| for row in rows: |
| body_rows.append( |
| "<tr>" + "".join(f"<td>{row.get(header, '')}</td>" for header in headers) + "</tr>" |
| ) |
| body_html = "\n".join(body_rows) |
| return ( |
| '<div class="table-scroll"><table><thead><tr>' |
| + header_html |
| + "</tr></thead><tbody>" |
| + body_html |
| + "</tbody></table></div>" |
| ) |
|
|
|
|
| def build_playground_html() -> str: |
| """Render the main dashboard HTML.""" |
| html = """<!doctype html> |
| <html lang="en"> |
| <head> |
| <meta charset="utf-8"/> |
| <meta name="viewport" content="width=device-width,initial-scale=1"/> |
| <title>Adaptive Firewall Dashboard</title> |
| <style> |
| :root{ |
| --bg:#030d23; |
| --bg-accent:#071835; |
| --card:#0d1930; |
| --card-soft:#102142; |
| --border:#1d3157; |
| --text:#edf4ff; |
| --muted:#95add3; |
| --primary:#22c55e; |
| --secondary:#6f819f; |
| --highlight:#5eead4; |
| --danger:#fb7185; |
| --shadow:0 18px 48px rgba(0,0,0,.35); |
| } |
| *{box-sizing:border-box} |
| body{ |
| margin:0; |
| font-family:Inter,Arial,sans-serif; |
| color:var(--text); |
| background: |
| radial-gradient(circle at top left, rgba(45,212,191,.12), transparent 24%), |
| radial-gradient(circle at top right, rgba(59,130,246,.12), transparent 22%), |
| linear-gradient(180deg, #051227 0%, var(--bg) 100%); |
| } |
| .page{ |
| max-width:1360px; |
| margin:0 auto; |
| padding:32px 24px 56px; |
| } |
| .hero{ |
| display:flex; |
| align-items:flex-end; |
| justify-content:space-between; |
| gap:16px; |
| margin-bottom:24px; |
| } |
| .hero h1{ |
| margin:0 0 8px; |
| font-size:36px; |
| line-height:1.1; |
| letter-spacing:-0.02em; |
| } |
| .hero p{ |
| margin:0; |
| color:var(--muted); |
| max-width:760px; |
| line-height:1.55; |
| } |
| .badge{ |
| display:inline-flex; |
| align-items:center; |
| gap:8px; |
| padding:8px 12px; |
| border-radius:999px; |
| background:rgba(34,197,94,.12); |
| border:1px solid rgba(34,197,94,.35); |
| color:#b7f7ca; |
| font-size:13px; |
| font-weight:700; |
| white-space:nowrap; |
| } |
| .layout{ |
| display:grid; |
| grid-template-columns: minmax(0, 1.2fr) minmax(320px, 0.8fr); |
| gap:20px; |
| align-items:start; |
| } |
| .card{ |
| background:linear-gradient(180deg, rgba(13,25,48,.96), rgba(10,21,39,.98)); |
| border:1px solid var(--border); |
| border-radius:22px; |
| box-shadow:var(--shadow); |
| overflow:hidden; |
| } |
| .card-inner{padding:22px} |
| .card h2{ |
| margin:0 0 8px; |
| font-size:20px; |
| letter-spacing:-0.01em; |
| } |
| .subtitle{ |
| margin:0 0 18px; |
| color:var(--muted); |
| font-size:14px; |
| } |
| .controls{ |
| display:grid; |
| grid-template-columns:repeat(2, minmax(0,1fr)); |
| gap:14px; |
| margin-bottom:14px; |
| } |
| .field label{ |
| display:block; |
| margin-bottom:6px; |
| color:#d7e4fb; |
| font-size:13px; |
| font-weight:600; |
| } |
| select,input,button{ |
| width:100%; |
| border-radius:12px; |
| border:1px solid #2a3f68; |
| background:#081426; |
| color:var(--text); |
| padding:14px 14px; |
| font-size:15px; |
| } |
| input:focus, select:focus{ |
| outline:2px solid rgba(94,234,212,.28); |
| border-color:#51d6c1; |
| } |
| .actions{ |
| display:grid; |
| grid-template-columns:repeat(4, minmax(0,1fr)); |
| gap:12px; |
| margin:8px 0 14px; |
| } |
| button{ |
| font-weight:700; |
| cursor:pointer; |
| transition:transform .15s ease, filter .15s ease; |
| } |
| button:hover{transform:translateY(-1px);filter:brightness(1.04)} |
| button:disabled{opacity:.65;cursor:not-allowed;transform:none} |
| .btn-primary{ |
| background:linear-gradient(180deg,#29d365,#22c55e); |
| border-color:#22c55e; |
| color:#06260f; |
| } |
| .btn-secondary{ |
| background:#6d7c96; |
| border-color:#6d7c96; |
| color:#f7fbff; |
| } |
| .status-row{ |
| display:flex; |
| align-items:center; |
| justify-content:space-between; |
| gap:12px; |
| margin-bottom:14px; |
| flex-wrap:wrap; |
| } |
| .status-pill{ |
| display:inline-flex; |
| align-items:center; |
| gap:8px; |
| padding:8px 12px; |
| border-radius:999px; |
| border:1px solid #23406e; |
| color:#cce0ff; |
| background:#091629; |
| font-size:13px; |
| font-weight:600; |
| } |
| .status-pill.ready::before, |
| .status-pill.success::before, |
| .status-pill.error::before{ |
| content:""; |
| width:8px; |
| height:8px; |
| border-radius:50%; |
| background:var(--highlight); |
| display:inline-block; |
| } |
| .status-pill.success::before{background:var(--primary)} |
| .status-pill.error::before{background:var(--danger)} |
| .metric-grid{ |
| display:grid; |
| grid-template-columns:repeat(3, minmax(0,1fr)); |
| gap:12px; |
| margin-top:6px; |
| } |
| .metric{ |
| background:var(--card-soft); |
| border:1px solid var(--border); |
| border-radius:16px; |
| padding:14px; |
| } |
| .metric span{ |
| display:block; |
| color:var(--muted); |
| font-size:12px; |
| text-transform:uppercase; |
| letter-spacing:.06em; |
| margin-bottom:8px; |
| } |
| .metric strong{ |
| font-size:22px; |
| letter-spacing:-0.02em; |
| } |
| .metric small{ |
| display:block; |
| margin-top:8px; |
| color:#b7c9e8; |
| font-size:12px; |
| } |
| .mapping{ |
| display:flex; |
| flex-wrap:wrap; |
| gap:8px; |
| margin-top:16px; |
| } |
| .chip{ |
| border:1px solid #264067; |
| border-radius:999px; |
| padding:8px 12px; |
| background:#0a1730; |
| color:#dce9fb; |
| font-size:13px; |
| } |
| pre{ |
| margin:14px 0 0; |
| min-height:300px; |
| padding:16px; |
| border-radius:16px; |
| border:1px solid #253a60; |
| background:#071224; |
| color:#d8e5fb; |
| overflow:auto; |
| white-space:pre-wrap; |
| word-break:break-word; |
| font-size:13px; |
| line-height:1.55; |
| } |
| .stack{ |
| display:grid; |
| gap:20px; |
| } |
| .section{ |
| margin-top:22px; |
| } |
| .section-head{ |
| display:flex; |
| justify-content:space-between; |
| align-items:flex-end; |
| gap:16px; |
| margin-bottom:14px; |
| } |
| .section-head h2{ |
| margin:0 0 6px; |
| font-size:24px; |
| } |
| .section-head p{ |
| margin:0; |
| color:var(--muted); |
| max-width:760px; |
| } |
| .table-wrap{ |
| background:linear-gradient(180deg, rgba(13,25,48,.96), rgba(10,21,39,.98)); |
| border:1px solid var(--border); |
| border-radius:22px; |
| box-shadow:var(--shadow); |
| padding:20px; |
| } |
| .table-scroll{ |
| overflow:auto; |
| border:1px solid #223658; |
| border-radius:16px; |
| background:#081426; |
| } |
| table{ |
| width:100%; |
| border-collapse:collapse; |
| min-width:920px; |
| } |
| th,td{ |
| padding:12px 14px; |
| border-bottom:1px solid rgba(36,56,92,.85); |
| text-align:left; |
| font-size:13px; |
| white-space:nowrap; |
| } |
| th{ |
| position:sticky; |
| top:0; |
| background:#0c1a31; |
| color:#d7e6ff; |
| z-index:1; |
| } |
| td{color:#c8d9f5} |
| .table-empty{ |
| padding:18px; |
| border:1px dashed #325180; |
| border-radius:16px; |
| color:var(--muted); |
| background:#081426; |
| } |
| .viz-grid{ |
| display:grid; |
| grid-template-columns:repeat(2, minmax(0,1fr)); |
| gap:18px; |
| } |
| .viz-card{ |
| background:linear-gradient(180deg, rgba(13,25,48,.96), rgba(10,21,39,.98)); |
| border:1px solid var(--border); |
| border-radius:22px; |
| box-shadow:var(--shadow); |
| overflow:hidden; |
| } |
| .viz-copy{ |
| padding:18px 18px 8px; |
| } |
| .viz-copy h3{ |
| margin:0 0 6px; |
| font-size:18px; |
| } |
| .viz-copy p{ |
| margin:0; |
| color:var(--muted); |
| font-size:14px; |
| } |
| .viz-media{ |
| padding:14px 18px 18px; |
| } |
| .viz-media img, |
| .img-fallback{ |
| width:100%; |
| display:block; |
| border-radius:16px; |
| border:1px solid #28416b; |
| background:#091426; |
| } |
| .img-fallback{ |
| min-height:220px; |
| display:flex; |
| align-items:center; |
| justify-content:center; |
| color:var(--muted); |
| font-size:14px; |
| } |
| @media (max-width: 1080px){ |
| .layout{grid-template-columns:1fr} |
| .viz-grid{grid-template-columns:1fr} |
| } |
| @media (max-width: 720px){ |
| .page{padding:20px 14px 40px} |
| .hero{flex-direction:column;align-items:flex-start} |
| .controls,.actions,.metric-grid{grid-template-columns:1fr} |
| } |
| </style> |
| </head> |
| <body> |
| <div class="page"> |
| <header class="hero"> |
| <div> |
| <h1>Adaptive Firewall Dashboard</h1> |
| <p> |
| Monitor the firewall environment, step through actions, inspect live state, |
| and review the training visuals from the current performance matrix in one place. |
| </p> |
| </div> |
| <div class="badge">Running on Hugging Face Space</div> |
| </header> |
| |
| <section class="layout"> |
| <div class="card"> |
| <div class="card-inner"> |
| <h2>Playground</h2> |
| <p class="subtitle">Click Reset to start a new episode, then use Step to apply a single action.</p> |
| |
| <div class="controls"> |
| <div class="field"> |
| <label for="task_select">Task</label> |
| <select id="task_select"> |
| <option value="easy">easy</option> |
| <option value="medium">medium</option> |
| <option value="hard">hard</option> |
| </select> |
| </div> |
| <div class="field"> |
| <label for="action_input">Action ID</label> |
| <input id="action_input" type="number" value="0" min="0" max="5" /> |
| </div> |
| </div> |
| |
| <div class="actions"> |
| <button id="btn_step" class="btn-primary">Step</button> |
| <button id="btn_reset" class="btn-secondary">Reset</button> |
| <button id="btn_state" class="btn-secondary">Get state</button> |
| <button id="btn_stats" class="btn-secondary">Get stats</button> |
| </div> |
| |
| <div class="status-row"> |
| <div id="status" class="status-pill ready">Ready</div> |
| <div class="status-pill">Action space: 6 discrete actions</div> |
| </div> |
| |
| <div class="mapping"> |
| <div class="chip">0: ALLOW</div> |
| <div class="chip">1: BLOCK</div> |
| <div class="chip">2: INSPECT</div> |
| <div class="chip">3: SANDBOX</div> |
| <div class="chip">4: RATE_LIMIT</div> |
| <div class="chip">5: QUARANTINE</div> |
| </div> |
| |
| <pre id="output">{}</pre> |
| </div> |
| </div> |
| |
| <div class="stack"> |
| <div class="card"> |
| <div class="card-inner"> |
| <h2>Live Episode</h2> |
| <p class="subtitle">Core state values update after every API action.</p> |
| <div class="metric-grid"> |
| <div class="metric"> |
| <span>Task</span> |
| <strong id="metric_task">easy</strong> |
| <small>Current difficulty level</small> |
| </div> |
| <div class="metric"> |
| <span>Step Count</span> |
| <strong id="metric_step">0</strong> |
| <small>Episode steps processed</small> |
| </div> |
| <div class="metric"> |
| <span>Queue Length</span> |
| <strong id="metric_queue">0</strong> |
| <small>Sessions waiting for action</small> |
| </div> |
| <div class="metric"> |
| <span>Budget</span> |
| <strong id="metric_budget">0</strong> |
| <small>Remaining budget</small> |
| </div> |
| <div class="metric"> |
| <span>Total Reward</span> |
| <strong id="metric_reward">0</strong> |
| <small>Accumulated environment reward</small> |
| </div> |
| <div class="metric"> |
| <span>Focus Session</span> |
| <strong id="metric_focus">-</strong> |
| <small>Current session in focus</small> |
| </div> |
| </div> |
| </div> |
| </div> |
| |
| <div class="card"> |
| <div class="card-inner"> |
| <h2>Firewall Summary</h2> |
| <p class="subtitle">The dashboard below uses the generated training artifacts and the latest API responses.</p> |
| <div class="metric-grid"> |
| <div class="metric"> |
| <span>Detection Rate</span> |
| <strong id="metric_det">-</strong> |
| <small>Updated from /stats</small> |
| </div> |
| <div class="metric"> |
| <span>FP Rate</span> |
| <strong id="metric_fp">-</strong> |
| <small>Updated from /stats</small> |
| </div> |
| <div class="metric"> |
| <span>Efficiency</span> |
| <strong id="metric_eff">-</strong> |
| <small>Updated from /stats</small> |
| </div> |
| </div> |
| </div> |
| </div> |
| </div> |
| </section> |
| |
| <section class="section"> |
| <div class="section-head"> |
| <div> |
| <h2>Performance Matrix</h2> |
| <p>Preview of the tracked performance matrix used for training analysis.</p> |
| </div> |
| </div> |
| <div class="table-wrap"> |
| __TABLE_HTML__ |
| </div> |
| </section> |
| |
| <section class="section"> |
| <div class="section-head"> |
| <div> |
| <h2>Performance Visuals</h2> |
| <p>Embedded graphs generated from the available dashboard images so the Space can display them reliably.</p> |
| </div> |
| </div> |
| <div class="viz-grid"> |
| __GRAPH_CARDS__ |
| </div> |
| </section> |
| </div> |
| |
| <script> |
| const output = document.getElementById("output"); |
| const status = document.getElementById("status"); |
| const taskSelect = document.getElementById("task_select"); |
| const actionInput = document.getElementById("action_input"); |
| const metricEls = { |
| task: document.getElementById("metric_task"), |
| step: document.getElementById("metric_step"), |
| queue: document.getElementById("metric_queue"), |
| budget: document.getElementById("metric_budget"), |
| reward: document.getElementById("metric_reward"), |
| focus: document.getElementById("metric_focus"), |
| det: document.getElementById("metric_det"), |
| fp: document.getElementById("metric_fp"), |
| eff: document.getElementById("metric_eff"), |
| }; |
| |
| function setStatus(message, kind="ready") { |
| status.textContent = message; |
| status.className = "status-pill " + kind; |
| } |
| |
| function setButtonsDisabled(disabled) { |
| document.querySelectorAll("button").forEach((button) => { |
| button.disabled = disabled; |
| }); |
| } |
| |
| function renderJson(data) { |
| output.textContent = JSON.stringify(data, null, 2); |
| } |
| |
| async function refreshStats() { |
| try { |
| const response = await fetch("/stats"); |
| const contentType = response.headers.get("content-type") || ""; |
| const data = contentType.includes("application/json") |
| ? await response.json() |
| : null; |
| if (response.ok) { |
| updateStatsMetrics(data); |
| } |
| } catch (error) { |
| console.error("Unable to refresh stats", error); |
| } |
| } |
| |
| function updateStateMetrics(payload) { |
| const state = payload && payload.state ? payload.state : payload; |
| if (!state || typeof state !== "object") return; |
| |
| if (state.task !== undefined) { |
| metricEls.task.textContent = String(state.task); |
| if (taskSelect.querySelector(`option[value="${state.task}"]`)) { |
| taskSelect.value = String(state.task); |
| } |
| } |
| if (state.step_count !== undefined) metricEls.step.textContent = String(state.step_count); |
| if (state.queue_length !== undefined) metricEls.queue.textContent = String(state.queue_length); |
| if (state.budget_remaining !== undefined) metricEls.budget.textContent = Number(state.budget_remaining).toFixed(2); |
| if (state.total_reward !== undefined) metricEls.reward.textContent = Number(state.total_reward).toFixed(2); |
| metricEls.focus.textContent = state.focus_session_id || "-"; |
| } |
| |
| function updateStatsMetrics(payload) { |
| if (!payload || typeof payload !== "object") return; |
| if (payload.detection_rate !== undefined) metricEls.det.textContent = Number(payload.detection_rate).toFixed(3); |
| if (payload.false_positive_rate !== undefined) metricEls.fp.textContent = Number(payload.false_positive_rate).toFixed(3); |
| if (payload.efficiency !== undefined) metricEls.eff.textContent = Number(payload.efficiency).toFixed(3); |
| } |
| |
| async function call(path, method="GET", body=null) { |
| setButtonsDisabled(true); |
| setStatus("Calling " + path + "...", "ready"); |
| try { |
| const options = { |
| method, |
| headers: {"Content-Type": "application/json"}, |
| }; |
| if (body !== null) options.body = JSON.stringify(body); |
| |
| const response = await fetch(path, options); |
| const contentType = response.headers.get("content-type") || ""; |
| const data = contentType.includes("application/json") |
| ? await response.json() |
| : {raw: await response.text()}; |
| |
| renderJson(data); |
| |
| if (!response.ok) { |
| const detail = data && data.detail ? data.detail : "Request failed"; |
| setStatus("Error: " + detail, "error"); |
| return; |
| } |
| |
| updateStateMetrics(data); |
| updateStatsMetrics(data); |
| if (path !== "/stats") { |
| await refreshStats(); |
| } |
| setStatus("Success", "success"); |
| return data; |
| } catch (error) { |
| renderJson({error: String(error)}); |
| setStatus("Error: " + error, "error"); |
| } finally { |
| setButtonsDisabled(false); |
| } |
| } |
| |
| document.getElementById("btn_step").addEventListener("click", () => { |
| const rawValue = parseInt(actionInput.value || "0", 10); |
| const action = Math.min(5, Math.max(0, Number.isNaN(rawValue) ? 0 : rawValue)); |
| actionInput.value = String(action); |
| call("/step_single", "POST", {action}); |
| }); |
| |
| document.getElementById("btn_reset").addEventListener("click", () => { |
| call("/reset", "POST", {task: taskSelect.value}); |
| }); |
| |
| document.getElementById("btn_state").addEventListener("click", () => { |
| call("/state", "GET"); |
| }); |
| |
| document.getElementById("btn_stats").addEventListener("click", () => { |
| call("/stats", "GET"); |
| }); |
| |
| call("/state", "GET"); |
| </script> |
| </body> |
| </html>""" |
| return html.replace("__GRAPH_CARDS__", _build_graph_cards()).replace("__TABLE_HTML__", _build_table_html()) |
|
|
|
|
| env = FirewallEnvironment(seed=42) |
| env.reset(task="easy") |
| app = FastAPI( |
| title="Adaptive AI Firewall OpenEnv", |
| version="0.2.0", |
| description="RL environment for adaptive firewall decision making on encrypted traffic.", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| def root() -> HTMLResponse: |
| """Redirect root to the playground UI.""" |
| return HTMLResponse(content=build_playground_html()) |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health() -> HealthResponse: |
| return HealthResponse(status="ok", version="0.2.0") |
|
|
|
|
| @app.post("/reset", response_model=StateResponse) |
| def reset(request: ResetRequest = ResetRequest()) -> StateResponse: |
| try: |
| state = env.reset(task=request.task, seed=request.seed) |
| return StateResponse(**state) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) from e |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(request: StepRequest = StepRequest()) -> StepResponse: |
| result = env.step(action_map=request.actions) |
| return StepResponse(**result) |
|
|
|
|
| @app.post("/step_single", response_model=StepSingleResponse) |
| def step_single(request: StepSingleRequest = None) -> StepSingleResponse: |
| if request is None: |
| raise HTTPException(status_code=422, detail="Body is required for /step_single") |
| try: |
| result = env.step_single(action=request.action) |
| return StepSingleResponse(**result) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) from e |
|
|
|
|
| @app.get("/state", response_model=StateResponse) |
| def state() -> StateResponse: |
| return StateResponse(**env.state()) |
|
|
|
|
| @app.get("/stats", response_model=NetworkStatsResponse) |
| def stats() -> NetworkStatsResponse: |
| return NetworkStatsResponse(**env.get_network_stats()) |
|
|
|
|
| @app.get("/tools", response_model=ToolsListResponse) |
| def list_tools() -> ToolsListResponse: |
| return ToolsListResponse(tools=env.list_tools()) |
|
|
|
|
| @app.get("/web", response_class=HTMLResponse) |
| def web_interface() -> HTMLResponse: |
| return HTMLResponse(content=build_playground_html()) |
|
|
|
|
| @app.get("/schema") |
| def schema() -> Any: |
| return { |
| "observation_space": { |
| "type": "Box", |
| "shape": [22], |
| "low": 0.0, |
| "high": 1.0, |
| }, |
| "action_space": { |
| "type": "Discrete", |
| "n": 6, |
| "actions": ACTIONS, |
| }, |
| } |
|
|
|
|
| @app.post("/tool/{name}") |
| def call_tool(name: str, request: ToolRequest) -> Any: |
| try: |
| if name == "evaluate_session": |
| return env.evaluate_session(request.kwargs["session_id"]) |
| if name == "take_action": |
| reward, record = env.take_action( |
| session_id=request.kwargs["session_id"], |
| action=int(request.kwargs["action"]), |
| ) |
| return {"reward": reward, "record": record} |
| if name == "get_network_stats": |
| return env.get_network_stats() |
| if name == "get_threat_intelligence": |
| return env.get_threat_intelligence() |
| raise HTTPException(status_code=404, detail=f"unknown tool: {name}") |
| except KeyError as exc: |
| raise HTTPException(status_code=400, detail=f"missing key: {exc}") from exc |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
|
|
| def main() -> None: |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|