File size: 1,586 Bytes
8cb7cab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from app.models import Action, Observation, StepResponse
from app.env import CrisisSimEnv
from app.tasks import TASKS
app = FastAPI(title="CrisisSim")
env_instance: Optional[CrisisSimEnv] = None
class ResetRequest(BaseModel):
task_name: str = "easy"
@app.post("/reset", response_model=Observation)
def reset_env(req: ResetRequest):
global env_instance
task_config = TASKS.get(req.task_name)
if not task_config:
raise HTTPException(status_code=400, detail="Invalid task name")
env_instance = CrisisSimEnv(task_config)
return env_instance.state()
@app.post("/step", response_model=StepResponse)
def step_env(action: Action):
global env_instance
if not env_instance:
raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.")
obs, reward, done, info = env_instance.step(action.action)
return StepResponse(
observation=obs,
reward=reward,
done=done,
info=info
)
@app.get("/")
def root():
return {
"message": "CrisisSim API is running 🚀",
"endpoints": ["/reset", "/step", "/state"]
}
@app.get("/state", response_model=Observation)
def get_state():
global env_instance
if not env_instance:
raise HTTPException(status_code=400, detail="Environment not initialized.")
return env_instance.state()
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=7860)
|