hitanshjain1812's picture
Upload folder using huggingface_hub
7a0f237 verified
"""
FastAPI server for the SQL Data Analyst OpenEnv environment.
Exposes /reset, /step, /state, /health, and /tasks endpoints.
"""
from __future__ import annotations
import sys
import os
# Allow imports from project root
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from models import ResetResult, SQLAction, SQLObservation, SQLState, StepResult
from server.environment import SQLDataAnalystEnv, TASKS
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = FastAPI(
title="SQL Data Analyst Environment",
description=(
"An OpenEnv-compatible agentic environment where AI agents must analyze "
"SQLite databases, fix broken queries, detect data anomalies, and repair "
"data pipelines. Implements the OpenEnv step()/reset()/state() API."
),
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global environment instance (single-session server — sufficient for HF Spaces + evaluation)
_env = SQLDataAnalystEnv()
_last_obs: Optional[SQLObservation] = None
_last_done: bool = False
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health", summary="Health check")
def health() -> Dict[str, str]:
"""Liveness probe — returns 200 if server is running."""
return {"status": "ok", "environment": "sql-data-analyst-env", "version": "1.0.0"}
@app.get("/tasks", summary="List available tasks")
def list_tasks() -> Dict[str, Any]:
"""Return metadata for all available tasks."""
return {
"tasks": [
{
"id": t["id"],
"difficulty": t["difficulty"],
"goal": t["goal"],
"max_steps": t["max_steps"],
}
for t in TASKS.values()
]
}
@app.post("/reset", response_model=ResetResult, summary="Reset the environment")
def reset(task_id: Optional[str] = Query(default=None, description="Task ID to load. If omitted, cycles through tasks.")) -> ResetResult:
"""
Initialize a new episode.
Returns the initial observation.
"""
global _last_obs, _last_done
try:
obs, info = _env.reset(task_id=task_id)
_last_obs = obs
_last_done = False
return ResetResult(observation=obs, done=False, info=info)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/step", response_model=StepResult, summary="Take a step in the environment")
def step(action: SQLAction) -> StepResult:
"""
Execute an action and return (observation, reward, done, info).
Action types:
- `execute_query`: Run a SQL query. Requires `sql_query`.
- `describe_table`: Get schema + sample for a table. Set `sql_query` = table name.
- `list_tables`: List all tables in the episode database.
- `submit_answer`: Submit final answer to the grader. Requires `answer` dict.
- `noop`: Do nothing.
"""
global _last_obs, _last_done
try:
obs, reward, done, info = _env.step(action)
_last_obs = obs
_last_done = done
return StepResult(observation=obs, reward=reward, done=done, info=info)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/state", response_model=SQLState, summary="Get current episode state")
def state() -> SQLState:
"""Return the current episode-level state metadata."""
try:
return _env.state()
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/observation", response_model=SQLObservation, summary="Get current observation")
def observation() -> SQLObservation:
"""Return the last observation (convenience endpoint)."""
if _last_obs is None:
raise HTTPException(status_code=400, detail="No observation yet. Call /reset first.")
return _last_obs
def main():
import uvicorn
uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
if __name__ == "__main__":
main()