File size: 2,181 Bytes
ef41468 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | from __future__ import annotations
import os
from fastapi import FastAPI
import uvicorn
from models import SepsisAction, SepsisObservation, SepsisState
from openenv_compat import OPENENV_AVAILABLE, create_app
from server.sepsis_environment import SepsisTreatmentEnvironment
if OPENENV_AVAILABLE and create_app is not None:
app = create_app(SepsisTreatmentEnvironment, SepsisAction, SepsisObservation, env_name="sepsis-openenv")
else:
environment = SepsisTreatmentEnvironment()
app = FastAPI(title="Sepsis OpenEnv", version="0.1.0")
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/metadata")
def metadata() -> dict:
return environment.metadata()
@app.get("/schema")
def schema() -> dict:
return {
"action_schema": SepsisAction.model_json_schema(),
"observation_schema": SepsisObservation.model_json_schema(),
"state_schema": SepsisState.model_json_schema(),
}
@app.post("/reset")
def reset(payload: dict | None = None) -> dict:
task_id = None
if payload:
task_id = payload.get("task_id")
observation = environment.reset(task_id=task_id)
return {
"observation": observation.model_dump(),
"reward": 0.0,
"done": False,
"info": {
"tasks": environment.available_tasks(),
"metrics": environment.current_metrics(),
},
}
@app.post("/step")
def step(payload: dict) -> dict:
action = SepsisAction(**payload)
observation = environment.step(action)
return {
"observation": observation.model_dump(),
"reward": observation.reward,
"done": observation.done,
"info": {
"metrics": environment.current_metrics(),
},
}
@app.get("/state")
def state() -> dict:
return environment.state.model_dump()
def main() -> None:
port = int(os.getenv("PORT", "7860"))
uvicorn.run("server.app:app", host="0.0.0.0", port=port)
if __name__ == "__main__":
main()
|