File size: 4,867 Bytes
b272983
 
 
 
4812df3
 
b272983
5cf727a
febcf68
b272983
4812df3
febcf68
4812df3
bb2fc43
febcf68
5cf727a
b272983
 
 
 
 
5cf727a
 
 
 
b272983
5cf727a
 
 
 
bb2fc43
b272983
 
 
 
 
 
 
 
 
 
 
 
4812df3
b272983
 
 
 
 
 
 
 
 
 
bb2fc43
 
 
4812df3
 
 
 
 
 
bb2fc43
4812df3
 
 
 
 
 
 
 
 
 
 
 
 
 
b272983
 
bb2fc43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b272983
4812df3
 
 
 
 
 
 
 
 
 
 
 
bb2fc43
 
 
b272983
4812df3
 
 
 
 
 
 
bb2fc43
4812df3
bb2fc43
 
 
b272983
4812df3
 
 
 
 
 
 
bb2fc43
 
b272983
 
 
5cf727a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
FastAPI application for the DataClean Environment.

Uses the OpenEnv framework's create_app() for full feature support
(WebSocket, Web UI, MCP, OpenAPI docs) while patching in session-isolated
stateful HTTP endpoints for inference script compatibility.
"""

import asyncio
import os
from uuid import uuid4

from fastapi import FastAPI, Body, Header
from pydantic import BaseModel
from typing import Any, Dict, Optional

# Enable the Gradio web interface before importing create_app
os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")

from openenv.core.env_server.http_server import create_app

try:
    from .environment import DataCleanEnvironment
    from ..models import DataCleanAction, DataCleanObservation
except ImportError:
    import sys
    sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from server.environment import DataCleanEnvironment
    from models import DataCleanAction, DataCleanObservation


# ---------------------------------------------------------------------------
# Create the full framework app (WebSocket /ws, Web UI /web/, MCP /mcp,
# OpenAPI /docs, /health, /metadata, /schema)
# ---------------------------------------------------------------------------
_framework_app = create_app(
    DataCleanEnvironment,
    DataCleanAction,
    DataCleanObservation,
    env_name="data_clean_env",
)

# Remove the framework's stateless /reset, /step, /state HTTP routes
# so we can replace them with session-isolated stateful versions below.
# This keeps WebSocket, web UI, MCP, /docs, /health, /metadata, /schema intact.
_framework_app.router.routes = [
    r for r in _framework_app.router.routes
    if not (
        hasattr(r, 'path') and hasattr(r, 'methods')
        and r.path in ('/reset', '/step', '/state')
    )
]

app = _framework_app


# ---------------------------------------------------------------------------
# Session-isolated stateful HTTP layer
#
# Each session gets its own DataCleanEnvironment instance. Sessions are
# identified by the X-Session-Id header (or auto-assigned on /reset).
# A default session ("default") is used when no header is provided,
# so simple single-client usage (like inference.py) works out of the box.
# ---------------------------------------------------------------------------
_sessions: Dict[str, DataCleanEnvironment] = {}
_sessions_lock = asyncio.Lock()

MAX_SESSIONS = 50


async def _get_or_create_session(session_id: str) -> DataCleanEnvironment:
    async with _sessions_lock:
        if session_id not in _sessions:
            if len(_sessions) >= MAX_SESSIONS:
                oldest = next(iter(_sessions))
                del _sessions[oldest]
            _sessions[session_id] = DataCleanEnvironment()
        return _sessions[session_id]


class ResetRequest(BaseModel):
    task_id: str = "customer_contacts"
    seed: Optional[int] = None
    episode_id: Optional[str] = None
    model_config = {"extra": "allow"}


class StepRequest(BaseModel):
    action: Dict[str, Any]
    model_config = {"extra": "allow"}


def _obs_dict(obs: DataCleanObservation) -> dict:
    return obs.model_dump()


@app.post("/reset", tags=["Environment Control"])
async def stateful_reset(
    request: ResetRequest = Body(default_factory=ResetRequest),
    x_session_id: Optional[str] = Header(default="default"),
):
    """Reset the environment with a specific task. Session-isolated via X-Session-Id header."""
    session_id = x_session_id or "default"
    env = await _get_or_create_session(session_id)
    obs = env.reset(
        seed=request.seed,
        episode_id=request.episode_id,
        task_id=request.task_id,
    )
    return {"observation": _obs_dict(obs), "reward": None, "done": False}


@app.post("/step", tags=["Environment Control"])
async def stateful_step(
    request: StepRequest,
    x_session_id: Optional[str] = Header(default="default"),
):
    """Execute an action. Session-isolated via X-Session-Id header."""
    session_id = x_session_id or "default"
    env = await _get_or_create_session(session_id)
    action = DataCleanAction(**request.action)
    obs = env.step(action)
    return {"observation": _obs_dict(obs), "reward": obs.reward, "done": obs.done}


@app.get("/state", tags=["State Management"])
async def stateful_state(
    x_session_id: Optional[str] = Header(default="default"),
):
    """Get current environment state for a session."""
    session_id = x_session_id or "default"
    env = await _get_or_create_session(session_id)
    return env.state.model_dump()


# ---------------------------------------------------------------------------
# Entry point for `uv run server` / `python -m server.app`
# ---------------------------------------------------------------------------
def main():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
    main()