| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import Optional |
| from app.models import Action, Observation, StepResponse |
| from app.env import CrisisSimEnv |
| from app.tasks import TASKS |
|
|
| app = FastAPI(title="CrisisSim") |
|
|
| env_instance: Optional[CrisisSimEnv] = None |
|
|
| class ResetRequest(BaseModel): |
| task_name: str = "easy" |
|
|
| @app.post("/reset", response_model=Observation) |
| def reset_env(req: ResetRequest): |
| global env_instance |
| task_config = TASKS.get(req.task_name) |
| if not task_config: |
| raise HTTPException(status_code=400, detail="Invalid task name") |
| env_instance = CrisisSimEnv(task_config) |
| return env_instance.state() |
|
|
| @app.post("/step", response_model=StepResponse) |
| def step_env(action: Action): |
| global env_instance |
| if not env_instance: |
| raise HTTPException(status_code=400, detail="Environment not initialized. Call /reset first.") |
| |
| obs, reward, done, info = env_instance.step(action.action) |
| return StepResponse( |
| observation=obs, |
| reward=reward, |
| done=done, |
| info=info |
| ) |
|
|
| @app.get("/") |
| def root(): |
| return { |
| "message": "CrisisSim API is running 🚀", |
| "endpoints": ["/reset", "/step", "/state"] |
| } |
|
|
| |
| @app.get("/state", response_model=Observation) |
| def get_state(): |
| global env_instance |
| if not env_instance: |
| raise HTTPException(status_code=400, detail="Environment not initialized.") |
| return env_instance.state() |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("server:app", host="0.0.0.0", port=7860) |
|
|