File size: 1,801 Bytes
b641d3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from .constants import TaskName
from .env import DistributedDebugEnv
from .models import Action, Observation, StepResult
@asynccontextmanager
async def lifespan(app: FastAPI):
env = DistributedDebugEnv()
env.start()
app.state.env = env
try:
yield
finally:
env.close()
app = FastAPI(title="Distributed Systems Debug Environment", version="1.0.0", lifespan=lifespan)
@app.post("/reset", response_model=Observation)
async def reset(task_name: str | None = None) -> Observation:
# Validator and sample inference call /reset without task input.
# Use a deterministic default task for reproducible episode bootstrapping.
selected_task_name = task_name or TaskName.CASCADING_TIMEOUT.value
try:
task = TaskName.parse(selected_task_name)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
try:
env: DistributedDebugEnv = app.state.env
return env.reset(task_name=task)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
@app.post("/step", response_model=StepResult)
async def step(action: Action) -> StepResult:
try:
env: DistributedDebugEnv = app.state.env
return env.step(action)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
@app.get("/state")
async def state() -> dict:
try:
env: DistributedDebugEnv = app.state.env
return env.state()
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
@app.get("/health")
async def health() -> dict:
return {"status": "ok", "version": "1.0.0"}
|