File size: 4,662 Bytes
7a0f237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
FastAPI server for the SQL Data Analyst OpenEnv environment.
Exposes /reset, /step, /state, /health, and /tasks endpoints.
"""

from __future__ import annotations

import sys
import os

# Allow imports from project root
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from models import ResetResult, SQLAction, SQLObservation, SQLState, StepResult
from server.environment import SQLDataAnalystEnv, TASKS

# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------

app = FastAPI(
    title="SQL Data Analyst Environment",
    description=(
        "An OpenEnv-compatible agentic environment where AI agents must analyze "
        "SQLite databases, fix broken queries, detect data anomalies, and repair "
        "data pipelines. Implements the OpenEnv step()/reset()/state() API."
    ),
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global environment instance (single-session server — sufficient for HF Spaces + evaluation)
_env = SQLDataAnalystEnv()
_last_obs: Optional[SQLObservation] = None
_last_done: bool = False


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/health", summary="Health check")
def health() -> Dict[str, str]:
    """Liveness probe — returns 200 if server is running."""
    return {"status": "ok", "environment": "sql-data-analyst-env", "version": "1.0.0"}


@app.get("/tasks", summary="List available tasks")
def list_tasks() -> Dict[str, Any]:
    """Return metadata for all available tasks."""
    return {
        "tasks": [
            {
                "id": t["id"],
                "difficulty": t["difficulty"],
                "goal": t["goal"],
                "max_steps": t["max_steps"],
            }
            for t in TASKS.values()
        ]
    }


@app.post("/reset", response_model=ResetResult, summary="Reset the environment")
def reset(task_id: Optional[str] = Query(default=None, description="Task ID to load. If omitted, cycles through tasks.")) -> ResetResult:
    """
    Initialize a new episode.
    Returns the initial observation.
    """
    global _last_obs, _last_done
    try:
        obs, info = _env.reset(task_id=task_id)
        _last_obs = obs
        _last_done = False
        return ResetResult(observation=obs, done=False, info=info)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/step", response_model=StepResult, summary="Take a step in the environment")
def step(action: SQLAction) -> StepResult:
    """
    Execute an action and return (observation, reward, done, info).

    Action types:
    - `execute_query`: Run a SQL query. Requires `sql_query`.
    - `describe_table`: Get schema + sample for a table. Set `sql_query` = table name.
    - `list_tables`: List all tables in the episode database.
    - `submit_answer`: Submit final answer to the grader. Requires `answer` dict.
    - `noop`: Do nothing.
    """
    global _last_obs, _last_done
    try:
        obs, reward, done, info = _env.step(action)
        _last_obs = obs
        _last_done = done
        return StepResult(observation=obs, reward=reward, done=done, info=info)
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/state", response_model=SQLState, summary="Get current episode state")
def state() -> SQLState:
    """Return the current episode-level state metadata."""
    try:
        return _env.state()
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/observation", response_model=SQLObservation, summary="Get current observation")
def observation() -> SQLObservation:
    """Return the last observation (convenience endpoint)."""
    if _last_obs is None:
        raise HTTPException(status_code=400, detail="No observation yet. Call /reset first.")
    return _last_obs

def main():
    import uvicorn
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)

if __name__ == "__main__":
    main()