Abhinav Singh commited on
Commit
e8bc352
Β·
1 Parent(s): 7841be7

feat(server): add FastAPI server with all OpenEnv endpoints

Browse files

server/app.py β€” uvicorn FastAPI app on port 7860:
POST /reset β€” start new episode, accepts optional {task_id} body,
defaults to task_1_basic_antipatterns
POST /step β€” submit Action, returns StepResult with Observation+Reward
GET /state β€” current EnvironmentState (non-destructive)
GET /tasks β€” list all tasks with descriptions and action schema
POST /grader β€” grade an action against active task without advancing
POST /baseline β€” runs inference.py subprocess and returns stdout
GET / β€” health check returning env name and available tasks
CORS middleware enabled for all origins (required for HF Space)

Files changed (2) hide show
  1. server/__init__.py +1 -0
  2. server/app.py +121 -0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # server package
server/app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import os
4
+ import sys
5
+ import json
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from env import SQLOptimEnv
10
+ from models import Action, StepResult, EnvironmentState, Observation
11
+ from tasks import get_task_list
12
+ from graders import grade
13
+
14
+ app = FastAPI(
15
+ title="SQL Query Optimization Environment",
16
+ description=(
17
+ "OpenEnv-compliant RL environment where AI agents learn to analyze, "
18
+ "diagnose, and optimize SQL queries across three difficulty levels."
19
+ ),
20
+ version="1.0.0",
21
+ )
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ env = SQLOptimEnv()
31
+
32
+
33
+ @app.get("/")
34
+ def root():
35
+ return {
36
+ "status": "ok",
37
+ "environment": "sql-optim-env",
38
+ "version": "1.0.0",
39
+ "tasks": [t["task_id"] for t in get_task_list()],
40
+ }
41
+
42
+
43
+ @app.post("/reset", response_model=Observation)
44
+ async def reset(request: Request):
45
+ """
46
+ Start a new episode. Optionally pass {"task_id": "..."} in the body.
47
+ Defaults to task_1_basic_antipatterns.
48
+ """
49
+ try:
50
+ body = await request.body()
51
+ task_id = "task_1_basic_antipatterns"
52
+ if body:
53
+ try:
54
+ data = json.loads(body)
55
+ task_id = data.get("task_id", task_id) or task_id
56
+ except Exception:
57
+ pass
58
+ obs = env.reset(task_id=task_id)
59
+ return obs
60
+ except ValueError as e:
61
+ raise HTTPException(status_code=400, detail=str(e))
62
+
63
+
64
+ @app.post("/step", response_model=StepResult)
65
+ def step(action: Action):
66
+ """Take one action (submit SQL analysis + optimized query)."""
67
+ try:
68
+ result = env.step(action)
69
+ return result
70
+ except RuntimeError as e:
71
+ raise HTTPException(status_code=400, detail=str(e))
72
+
73
+
74
+ @app.get("/state", response_model=EnvironmentState)
75
+ def state():
76
+ """Get current environment state without advancing the episode."""
77
+ return env.state()
78
+
79
+
80
+ @app.get("/tasks")
81
+ def tasks():
82
+ """List all available tasks with descriptions and action schema."""
83
+ return {"tasks": get_task_list()}
84
+
85
+
86
+ @app.post("/grader")
87
+ def grader(action: Action):
88
+ """Grade an action against the current task without advancing the episode."""
89
+ if env._task_data is None:
90
+ raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
91
+ reward = grade(env._task_data, action)
92
+ return reward
93
+
94
+
95
+ @app.post("/baseline")
96
+ def baseline():
97
+ """Run the baseline agent and return scores for all tasks."""
98
+ try:
99
+ import subprocess
100
+ result = subprocess.run(
101
+ ["python", "inference.py"],
102
+ capture_output=True, text=True, timeout=300,
103
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
104
+ )
105
+ return {
106
+ "stdout": result.stdout,
107
+ "stderr": result.stderr,
108
+ "returncode": result.returncode,
109
+ }
110
+ except Exception as e:
111
+ raise HTTPException(status_code=500, detail=f"Baseline failed: {str(e)}")
112
+
113
+
114
+
115
+ def main():
116
+ import uvicorn
117
+ uvicorn.run(app, host="0.0.0.0", port=7860)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()