Spaces:
Sleeping
Sleeping
| """FastAPI composition root — wires environment, MCP, and custom endpoints.""" | |
| from __future__ import annotations | |
| import os | |
| from threading import Lock | |
| from openenv.core.env_server import create_fastapi_app | |
| from fastapi import Query | |
| from pydantic import BaseModel | |
| from constants import PROJECT_DESCRIPTION, VERSION | |
| from models import CommitmentAction, CommitmentObservation, CommitmentState | |
| from server.environment import CommitmentEnvironment | |
| from server.mcp import router as mcp_router | |
| from server.tasks import get_scenario_ids_grouped | |
| _DEFAULT_SESSION_ID = "default" | |
| _env_store: dict[str, CommitmentEnvironment] = { | |
| _DEFAULT_SESSION_ID: CommitmentEnvironment(), | |
| } | |
| _env_store_lock = Lock() | |
| def _get_env(session_id: str) -> CommitmentEnvironment: | |
| """Return a per-session environment instance. | |
| This avoids cross-user state bleed from a single shared mutable environment. | |
| Clients can pass ``episode_id`` query param to isolate sessions. | |
| """ | |
| with _env_store_lock: | |
| env = _env_store.get(session_id) | |
| if env is None: | |
| env = CommitmentEnvironment() | |
| _env_store[session_id] = env | |
| return env | |
| class StepPayload(BaseModel): | |
| action: CommitmentAction | |
| app = create_fastapi_app( | |
| env=lambda: _get_env(_DEFAULT_SESSION_ID), | |
| action_cls=CommitmentAction, | |
| observation_cls=CommitmentObservation, | |
| ) | |
| app.title = "CommitmentOS" | |
| app.description = PROJECT_DESCRIPTION | |
| app.version = VERSION | |
| app.routes[:] = [ | |
| r for r in app.routes | |
| if not (hasattr(r, "path") and r.path in ("/state", "/mcp", "/reset", "/step")) | |
| ] | |
| def reset_episode( | |
| task_id: str | None = Query(default=None), | |
| difficulty: str | None = Query(default=None), | |
| seed: int | None = Query(default=None), | |
| episode_id: str | None = Query(default=None), | |
| ) -> dict[str, object]: | |
| """Reset endpoint with explicit query-param support. | |
| The default OpenEnv route did not reliably propagate ``task_id`` from | |
| query params in this deployment setup, which made scenario selection | |
| non-deterministic for demos/evaluations. | |
| """ | |
| session_id = episode_id or _DEFAULT_SESSION_ID | |
| env = _get_env(session_id) | |
| obs = env.reset( | |
| seed=seed, | |
| episode_id=session_id, | |
| task_id=task_id, | |
| difficulty=difficulty, | |
| ) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": float(obs.reward), | |
| "done": bool(obs.done), | |
| "episode_id": session_id, | |
| } | |
| def step_episode( | |
| payload: StepPayload, | |
| episode_id: str | None = Query(default=None), | |
| ) -> dict[str, object]: | |
| session_id = episode_id or _DEFAULT_SESSION_ID | |
| env = _get_env(session_id) | |
| obs = env.step(payload.action) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": float(obs.reward), | |
| "done": bool(obs.done), | |
| "episode_id": session_id, | |
| } | |
| def get_state(episode_id: str | None = Query(default=None)) -> CommitmentState: | |
| session_id = episode_id or _DEFAULT_SESSION_ID | |
| return _get_env(session_id).state | |
| def list_tasks() -> dict[str, list[str]]: | |
| return get_scenario_ids_grouped() | |
| app.include_router(mcp_router) | |
| def main() -> None: | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |
| if __name__ == "__main__": | |
| main() | |