from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, Optional from uuid import uuid4 from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from pydantic import BaseModel, ConfigDict, Field, ValidationError from support_ops_env.env import SupportOpsEnv from support_ops_env.models import Action, Observation, StateModel class ResetRequest(BaseModel): model_config = ConfigDict(extra="allow") seed: Optional[int] = Field(default=None, ge=0) episode_id: Optional[str] = None task_id: Optional[str] = None class StepRequest(BaseModel): model_config = ConfigDict(extra="allow") action: Dict[str, Any] timeout_s: Optional[float] = Field(default=None, gt=0) request_id: Optional[str] = None class StepResponse(BaseModel): observation: Dict[str, Any] reward: Optional[float] = None done: bool = False class HealthResponse(BaseModel): status: str = "healthy" class SchemaResponse(BaseModel): action: Dict[str, Any] observation: Dict[str, Any] state: Dict[str, Any] class EnvironmentMetadata(BaseModel): name: str description: str readme_content: Optional[str] = None version: str documentation_url: Optional[str] = None README_PATH = Path(__file__).resolve().parent.parent / "README.md" app = FastAPI( title="SupportOpsEnv Server", description="OpenEnv-compatible server for the SupportOpsEnv benchmark.", version="0.1.0", ) _http_env = SupportOpsEnv() _http_episode_id = str(uuid4()) def _serialize_observation( observation: Observation, reward: Optional[float] = None, done: bool = False, ) -> Dict[str, Any]: return { "observation": observation.model_dump(), "reward": reward, "done": done, } def _state_payload(env: SupportOpsEnv, episode_id: str) -> Dict[str, Any]: state = env.state().model_dump() state["episode_id"] = episode_id return state def _metadata() -> EnvironmentMetadata: return EnvironmentMetadata( name="support-ops-env", description="Multi-step customer support triage and escalation benchmark for OpenEnv-style agents.", readme_content=README_PATH.read_text(encoding="utf-8") if README_PATH.exists() else None, version="0.1.0", documentation_url="https://huggingface.co/spaces", ) @app.get("/") def root() -> dict[str, str]: return { "name": "support-ops-env", "status": "ok", "message": "SupportOpsEnv OpenEnv server is available.", } @app.get("/health", response_model=HealthResponse) def health() -> HealthResponse: return HealthResponse() @app.get("/metadata", response_model=EnvironmentMetadata) def metadata() -> EnvironmentMetadata: return _metadata() @app.get("/schema", response_model=SchemaResponse) def schema() -> SchemaResponse: return SchemaResponse( action=Action.model_json_schema(), observation=Observation.model_json_schema(), state=StateModel.model_json_schema(), ) @app.get("/state") def state() -> Dict[str, Any]: return _state_payload(_http_env, _http_episode_id) @app.post("/reset", response_model=StepResponse) def reset(request: ResetRequest = ResetRequest()) -> StepResponse: global _http_episode_id _http_episode_id = request.episode_id or str(uuid4()) observation = _http_env.reset(task_id=request.task_id) return StepResponse(**_serialize_observation(observation)) @app.post("/step", response_model=StepResponse) def step(request: StepRequest) -> StepResponse: try: action = Action.model_validate(request.action) except ValidationError as exc: raise HTTPException(status_code=422, detail=exc.errors()) from exc observation, reward, done, _info = _http_env.step(action) return StepResponse( **_serialize_observation(observation, reward=reward.value, done=done) ) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() env = SupportOpsEnv() episode_id = str(uuid4()) try: while True: raw_message = await websocket.receive_text() try: payload = json.loads(raw_message) except json.JSONDecodeError as exc: await websocket.send_json( { "type": "error", "data": {"message": f"Invalid JSON: {exc}", "code": "invalid_json"}, } ) continue msg_type = payload.get("type") data = payload.get("data", {}) try: if msg_type == "reset": reset_request = ResetRequest.model_validate(data) episode_id = reset_request.episode_id or str(uuid4()) observation = env.reset(task_id=reset_request.task_id) await websocket.send_json( {"type": "observation", "data": _serialize_observation(observation)} ) elif msg_type == "step": action = Action.model_validate(data) observation, reward, done, _info = env.step(action) await websocket.send_json( { "type": "observation", "data": _serialize_observation( observation, reward=reward.value, done=done, ), } ) elif msg_type == "state": await websocket.send_json( {"type": "state", "data": _state_payload(env, episode_id)} ) elif msg_type == "close": break else: await websocket.send_json( { "type": "error", "data": { "message": f"Unknown message type: {msg_type}", "code": "unknown_type", }, } ) except ValidationError as exc: await websocket.send_json( { "type": "error", "data": { "message": "Validation error", "code": "validation_error", "errors": exc.errors(), }, } ) except Exception as exc: # pragma: no cover await websocket.send_json( { "type": "error", "data": {"message": str(exc), "code": "execution_error"}, } ) except WebSocketDisconnect: pass finally: await websocket.close() def main(host: str = "0.0.0.0", port: int = 8000) -> None: import uvicorn uvicorn.run(app, host=host, port=port) def uv_main() -> FastAPI: return app if __name__ == "__main__": main()