Spaces:
Sleeping
Sleeping
File size: 4,867 Bytes
b272983 4812df3 b272983 5cf727a febcf68 b272983 4812df3 febcf68 4812df3 bb2fc43 febcf68 5cf727a b272983 5cf727a b272983 5cf727a bb2fc43 b272983 4812df3 b272983 bb2fc43 4812df3 bb2fc43 4812df3 b272983 bb2fc43 b272983 4812df3 bb2fc43 b272983 4812df3 bb2fc43 4812df3 bb2fc43 b272983 4812df3 bb2fc43 b272983 5cf727a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """
FastAPI application for the DataClean Environment.
Uses the OpenEnv framework's create_app() for full feature support
(WebSocket, Web UI, MCP, OpenAPI docs) while patching in session-isolated
stateful HTTP endpoints for inference script compatibility.
"""
import asyncio
import os
from uuid import uuid4
from fastapi import FastAPI, Body, Header
from pydantic import BaseModel
from typing import Any, Dict, Optional
# Enable the Gradio web interface before importing create_app
os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")
from openenv.core.env_server.http_server import create_app
try:
from .environment import DataCleanEnvironment
from ..models import DataCleanAction, DataCleanObservation
except ImportError:
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.environment import DataCleanEnvironment
from models import DataCleanAction, DataCleanObservation
# ---------------------------------------------------------------------------
# Create the full framework app (WebSocket /ws, Web UI /web/, MCP /mcp,
# OpenAPI /docs, /health, /metadata, /schema)
# ---------------------------------------------------------------------------
_framework_app = create_app(
DataCleanEnvironment,
DataCleanAction,
DataCleanObservation,
env_name="data_clean_env",
)
# Remove the framework's stateless /reset, /step, /state HTTP routes
# so we can replace them with session-isolated stateful versions below.
# This keeps WebSocket, web UI, MCP, /docs, /health, /metadata, /schema intact.
_framework_app.router.routes = [
r for r in _framework_app.router.routes
if not (
hasattr(r, 'path') and hasattr(r, 'methods')
and r.path in ('/reset', '/step', '/state')
)
]
app = _framework_app
# ---------------------------------------------------------------------------
# Session-isolated stateful HTTP layer
#
# Each session gets its own DataCleanEnvironment instance. Sessions are
# identified by the X-Session-Id header (or auto-assigned on /reset).
# A default session ("default") is used when no header is provided,
# so simple single-client usage (like inference.py) works out of the box.
# ---------------------------------------------------------------------------
_sessions: Dict[str, DataCleanEnvironment] = {}
_sessions_lock = asyncio.Lock()
MAX_SESSIONS = 50
async def _get_or_create_session(session_id: str) -> DataCleanEnvironment:
async with _sessions_lock:
if session_id not in _sessions:
if len(_sessions) >= MAX_SESSIONS:
oldest = next(iter(_sessions))
del _sessions[oldest]
_sessions[session_id] = DataCleanEnvironment()
return _sessions[session_id]
class ResetRequest(BaseModel):
task_id: str = "customer_contacts"
seed: Optional[int] = None
episode_id: Optional[str] = None
model_config = {"extra": "allow"}
class StepRequest(BaseModel):
action: Dict[str, Any]
model_config = {"extra": "allow"}
def _obs_dict(obs: DataCleanObservation) -> dict:
return obs.model_dump()
@app.post("/reset", tags=["Environment Control"])
async def stateful_reset(
request: ResetRequest = Body(default_factory=ResetRequest),
x_session_id: Optional[str] = Header(default="default"),
):
"""Reset the environment with a specific task. Session-isolated via X-Session-Id header."""
session_id = x_session_id or "default"
env = await _get_or_create_session(session_id)
obs = env.reset(
seed=request.seed,
episode_id=request.episode_id,
task_id=request.task_id,
)
return {"observation": _obs_dict(obs), "reward": None, "done": False}
@app.post("/step", tags=["Environment Control"])
async def stateful_step(
request: StepRequest,
x_session_id: Optional[str] = Header(default="default"),
):
"""Execute an action. Session-isolated via X-Session-Id header."""
session_id = x_session_id or "default"
env = await _get_or_create_session(session_id)
action = DataCleanAction(**request.action)
obs = env.step(action)
return {"observation": _obs_dict(obs), "reward": obs.reward, "done": obs.done}
@app.get("/state", tags=["State Management"])
async def stateful_state(
x_session_id: Optional[str] = Header(default="default"),
):
"""Get current environment state for a session."""
session_id = x_session_id or "default"
env = await _get_or_create_session(session_id)
return env.state.model_dump()
# ---------------------------------------------------------------------------
# Entry point for `uv run server` / `python -m server.app`
# ---------------------------------------------------------------------------
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|