""" server.py — FastAPI HTTP wrapper for OpenEnv Warehouse ======================================================= Exposes the environment over HTTP so the HF Space validator can ping /reset. Also provides full step/state/grade endpoints for remote evaluation. Endpoints: POST /reset → initial observation (validator pings this) POST /step → step the environment GET /state → current internal state POST /grade → run grader on current state GET /health → liveness check """ import os import sys from typing import Any, Dict, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel # Ensure the current directory is in the path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # UPDATED IMPORTS: Removing folder prefixes from warehouse_env import WarehouseEnv, Action, ActionType from graders import grade as run_grade app = FastAPI( title="OpenEnv Warehouse", description="Warehouse inventory management benchmark for autonomous agents.", version="1.0.0", ) # One env instance per task — keyed by task_id _envs: Dict[str, WarehouseEnv] = {} VALID_TASKS = ["triage_alerts", "resolve_stockout", "optimize_replenishment"] DEFAULT_SEED = int(os.getenv("ENV_SEED", "42")) def _get_env(task_id: str) -> WarehouseEnv: if task_id not in _envs: _envs[task_id] = WarehouseEnv(task_id=task_id, seed=DEFAULT_SEED) return _envs[task_id] # ── Request bodies ──────────────────────────────────────────────────── class ResetRequest(BaseModel): task_id: str = "triage_alerts" seed: Optional[int] = None class StepRequest(BaseModel): task_id: str = "triage_alerts" action: Dict[str, Any] class GradeRequest(BaseModel): task_id: str = "triage_alerts" # ── Endpoints ───────────────────────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok", "env": "openenv-warehouse", "version": "1.0.0"} @app.post("/reset") def reset(req: ResetRequest = ResetRequest()): """ Reset the environment and return the initial observation. The validator pings this endpoint with POST /reset and expects HTTP 200. """ if req.task_id not in VALID_TASKS: raise HTTPException( status_code=400, detail=f"Unknown task_id '{req.task_id}'. Choose from {VALID_TASKS}", ) seed = req.seed if req.seed is not None else DEFAULT_SEED env = WarehouseEnv(task_id=req.task_id, seed=seed) _envs[req.task_id] = env obs = env.reset() return JSONResponse( status_code=200, content={ "task_id": req.task_id, "seed": seed, "observation": obs.model_dump(), }, ) @app.post("/step") def step(req: StepRequest): """Execute one action and return (observation, reward, done, info).""" if req.task_id not in VALID_TASKS: raise HTTPException(status_code=400, detail=f"Unknown task_id '{req.task_id}'") env = _get_env(req.task_id) try: action = Action(**req.action) except Exception as exc: raise HTTPException(status_code=422, detail=f"Invalid action: {exc}") try: obs, reward, done, info = env.step(action) except RuntimeError as exc: raise HTTPException(status_code=409, detail=str(exc)) return { "observation": obs.model_dump(), "reward": reward.model_dump(), "done": done, "info": info, } @app.get("/state") def state(task_id: str = "triage_alerts"): """Return current full internal state (includes ground-truth, for graders).""" if task_id not in VALID_TASKS: raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'") env = _get_env(task_id) return env.state() @app.post("/grade") def grade(req: GradeRequest): """Run the programmatic grader on the current state and return score.""" if req.task_id not in VALID_TASKS: raise HTTPException(status_code=400, detail=f"Unknown task_id '{req.task_id}'") env = _get_env(req.task_id) score, explanation = run_grade(req.task_id, env.state()) return {"task_id": req.task_id, "score": score, "explanation": explanation} @app.get("/tasks") def list_tasks(): return {"tasks": VALID_TASKS}