openenv-warehouse / server.py
flamingo44333's picture
Update server.py
abfa97e verified
"""
server.py β€” FastAPI HTTP wrapper for OpenEnv Warehouse
=======================================================
Exposes the environment over HTTP so the HF Space validator can ping /reset.
Also provides full step/state/grade endpoints for remote evaluation.
Endpoints:
POST /reset β†’ initial observation (validator pings this)
POST /step β†’ step the environment
GET /state β†’ current internal state
POST /grade β†’ run grader on current state
GET /health β†’ liveness check
"""
import os
import sys
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Ensure the current directory is in the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# UPDATED IMPORTS: Removing folder prefixes
from warehouse_env import WarehouseEnv, Action, ActionType
from graders import grade as run_grade
app = FastAPI(
title="OpenEnv Warehouse",
description="Warehouse inventory management benchmark for autonomous agents.",
version="1.0.0",
)
# One env instance per task β€” keyed by task_id
_envs: Dict[str, WarehouseEnv] = {}
VALID_TASKS = ["triage_alerts", "resolve_stockout", "optimize_replenishment"]
DEFAULT_SEED = int(os.getenv("ENV_SEED", "42"))
def _get_env(task_id: str) -> WarehouseEnv:
if task_id not in _envs:
_envs[task_id] = WarehouseEnv(task_id=task_id, seed=DEFAULT_SEED)
return _envs[task_id]
# ── Request bodies ────────────────────────────────────────────────────
class ResetRequest(BaseModel):
task_id: str = "triage_alerts"
seed: Optional[int] = None
class StepRequest(BaseModel):
task_id: str = "triage_alerts"
action: Dict[str, Any]
class GradeRequest(BaseModel):
task_id: str = "triage_alerts"
# ── Endpoints ─────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok", "env": "openenv-warehouse", "version": "1.0.0"}
@app.post("/reset")
def reset(req: ResetRequest = ResetRequest()):
"""
Reset the environment and return the initial observation.
The validator pings this endpoint with POST /reset and expects HTTP 200.
"""
if req.task_id not in VALID_TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{req.task_id}'. Choose from {VALID_TASKS}",
)
seed = req.seed if req.seed is not None else DEFAULT_SEED
env = WarehouseEnv(task_id=req.task_id, seed=seed)
_envs[req.task_id] = env
obs = env.reset()
return JSONResponse(
status_code=200,
content={
"task_id": req.task_id,
"seed": seed,
"observation": obs.model_dump(),
},
)
@app.post("/step")
def step(req: StepRequest):
"""Execute one action and return (observation, reward, done, info)."""
if req.task_id not in VALID_TASKS:
raise HTTPException(status_code=400, detail=f"Unknown task_id '{req.task_id}'")
env = _get_env(req.task_id)
try:
action = Action(**req.action)
except Exception as exc:
raise HTTPException(status_code=422, detail=f"Invalid action: {exc}")
try:
obs, reward, done, info = env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc))
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
"done": done,
"info": info,
}
@app.get("/state")
def state(task_id: str = "triage_alerts"):
"""Return current full internal state (includes ground-truth, for graders)."""
if task_id not in VALID_TASKS:
raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'")
env = _get_env(task_id)
return env.state()
@app.post("/grade")
def grade(req: GradeRequest):
"""Run the programmatic grader on the current state and return score."""
if req.task_id not in VALID_TASKS:
raise HTTPException(status_code=400, detail=f"Unknown task_id '{req.task_id}'")
env = _get_env(req.task_id)
score, explanation = run_grade(req.task_id, env.state())
return {"task_id": req.task_id, "score": score, "explanation": explanation}
@app.get("/tasks")
def list_tasks():
return {"tasks": VALID_TASKS}