Spaces:
Sleeping
Sleeping
File size: 4,395 Bytes
7bdbe90 711341a 7bdbe90 711341a 7bdbe90 711341a 7bdbe90 711341a 7bdbe90 711341a 7bdbe90 711341a 7bdbe90 711341a 7bdbe90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
FastAPI application for the Meeting Scheduling RL Environment.
Uses OpenEnv's create_app() for standard routes + Gradio web UI,
then overrides /reset and /step with stateful (singleton) versions
so that HTTP-based interaction (curl, inference scripts) works correctly
across multiple calls within the same episode.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict, Optional
from fastapi import Body
from openenv.core.env_server.http_server import create_app
try:
from ..models import SchedulingAction, SchedulingObservation
from .scheduling_env_environment import SchedulingEnvironment
except (ModuleNotFoundError, ImportError):
from models import SchedulingAction, SchedulingObservation
from server.scheduling_env_environment import SchedulingEnvironment
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Create base app via OpenEnv (provides Gradio UI, WebSocket, schema, health)
# ---------------------------------------------------------------------------
app = create_app(
env=SchedulingEnvironment,
action_cls=SchedulingAction,
observation_cls=SchedulingObservation,
env_name="scheduling_env",
max_concurrent_envs=1,
)
# ---------------------------------------------------------------------------
# Remove default stateless /reset and /step routes so we can replace them
# with stateful singleton-backed versions for HTTP interaction.
# ---------------------------------------------------------------------------
_routes_to_remove = {"/reset", "/step"}
app.routes[:] = [r for r in app.routes if getattr(r, "path", None) not in _routes_to_remove]
# ---------------------------------------------------------------------------
# Singleton environment for stateful HTTP endpoints
# ---------------------------------------------------------------------------
_env: SchedulingEnvironment = SchedulingEnvironment()
@app.post("/reset")
async def reset_handler(
body: Optional[Dict[str, Any]] = Body(default=None),
) -> Dict[str, Any]:
"""Reset the environment to a new episode."""
body = body or {}
task_id = body.get("task_id", "task1_easy")
loop = asyncio.get_event_loop()
observation = await loop.run_in_executor(
None, lambda: _env.reset(task_id=task_id)
)
obs_dict = (
observation.model_dump()
if hasattr(observation, "model_dump")
else observation.__dict__
)
return {
"observation": obs_dict,
"done": getattr(observation, "done", False),
"reward": getattr(observation, "reward", 0.0),
**obs_dict,
}
@app.post("/step")
async def step_handler(
body: Dict[str, Any] = Body(...),
) -> Dict[str, Any]:
"""Execute an action and return the resulting observation."""
# Support both {"action": {...}} and direct action fields
action_data = body.get("action", body)
try:
action = SchedulingAction(**action_data)
except Exception as e:
logger.error("Failed to deserialize action: %s", e)
return {
"observation": {
"success": False,
"error_message": f"Invalid action: {e}",
"done": False,
"reward": -1.0,
},
"done": False,
"reward": -1.0,
}
loop = asyncio.get_event_loop()
observation = await loop.run_in_executor(None, _env.step, action)
obs_dict = (
observation.model_dump()
if hasattr(observation, "model_dump")
else observation.__dict__
)
return {
"observation": obs_dict,
"done": getattr(observation, "done", False),
"reward": getattr(observation, "reward", 0.0),
**obs_dict,
}
@app.get("/state")
async def state_handler() -> Dict[str, Any]:
"""Return the current internal environment state."""
state = _env.state
return (
state.model_dump()
if hasattr(state, "model_dump")
else state.__dict__
)
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|