File size: 4,395 Bytes
7bdbe90
 
 
 
 
 
 
 
 
711341a
 
 
 
7bdbe90
 
 
 
 
 
 
 
711341a
 
7bdbe90
 
711341a
7bdbe90
 
711341a
7bdbe90
 
 
 
711341a
 
 
 
 
7bdbe90
 
711341a
 
7bdbe90
711341a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bdbe90
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
FastAPI application for the Meeting Scheduling RL Environment.

Uses OpenEnv's create_app() for standard routes + Gradio web UI,
then overrides /reset and /step with stateful (singleton) versions
so that HTTP-based interaction (curl, inference scripts) works correctly
across multiple calls within the same episode.
"""

from __future__ import annotations

import asyncio
import logging
from typing import Any, Dict, Optional

from fastapi import Body
from openenv.core.env_server.http_server import create_app

try:
    from ..models import SchedulingAction, SchedulingObservation
    from .scheduling_env_environment import SchedulingEnvironment
except (ModuleNotFoundError, ImportError):
    from models import SchedulingAction, SchedulingObservation
    from server.scheduling_env_environment import SchedulingEnvironment

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Create base app via OpenEnv (provides Gradio UI, WebSocket, schema, health)
# ---------------------------------------------------------------------------
app = create_app(
    env=SchedulingEnvironment,
    action_cls=SchedulingAction,
    observation_cls=SchedulingObservation,
    env_name="scheduling_env",
    max_concurrent_envs=1,
)

# ---------------------------------------------------------------------------
# Remove default stateless /reset and /step routes so we can replace them
# with stateful singleton-backed versions for HTTP interaction.
# ---------------------------------------------------------------------------
_routes_to_remove = {"/reset", "/step"}
app.routes[:] = [r for r in app.routes if getattr(r, "path", None) not in _routes_to_remove]

# ---------------------------------------------------------------------------
# Singleton environment for stateful HTTP endpoints
# ---------------------------------------------------------------------------
_env: SchedulingEnvironment = SchedulingEnvironment()


@app.post("/reset")
async def reset_handler(
    body: Optional[Dict[str, Any]] = Body(default=None),
) -> Dict[str, Any]:
    """Reset the environment to a new episode."""
    body = body or {}
    task_id = body.get("task_id", "task1_easy")

    loop = asyncio.get_event_loop()
    observation = await loop.run_in_executor(
        None, lambda: _env.reset(task_id=task_id)
    )

    obs_dict = (
        observation.model_dump()
        if hasattr(observation, "model_dump")
        else observation.__dict__
    )
    return {
        "observation": obs_dict,
        "done": getattr(observation, "done", False),
        "reward": getattr(observation, "reward", 0.0),
        **obs_dict,
    }


@app.post("/step")
async def step_handler(
    body: Dict[str, Any] = Body(...),
) -> Dict[str, Any]:
    """Execute an action and return the resulting observation."""
    # Support both {"action": {...}} and direct action fields
    action_data = body.get("action", body)

    try:
        action = SchedulingAction(**action_data)
    except Exception as e:
        logger.error("Failed to deserialize action: %s", e)
        return {
            "observation": {
                "success": False,
                "error_message": f"Invalid action: {e}",
                "done": False,
                "reward": -1.0,
            },
            "done": False,
            "reward": -1.0,
        }

    loop = asyncio.get_event_loop()
    observation = await loop.run_in_executor(None, _env.step, action)

    obs_dict = (
        observation.model_dump()
        if hasattr(observation, "model_dump")
        else observation.__dict__
    )
    return {
        "observation": obs_dict,
        "done": getattr(observation, "done", False),
        "reward": getattr(observation, "reward", 0.0),
        **obs_dict,
    }


@app.get("/state")
async def state_handler() -> Dict[str, Any]:
    """Return the current internal environment state."""
    state = _env.state
    return (
        state.model_dump()
        if hasattr(state, "model_dump")
        else state.__dict__
    )


def main():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
    main()