Akshaykumarbm's picture
Upload folder using huggingface_hub
711341a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
FastAPI application for the Meeting Scheduling RL Environment.
Uses OpenEnv's create_app() for standard routes + Gradio web UI,
then overrides /reset and /step with stateful (singleton) versions
so that HTTP-based interaction (curl, inference scripts) works correctly
across multiple calls within the same episode.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict, Optional
from fastapi import Body
from openenv.core.env_server.http_server import create_app
try:
from ..models import SchedulingAction, SchedulingObservation
from .scheduling_env_environment import SchedulingEnvironment
except (ModuleNotFoundError, ImportError):
from models import SchedulingAction, SchedulingObservation
from server.scheduling_env_environment import SchedulingEnvironment
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Create base app via OpenEnv (provides Gradio UI, WebSocket, schema, health)
# ---------------------------------------------------------------------------
app = create_app(
env=SchedulingEnvironment,
action_cls=SchedulingAction,
observation_cls=SchedulingObservation,
env_name="scheduling_env",
max_concurrent_envs=1,
)
# ---------------------------------------------------------------------------
# Remove default stateless /reset and /step routes so we can replace them
# with stateful singleton-backed versions for HTTP interaction.
# ---------------------------------------------------------------------------
_routes_to_remove = {"/reset", "/step"}
app.routes[:] = [r for r in app.routes if getattr(r, "path", None) not in _routes_to_remove]
# ---------------------------------------------------------------------------
# Singleton environment for stateful HTTP endpoints
# ---------------------------------------------------------------------------
_env: SchedulingEnvironment = SchedulingEnvironment()
@app.post("/reset")
async def reset_handler(
body: Optional[Dict[str, Any]] = Body(default=None),
) -> Dict[str, Any]:
"""Reset the environment to a new episode."""
body = body or {}
task_id = body.get("task_id", "task1_easy")
loop = asyncio.get_event_loop()
observation = await loop.run_in_executor(
None, lambda: _env.reset(task_id=task_id)
)
obs_dict = (
observation.model_dump()
if hasattr(observation, "model_dump")
else observation.__dict__
)
return {
"observation": obs_dict,
"done": getattr(observation, "done", False),
"reward": getattr(observation, "reward", 0.0),
**obs_dict,
}
@app.post("/step")
async def step_handler(
body: Dict[str, Any] = Body(...),
) -> Dict[str, Any]:
"""Execute an action and return the resulting observation."""
# Support both {"action": {...}} and direct action fields
action_data = body.get("action", body)
try:
action = SchedulingAction(**action_data)
except Exception as e:
logger.error("Failed to deserialize action: %s", e)
return {
"observation": {
"success": False,
"error_message": f"Invalid action: {e}",
"done": False,
"reward": -1.0,
},
"done": False,
"reward": -1.0,
}
loop = asyncio.get_event_loop()
observation = await loop.run_in_executor(None, _env.step, action)
obs_dict = (
observation.model_dump()
if hasattr(observation, "model_dump")
else observation.__dict__
)
return {
"observation": obs_dict,
"done": getattr(observation, "done", False),
"reward": getattr(observation, "reward", 0.0),
**obs_dict,
}
@app.get("/state")
async def state_handler() -> Dict[str, Any]:
"""Return the current internal environment state."""
state = _env.state
return (
state.model_dump()
if hasattr(state, "model_dump")
else state.__dict__
)
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()