Spaces:
Sleeping
Sleeping
| """ | |
| server.py β FastAPI HTTP wrapper for OpenEnv Warehouse | |
| ======================================================= | |
| Exposes the environment over HTTP so the HF Space validator can ping /reset. | |
| Also provides full step/state/grade endpoints for remote evaluation. | |
| Endpoints: | |
| POST /reset β initial observation (validator pings this) | |
| POST /step β step the environment | |
| GET /state β current internal state | |
| POST /grade β run grader on current state | |
| GET /health β liveness check | |
| """ | |
| import os | |
| import sys | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| # Ensure the current directory is in the path | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| # UPDATED IMPORTS: Removing folder prefixes | |
| from warehouse_env import WarehouseEnv, Action, ActionType | |
| from graders import grade as run_grade | |
| app = FastAPI( | |
| title="OpenEnv Warehouse", | |
| description="Warehouse inventory management benchmark for autonomous agents.", | |
| version="1.0.0", | |
| ) | |
| # One env instance per task β keyed by task_id | |
| _envs: Dict[str, WarehouseEnv] = {} | |
| VALID_TASKS = ["triage_alerts", "resolve_stockout", "optimize_replenishment"] | |
| DEFAULT_SEED = int(os.getenv("ENV_SEED", "42")) | |
| def _get_env(task_id: str) -> WarehouseEnv: | |
| if task_id not in _envs: | |
| _envs[task_id] = WarehouseEnv(task_id=task_id, seed=DEFAULT_SEED) | |
| return _envs[task_id] | |
| # ββ Request bodies ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task_id: str = "triage_alerts" | |
| seed: Optional[int] = None | |
| class StepRequest(BaseModel): | |
| task_id: str = "triage_alerts" | |
| action: Dict[str, Any] | |
| class GradeRequest(BaseModel): | |
| task_id: str = "triage_alerts" | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "ok", "env": "openenv-warehouse", "version": "1.0.0"} | |
| def reset(req: ResetRequest = ResetRequest()): | |
| """ | |
| Reset the environment and return the initial observation. | |
| The validator pings this endpoint with POST /reset and expects HTTP 200. | |
| """ | |
| if req.task_id not in VALID_TASKS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown task_id '{req.task_id}'. Choose from {VALID_TASKS}", | |
| ) | |
| seed = req.seed if req.seed is not None else DEFAULT_SEED | |
| env = WarehouseEnv(task_id=req.task_id, seed=seed) | |
| _envs[req.task_id] = env | |
| obs = env.reset() | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "task_id": req.task_id, | |
| "seed": seed, | |
| "observation": obs.model_dump(), | |
| }, | |
| ) | |
| def step(req: StepRequest): | |
| """Execute one action and return (observation, reward, done, info).""" | |
| if req.task_id not in VALID_TASKS: | |
| raise HTTPException(status_code=400, detail=f"Unknown task_id '{req.task_id}'") | |
| env = _get_env(req.task_id) | |
| try: | |
| action = Action(**req.action) | |
| except Exception as exc: | |
| raise HTTPException(status_code=422, detail=f"Invalid action: {exc}") | |
| try: | |
| obs, reward, done, info = env.step(action) | |
| except RuntimeError as exc: | |
| raise HTTPException(status_code=409, detail=str(exc)) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": reward.model_dump(), | |
| "done": done, | |
| "info": info, | |
| } | |
| def state(task_id: str = "triage_alerts"): | |
| """Return current full internal state (includes ground-truth, for graders).""" | |
| if task_id not in VALID_TASKS: | |
| raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'") | |
| env = _get_env(task_id) | |
| return env.state() | |
| def grade(req: GradeRequest): | |
| """Run the programmatic grader on the current state and return score.""" | |
| if req.task_id not in VALID_TASKS: | |
| raise HTTPException(status_code=400, detail=f"Unknown task_id '{req.task_id}'") | |
| env = _get_env(req.task_id) | |
| score, explanation = run_grade(req.task_id, env.state()) | |
| return {"task_id": req.task_id, "score": score, "explanation": explanation} | |
| def list_tasks(): | |
| return {"tasks": VALID_TASKS} | |