Spaces:
Running
Running
File size: 4,601 Bytes
a4f74f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
FastAPI application for the API Testing Environment.
Endpoints:
- POST /reset: Reset the environment
- POST /step: Execute an action
- GET /state: Get current environment state
- GET /schema: Get action/observation schemas
- WS /ws: WebSocket endpoint for persistent sessions
- GET / Info page
Usage:
uvicorn server.app:app --host 0.0.0.0 --port 8000
"""
import os
import logging
try:
from openenv.core.env_server.http_server import create_app
from ..models import APITestAction, APITestObservation
from .environment import APITestEnvironment
except ImportError:
from openenv.core.env_server.http_server import create_app
from models import APITestAction, APITestObservation
from server.environment import APITestEnvironment
from fastapi.responses import RedirectResponse
logger = logging.getLogger(__name__)
app = create_app(
APITestEnvironment,
APITestAction,
APITestObservation,
env_name="api_testing_env",
max_concurrent_envs=int(os.environ.get("MAX_ENVS", "1")),
)
# Track whether the Gradio UI is available so root can redirect to it
_GRADIO_MOUNTED = False
@app.get("/info")
async def info():
"""JSON info about the environment (replaces the old `/` JSON endpoint)."""
return {
"name": "API Testing Environment",
"description": "An OpenEnv RL environment where an AI agent learns to test REST APIs intelligently",
"tasks": ["basic_validation", "edge_cases", "security_workflows"],
"ui": "/ui",
"docs": "/docs",
"schema": "/schema",
}
@app.get("/tasks")
async def list_tasks():
"""List available tasks with descriptions."""
from .environment import TASKS
return {
task_id: {
"description": task["description"],
"difficulty": task["difficulty"],
"max_steps": task["max_steps"],
"total_bugs": task["total_bugs"],
}
for task_id, task in TASKS.items()
}
# ---------------------------------------------------------------------------
# Mount Gradio UI at /ui (only if gradio is installed and ENABLE_WEB_INTERFACE)
# ---------------------------------------------------------------------------
if os.environ.get("ENABLE_WEB_INTERFACE", "true").lower() in ("1", "true", "yes"):
try:
import gradio as gr # type: ignore
# Make the repo root importable so gradio_app's `from models import ...` works
import sys
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from gradio_app import build_ui # type: ignore
_gradio_ui = build_ui()
app = gr.mount_gradio_app(app, _gradio_ui, path="/ui")
_GRADIO_MOUNTED = True
logger.info("Gradio UI mounted at /ui")
except Exception as exc: # noqa: BLE001
logger.warning(f"Skipping Gradio mount ({type(exc).__name__}: {exc})")
# ---------------------------------------------------------------------------
# Root redirect: send visitors to the Gradio UI if mounted, else to JSON info
# ---------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
async def root_redirect():
"""Redirect / to the Gradio UI when available, otherwise to /info JSON."""
if _GRADIO_MOUNTED:
return RedirectResponse(url="/ui", status_code=307)
return RedirectResponse(url="/info", status_code=307)
def main(host: str = None, port: int = None):
"""Entry point for `uv run server` and `python -m server.app`.
When invoked from the CLI without args, parses argv for --host / --port.
"""
import uvicorn
if host is None or port is None:
import argparse
parser = argparse.ArgumentParser(description="API Testing Environment server")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=None)
args, _ = parser.parse_known_args()
host = host or args.host
port = port or args.port
if port is None:
port = int(os.environ.get("PORT", "8000"))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|