TanmaySK commited on
Commit
8cb7cab
·
verified ·
1 Parent(s): 8c1afa8

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +55 -47
server.py CHANGED
@@ -1,47 +1,55 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import Optional
4
- from app.models import Action, Observation, StepResponse
5
- from app.env import CrisisSimEnv
6
- from app.tasks import TASKS
7
-
8
- app = FastAPI(title="CrisisSim")
9
-
10
- env_instance: Optional[CrisisSimEnv] = None
11
-
12
- class ResetRequest(BaseModel):
13
- task_name: str = "easy"
14
-
15
- @app.post("/reset", response_model=Observation)
16
- def reset_env(req: ResetRequest):
17
- global env_instance
18
- task_config = TASKS.get(req.task_name)
19
- if not task_config:
20
- raise HTTPException(status_code=400, detail="Invalid task name")
21
- env_instance = CrisisSimEnv(task_config)
22
- return env_instance.state()
23
-
24
- @app.post("/step", response_model=StepResponse)
25
- def step_env(action: Action):
26
- global env_instance
27
- if not env_instance:
28
- raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.")
29
-
30
- obs, reward, done, info = env_instance.step(action.action)
31
- return StepResponse(
32
- observation=obs,
33
- reward=reward,
34
- done=done,
35
- info=info
36
- )
37
-
38
- @app.get("/state", response_model=Observation)
39
- def get_state():
40
- global env_instance
41
- if not env_instance:
42
- raise HTTPException(status_code=400, detail="Environment not initialized.")
43
- return env_instance.state()
44
-
45
- if __name__ == "__main__":
46
- import uvicorn
47
- uvicorn.run("server:app", host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ from app.models import Action, Observation, StepResponse
5
+ from app.env import CrisisSimEnv
6
+ from app.tasks import TASKS
7
+
8
+ app = FastAPI(title="CrisisSim")
9
+
10
+ env_instance: Optional[CrisisSimEnv] = None
11
+
12
+ class ResetRequest(BaseModel):
13
+ task_name: str = "easy"
14
+
15
+ @app.post("/reset", response_model=Observation)
16
+ def reset_env(req: ResetRequest):
17
+ global env_instance
18
+ task_config = TASKS.get(req.task_name)
19
+ if not task_config:
20
+ raise HTTPException(status_code=400, detail="Invalid task name")
21
+ env_instance = CrisisSimEnv(task_config)
22
+ return env_instance.state()
23
+
24
+ @app.post("/step", response_model=StepResponse)
25
+ def step_env(action: Action):
26
+ global env_instance
27
+ if not env_instance:
28
+ raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.")
29
+
30
+ obs, reward, done, info = env_instance.step(action.action)
31
+ return StepResponse(
32
+ observation=obs,
33
+ reward=reward,
34
+ done=done,
35
+ info=info
36
+ )
37
+
38
+ @app.get("/")
39
+ def root():
40
+ return {
41
+ "message": "CrisisSim API is running 🚀",
42
+ "endpoints": ["/reset", "/step", "/state"]
43
+ }
44
+
45
+
46
+ @app.get("/state", response_model=Observation)
47
+ def get_state():
48
+ global env_instance
49
+ if not env_instance:
50
+ raise HTTPException(status_code=400, detail="Environment not initialized.")
51
+ return env_instance.state()
52
+
53
+ if __name__ == "__main__":
54
+ import uvicorn
55
+ uvicorn.run("server:app", host="0.0.0.0", port=7860)