Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server for the SQL Data Analyst OpenEnv environment. | |
| Exposes /reset, /step, /state, /health, and /tasks endpoints. | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import os | |
| # Allow imports from project root | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from models import ResetResult, SQLAction, SQLObservation, SQLState, StepResult | |
| from server.environment import SQLDataAnalystEnv, TASKS | |
| # --------------------------------------------------------------------------- | |
| # App setup | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="SQL Data Analyst Environment", | |
| description=( | |
| "An OpenEnv-compatible agentic environment where AI agents must analyze " | |
| "SQLite databases, fix broken queries, detect data anomalies, and repair " | |
| "data pipelines. Implements the OpenEnv step()/reset()/state() API." | |
| ), | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global environment instance (single-session server — sufficient for HF Spaces + evaluation) | |
| _env = SQLDataAnalystEnv() | |
| _last_obs: Optional[SQLObservation] = None | |
| _last_done: bool = False | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def health() -> Dict[str, str]: | |
| """Liveness probe — returns 200 if server is running.""" | |
| return {"status": "ok", "environment": "sql-data-analyst-env", "version": "1.0.0"} | |
| def list_tasks() -> Dict[str, Any]: | |
| """Return metadata for all available tasks.""" | |
| return { | |
| "tasks": [ | |
| { | |
| "id": t["id"], | |
| "difficulty": t["difficulty"], | |
| "goal": t["goal"], | |
| "max_steps": t["max_steps"], | |
| } | |
| for t in TASKS.values() | |
| ] | |
| } | |
| def reset(task_id: Optional[str] = Query(default=None, description="Task ID to load. If omitted, cycles through tasks.")) -> ResetResult: | |
| """ | |
| Initialize a new episode. | |
| Returns the initial observation. | |
| """ | |
| global _last_obs, _last_done | |
| try: | |
| obs, info = _env.reset(task_id=task_id) | |
| _last_obs = obs | |
| _last_done = False | |
| return ResetResult(observation=obs, done=False, info=info) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def step(action: SQLAction) -> StepResult: | |
| """ | |
| Execute an action and return (observation, reward, done, info). | |
| Action types: | |
| - `execute_query`: Run a SQL query. Requires `sql_query`. | |
| - `describe_table`: Get schema + sample for a table. Set `sql_query` = table name. | |
| - `list_tables`: List all tables in the episode database. | |
| - `submit_answer`: Submit final answer to the grader. Requires `answer` dict. | |
| - `noop`: Do nothing. | |
| """ | |
| global _last_obs, _last_done | |
| try: | |
| obs, reward, done, info = _env.step(action) | |
| _last_obs = obs | |
| _last_done = done | |
| return StepResult(observation=obs, reward=reward, done=done, info=info) | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def state() -> SQLState: | |
| """Return the current episode-level state metadata.""" | |
| try: | |
| return _env.state() | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def observation() -> SQLObservation: | |
| """Return the last observation (convenience endpoint).""" | |
| if _last_obs is None: | |
| raise HTTPException(status_code=400, detail="No observation yet. Call /reset first.") | |
| return _last_obs | |
| def main(): | |
| import uvicorn | |
| uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False) | |
| if __name__ == "__main__": | |
| main() | |