Spaces:
Sleeping
Sleeping
File size: 3,436 Bytes
6762657 af8810b 6762657 2f3c64b af8810b 6762657 af8810b 6762657 af8810b 6762657 af8810b 6762657 2f3c64b af8810b 2f3c64b af8810b 2f3c64b af8810b 2f3c64b 6762657 af8810b 6762657 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | """FastAPI composition root — wires environment, MCP, and custom endpoints."""
from __future__ import annotations
import os
from threading import Lock
from openenv.core.env_server import create_fastapi_app
from fastapi import Query
from pydantic import BaseModel
from constants import PROJECT_DESCRIPTION, VERSION
from models import CommitmentAction, CommitmentObservation, CommitmentState
from server.environment import CommitmentEnvironment
from server.mcp import router as mcp_router
from server.tasks import get_scenario_ids_grouped
_DEFAULT_SESSION_ID = "default"
_env_store: dict[str, CommitmentEnvironment] = {
_DEFAULT_SESSION_ID: CommitmentEnvironment(),
}
_env_store_lock = Lock()
def _get_env(session_id: str) -> CommitmentEnvironment:
"""Return a per-session environment instance.
This avoids cross-user state bleed from a single shared mutable environment.
Clients can pass ``episode_id`` query param to isolate sessions.
"""
with _env_store_lock:
env = _env_store.get(session_id)
if env is None:
env = CommitmentEnvironment()
_env_store[session_id] = env
return env
class StepPayload(BaseModel):
action: CommitmentAction
app = create_fastapi_app(
env=lambda: _get_env(_DEFAULT_SESSION_ID),
action_cls=CommitmentAction,
observation_cls=CommitmentObservation,
)
app.title = "CommitmentOS"
app.description = PROJECT_DESCRIPTION
app.version = VERSION
app.routes[:] = [
r for r in app.routes
if not (hasattr(r, "path") and r.path in ("/state", "/mcp", "/reset", "/step"))
]
@app.post("/reset")
def reset_episode(
task_id: str | None = Query(default=None),
difficulty: str | None = Query(default=None),
seed: int | None = Query(default=None),
episode_id: str | None = Query(default=None),
) -> dict[str, object]:
"""Reset endpoint with explicit query-param support.
The default OpenEnv route did not reliably propagate ``task_id`` from
query params in this deployment setup, which made scenario selection
non-deterministic for demos/evaluations.
"""
session_id = episode_id or _DEFAULT_SESSION_ID
env = _get_env(session_id)
obs = env.reset(
seed=seed,
episode_id=session_id,
task_id=task_id,
difficulty=difficulty,
)
return {
"observation": obs.model_dump(),
"reward": float(obs.reward),
"done": bool(obs.done),
"episode_id": session_id,
}
@app.post("/step")
def step_episode(
payload: StepPayload,
episode_id: str | None = Query(default=None),
) -> dict[str, object]:
session_id = episode_id or _DEFAULT_SESSION_ID
env = _get_env(session_id)
obs = env.step(payload.action)
return {
"observation": obs.model_dump(),
"reward": float(obs.reward),
"done": bool(obs.done),
"episode_id": session_id,
}
@app.get("/state", response_model=CommitmentState)
def get_state(episode_id: str | None = Query(default=None)) -> CommitmentState:
session_id = episode_id or _DEFAULT_SESSION_ID
return _get_env(session_id).state
@app.get("/tasks")
def list_tasks() -> dict[str, list[str]]:
return get_scenario_ids_grouped()
app.include_router(mcp_router)
def main() -> None:
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()
|