File size: 3,436 Bytes
6762657
 
 
 
 
af8810b
6762657
 
2f3c64b
af8810b
6762657
 
 
 
 
 
 
af8810b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6762657
 
af8810b
6762657
 
 
 
 
 
 
 
 
 
af8810b
6762657
 
 
2f3c64b
 
 
 
 
 
 
 
 
 
 
 
 
af8810b
 
 
2f3c64b
af8810b
2f3c64b
 
 
 
 
 
 
af8810b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f3c64b
 
 
6762657
af8810b
 
 
6762657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""FastAPI composition root — wires environment, MCP, and custom endpoints."""

from __future__ import annotations

import os
from threading import Lock

from openenv.core.env_server import create_fastapi_app
from fastapi import Query
from pydantic import BaseModel

from constants import PROJECT_DESCRIPTION, VERSION
from models import CommitmentAction, CommitmentObservation, CommitmentState
from server.environment import CommitmentEnvironment
from server.mcp import router as mcp_router
from server.tasks import get_scenario_ids_grouped

_DEFAULT_SESSION_ID = "default"
_env_store: dict[str, CommitmentEnvironment] = {
    _DEFAULT_SESSION_ID: CommitmentEnvironment(),
}
_env_store_lock = Lock()


def _get_env(session_id: str) -> CommitmentEnvironment:
    """Return a per-session environment instance.

    This avoids cross-user state bleed from a single shared mutable environment.
    Clients can pass ``episode_id`` query param to isolate sessions.
    """
    with _env_store_lock:
        env = _env_store.get(session_id)
        if env is None:
            env = CommitmentEnvironment()
            _env_store[session_id] = env
        return env


class StepPayload(BaseModel):
    action: CommitmentAction

app = create_fastapi_app(
    env=lambda: _get_env(_DEFAULT_SESSION_ID),
    action_cls=CommitmentAction,
    observation_cls=CommitmentObservation,
)

app.title = "CommitmentOS"
app.description = PROJECT_DESCRIPTION
app.version = VERSION

app.routes[:] = [
    r for r in app.routes
    if not (hasattr(r, "path") and r.path in ("/state", "/mcp", "/reset", "/step"))
]


@app.post("/reset")
def reset_episode(
    task_id: str | None = Query(default=None),
    difficulty: str | None = Query(default=None),
    seed: int | None = Query(default=None),
    episode_id: str | None = Query(default=None),
) -> dict[str, object]:
    """Reset endpoint with explicit query-param support.

    The default OpenEnv route did not reliably propagate ``task_id`` from
    query params in this deployment setup, which made scenario selection
    non-deterministic for demos/evaluations.
    """
    session_id = episode_id or _DEFAULT_SESSION_ID
    env = _get_env(session_id)
    obs = env.reset(
        seed=seed,
        episode_id=session_id,
        task_id=task_id,
        difficulty=difficulty,
    )
    return {
        "observation": obs.model_dump(),
        "reward": float(obs.reward),
        "done": bool(obs.done),
        "episode_id": session_id,
    }


@app.post("/step")
def step_episode(
    payload: StepPayload,
    episode_id: str | None = Query(default=None),
) -> dict[str, object]:
    session_id = episode_id or _DEFAULT_SESSION_ID
    env = _get_env(session_id)
    obs = env.step(payload.action)
    return {
        "observation": obs.model_dump(),
        "reward": float(obs.reward),
        "done": bool(obs.done),
        "episode_id": session_id,
    }


@app.get("/state", response_model=CommitmentState)
def get_state(episode_id: str | None = Query(default=None)) -> CommitmentState:
    session_id = episode_id or _DEFAULT_SESSION_ID
    return _get_env(session_id).state


@app.get("/tasks")
def list_tasks() -> dict[str, list[str]]:
    return get_scenario_ids_grouped()


app.include_router(mcp_router)


def main() -> None:
    import uvicorn

    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)


if __name__ == "__main__":
    main()