jaivardhan2409 commited on
Commit
aeea577
·
verified ·
1 Parent(s): a27ff86

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. server/app.py +11 -46
server/app.py CHANGED
@@ -1,41 +1,19 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from typing import Dict, Any, List
4
- import asyncio
5
 
 
 
 
6
  from env.environment import SQLEnv
7
- from env.models import Observation, Action, Reward
8
  from env.tasks import TASKS
9
 
10
- app = FastAPI(title="SQL Query Optimizer OpenEnv")
11
- env = SQLEnv()
12
-
13
- class ResetRequest(BaseModel):
14
- task_id: int
15
-
16
- @app.post("/reset", response_model=Observation)
17
- async def reset(req: ResetRequest):
18
- try:
19
- return env.reset(req.task_id)
20
- except ValueError as e:
21
- raise HTTPException(status_code=400, detail=str(e))
22
-
23
- @app.post("/step")
24
- async def step(action: Action):
25
- try:
26
- obs, reward, done, info = env.step(action)
27
- return {
28
- "observation": obs.model_dump(),
29
- "reward": reward.model_dump(),
30
- "done": done,
31
- "info": info
32
- }
33
- except RuntimeError as e:
34
- raise HTTPException(status_code=400, detail=str(e))
35
-
36
- @app.get("/state")
37
- async def state():
38
- return env.state()
39
 
40
  @app.get("/tasks")
41
  async def get_tasks():
@@ -46,12 +24,6 @@ async def get_tasks():
46
  "action_schema": action_schema
47
  }
48
 
49
- @app.get("/grader")
50
- async def grader():
51
- if not env.task:
52
- raise HTTPException(status_code=400, detail="Environment not initialized.")
53
- return {"grader_score": env.final_grader_score}
54
-
55
  class BaselineResponse(BaseModel):
56
  scores: Dict[int, float]
57
 
@@ -63,10 +35,3 @@ async def run_baseline():
63
  return BaselineResponse(scores=scores)
64
  except Exception as e:
65
  raise HTTPException(status_code=500, detail=str(e))
66
-
67
- def main(host: str = "0.0.0.0", port: int = 8000):
68
- import uvicorn
69
- uvicorn.run(app, host=host, port=port)
70
-
71
- if __name__ == '__main__':
72
- main()
 
1
+ from fastapi import HTTPException
2
  from pydantic import BaseModel
3
+ from typing import Dict
 
4
 
5
+ from openenv.core import create_app
6
+
7
+ from models import Observation, Action
8
  from env.environment import SQLEnv
 
9
  from env.tasks import TASKS
10
 
11
+ app = create_app(
12
+ env=SQLEnv,
13
+ action_cls=Action,
14
+ observation_cls=Observation,
15
+ env_name="sql-query-optimizer"
16
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.get("/tasks")
19
  async def get_tasks():
 
24
  "action_schema": action_schema
25
  }
26
 
 
 
 
 
 
 
27
  class BaselineResponse(BaseModel):
28
  scores: Dict[int, float]
29
 
 
35
  return BaselineResponse(scores=scores)
36
  except Exception as e:
37
  raise HTTPException(status_code=500, detail=str(e))