from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Body, Query from fastapi.responses import HTMLResponse from pydantic import TypeAdapter from typing import Optional, Any import json import os import uvicorn from server.environment import ProcureEnvironment from models import ( ProcureObservation, ProcureState, QueryAction, RequestDocAction, OfferAction, AcceptAction, RejectAction, ProcureAction, ) app = FastAPI( title="ProcureEnv", description="Industrial B2B Procurement RL Environment", version="1.0.0", ) # ------------------------------------------------------------------ # # OpenEnv required endpoints # # ------------------------------------------------------------------ # @app.get("/health") def health(): return {"status": "healthy"} @app.get("/metadata") def metadata(): """Return environment name and description for OpenEnv validator.""" return { "name": "procure_env", "description": ( "Industrial B2B procurement negotiation environment. " "An agent acts as a procurement engineer: querying hidden supplier attributes, " "negotiating prices, verifying compliance certifications, and avoiding adversarial " "counterparties to fulfill purchase requirements under budget constraints." ), } @app.get("/schema") def schema(): """Return JSON schemas for action, observation, and state types.""" action_schema = TypeAdapter(ProcureAction).json_schema() return { "action": action_schema, "observation": ProcureObservation.model_json_schema(), "state": ProcureState.model_json_schema(), } @app.post("/mcp") def mcp(request: dict = Body(default={})): """ Minimal JSON-RPC 2.0 endpoint for MCP tool discovery. Returns an empty tools list -- procurement actions are exposed via /step. """ return { "jsonrpc": "2.0", "id": request.get("id"), "result": {"tools": []}, } # ------------------------------------------------------------------ # # Simulation endpoints # # ------------------------------------------------------------------ # @app.post("/reset") async def reset( task_id: Optional[str] = Query(None), body: dict = Body(default={}) ): tid = task_id or body.get("task_id", "task1_easy") env = ProcureEnvironment(task_id=tid) obs = env.reset() return obs.model_dump() @app.post("/step") def step(action: dict = Body(default={}), task_id: str = "task1_easy"): """Stateless HTTP step -- resets env each call. Use /ws for stateful sessions.""" if "action" in action and isinstance(action.get("action"), dict): payload = action["action"] task_id = action.get("task_id", task_id) else: payload = action env = ProcureEnvironment(task_id=task_id) env.reset() obs = env.step(payload) return obs.model_dump() @app.get("/state") def state(task_id: str = "task1_easy"): env = ProcureEnvironment(task_id=task_id) env.reset() return env.state.model_dump() # ------------------------------------------------------------------ # # Status page # # ------------------------------------------------------------------ # @app.get("/web", response_class=HTMLResponse) def web_ui(): return """ ProcureEnv

ProcureEnv

Industrial B2B Procurement Negotiation — OpenEnv RL Environment

An agent acts as a procurement engineer: querying hidden supplier attributes, negotiating prices, verifying compliance certifications, and avoiding adversarial counterparties to fulfill purchase requirements under budget constraints.

Tasks

TaskDifficultyMax StepsKey Challenge
task1_easyEasy12Conveyor belts, ₹69L budget, 3 suppliers, pure negotiation
task2_mediumMedium18Pressure valves, ATEX required, QuickSeal has quality issues
task3_hardHard25Hydraulic pumps, CE+ISO9001, FluidDyn deceptive, tight budget

Endpoints

EndpointMethodDescription
/wsWebSocketPersistent stateful session (recommended)
/resetPOSTReset environment, returns initial observation
/stepPOSTExecute action (stateless)
/stateGETCurrent environment state
/healthGETHealth check
/metadataGETEnvironment name and description
/schemaGETAction / observation / state JSON schemas
/mcpPOSTMCP tool discovery (JSON-RPC 2.0)
/docsGETOpenAPI documentation

API docsHealthMetadataSchema

""" # ------------------------------------------------------------------ # # WebSocket (stateful sessions) # # ------------------------------------------------------------------ # @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """ Stateful WebSocket session. Client sends: {"type": "reset", "task_id": "task1_easy"} {"type": "step", "action": {...}} Server responds with observation JSON each time. """ await websocket.accept() env: Optional[ProcureEnvironment] = None try: while True: data = await websocket.receive_text() msg = json.loads(data) if msg.get("type") == "reset": task_id = msg.get("task_id", "task1_easy") env = ProcureEnvironment(task_id=task_id) obs = env.reset() await websocket.send_text(obs.model_dump_json()) elif msg.get("type") == "step": if env is None: await websocket.send_text(json.dumps({"error": "Call reset first"})) continue action = msg.get("action", {}) obs = env.step(action) await websocket.send_text(obs.model_dump_json()) elif msg.get("type") == "state": if env is None: await websocket.send_text(json.dumps({"error": "Call reset first"})) continue await websocket.send_text(env.state.model_dump_json()) else: await websocket.send_text(json.dumps({"error": f"Unknown type: {msg.get('type')}"})) except WebSocketDisconnect: pass def main(): uvicorn.run( "server.app:app", host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", "7860")), ) if __name__ == "__main__": main()