File size: 8,984 Bytes
7698d12
 
297cffc
 
 
7698d12
d81f3f0
 
 
 
 
7698d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297cffc
 
 
 
 
 
 
 
 
 
 
698641d
 
 
297cffc
698641d
297cffc
 
 
 
7698d12
 
 
 
 
 
 
297cffc
7698d12
 
 
d81f3f0
1a3a8ee
d81f3f0
1a3a8ee
 
 
 
 
d81f3f0
1a3a8ee
 
 
 
 
 
 
 
 
 
 
 
 
d81f3f0
1a3a8ee
d81f3f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7698d12
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""FastAPI application for the OpenCode Environment.

Mirrors the ``e2b_desktop`` pattern: pass a ``gradio_builder`` to
``create_app`` and let OpenEnv handle the Gradio mount (including the
HF-Space-compatible ``/web`` path). No manual ``gr.mount_gradio_app``.

Also mounts a bespoke SSE endpoint at ``GET /rollouts/{rollout_id}/events``
that multiplexes opencode serve's ``/event`` stream with our proxy's
per-turn frames. MCP tools don't support streaming; this gives the UI
and interactive clients a live feed.

Usage::

    # Development:
    E2B_API_KEY=... uv run uvicorn server.app:app --reload

    # Via uv project script:
    E2B_API_KEY=... uv run --project . server

    # Docker:
    docker run -p 8000:8000 -e E2B_API_KEY=... opencode-openenv
"""

from __future__ import annotations

import os

try:
    from openenv.core.env_server.http_server import create_app
    from openenv.core.env_server.mcp_types import (
        CallToolAction,
        CallToolObservation,
    )

    from .opencode_environment import OpenCodeEnvironment
    from .gradio_ui import opencode_ui_builder
except ImportError:
    from openenv.core.env_server.http_server import create_app
    from openenv.core.env_server.mcp_types import (
        CallToolAction,
        CallToolObservation,
    )

    from server.opencode_environment import OpenCodeEnvironment
    from server.gradio_ui import opencode_ui_builder


def _custom_gradio_builder(
    web_manager,
    action_fields,
    metadata,
    is_chat_env,
    title,
    quick_start_md,
):
    """Callback invoked by ``create_app`` to build our custom Gradio UI.

    We ignore ``web_manager`` (its public API is ``reset_environment`` /
    ``step_environment`` / ``connect_websocket`` — not an env instance) and
    hand the UI the env class directly, matching e2b_desktop's pattern.
    """
    return opencode_ui_builder(env_factory=OpenCodeEnvironment)


# Enable OpenEnv's built-in Gradio mounting at the standard /web path.
os.environ.setdefault("ENABLE_WEB_INTERFACE", "true")

app = create_app(
    OpenCodeEnvironment,
    CallToolAction,
    CallToolObservation,
    env_name="opencode_env",
    max_concurrent_envs=int(os.getenv("MAX_CONCURRENT_ENVS", "4")),
    gradio_builder=_custom_gradio_builder,
)


def _find_active_environment(request):
    """Locate a currently-active OpenCodeEnvironment instance.

    ``create_app`` stores per-session envs internally; we don't have a
    public accessor, so we poke at ``app.state`` attributes that match
    OpenEnv's conventions. As a last resort we create a fresh env —
    fine for single-worker Spaces because registries live in-process
    and the default env is idle until a tool is called.
    """
    # Most recent "env" attribute on app.state that looks like ours.
    for attr_name in ("env_cache", "envs", "environments", "_envs"):
        cache = getattr(app.state, attr_name, None)
        if cache:
            try:
                if isinstance(cache, dict):
                    return next(iter(cache.values()))
                if isinstance(cache, (list, tuple)):
                    return cache[-1]
            except Exception:
                pass
    # Fallback — make a new env. Safe because the SSE endpoint only
    # needs the _registry dict, which we then look up rollout_id in.
    try:
        return OpenCodeEnvironment()
    except Exception:
        return None


@app.get("/rollouts/{rollout_id}/events")
async def rollout_events(rollout_id: str):
    """Server-Sent Events feed for a rollout started via ``start_rollout``.

    Merges two streams:

    1. opencode serve's ``GET /event`` (session-level events: message
       parts, tool calls, idle/abort markers) — forwarded as-is.
    2. our proxy's ``proxy_trace.jsonl`` in the sandbox (per-turn
       LLM turns + logprobs) — tailed and emitted as
       ``{type: "proxy.turn", turn, tokens, logprobs, finish_reason, ...}``.

    Terminates on a final ``{"type": "rollout.done", ...}`` frame once the
    session has idled or erred.
    """
    import asyncio
    import json as _json

    from starlette.responses import StreamingResponse

    env = _find_active_environment(None)
    if env is None:
        return StreamingResponse(
            iter([f"data: {_json.dumps({'type': 'error', 'reason': 'env not found'})}\n\n"]),
            media_type="text/event-stream",
        )

    registry = getattr(env, "_registry", None)
    handle = registry.get(rollout_id) if registry else None
    if handle is None:
        async def _single_error():
            yield (
                "data: "
                + _json.dumps({"type": "error", "rollout_id": rollout_id, "reason": "unknown rollout"})
                + "\n\n"
            )
        return StreamingResponse(_single_error(), media_type="text/event-stream")

    async def _gen():
        # Wait briefly for the serve client to be wired by the worker.
        for _ in range(60):
            if handle.session is not None and getattr(handle.session, "serve_client", None):
                break
            if handle.is_done():
                break
            await asyncio.sleep(0.25)
        session = handle.session
        if session is None:
            yield (
                "data: "
                + _json.dumps({
                    "type": "error",
                    "rollout_id": rollout_id,
                    "reason": "session never created",
                    "detail": handle.error,
                })
                + "\n\n"
            )
            return

        sandbox = session.sandbox
        proxy_trace_path = session._proxy_trace_path
        serve_client = getattr(session, "serve_client", None)

        # Task A: forward opencode serve events.
        serve_q: asyncio.Queue = asyncio.Queue()

        async def forward_serve():
            if serve_client is None:
                return
            try:
                async for ev in serve_client.astream_events():
                    await serve_q.put({"source": "serve", **ev})
                    if handle.is_done():
                        break
            except Exception as exc:  # noqa: BLE001
                await serve_q.put({"source": "serve", "type": "error", "reason": str(exc)})
            finally:
                await serve_q.put(None)

        # Task B: tail proxy trace file (incremental) from the sandbox.
        async def tail_proxy():
            last_len = 0
            while not handle.is_done():
                try:
                    if proxy_trace_path:
                        content = sandbox.read_text(proxy_trace_path) or ""
                        if len(content) > last_len:
                            new = content[last_len:]
                            last_len = len(content)
                            for line in new.splitlines():
                                line = line.strip()
                                if not line:
                                    continue
                                try:
                                    turn = _json.loads(line)
                                except Exception:
                                    continue
                                await serve_q.put({
                                    "source": "proxy",
                                    "type": "proxy.turn",
                                    "turn": turn.get("turn"),
                                    "finish_reason": turn.get("finish_reason"),
                                    "n_tokens": len(turn.get("completion_tokens") or []),
                                    "first_tokens": (turn.get("completion_tokens") or [])[:6],
                                    "first_logps": (turn.get("per_token_logps") or [])[:6],
                                    "latency_s": turn.get("latency_s"),
                                })
                except Exception:
                    pass
                await asyncio.sleep(1.0)

        t_serve = asyncio.create_task(forward_serve())
        t_proxy = asyncio.create_task(tail_proxy())

        try:
            while True:
                try:
                    ev = await asyncio.wait_for(serve_q.get(), timeout=1.0)
                except asyncio.TimeoutError:
                    ev = None
                if ev is None:
                    if handle.is_done():
                        break
                    continue
                yield "data: " + _json.dumps(ev) + "\n\n"
        finally:
            t_serve.cancel()
            t_proxy.cancel()

        yield "data: " + _json.dumps({
            "source": "server",
            "type": "rollout.done",
            "rollout_id": rollout_id,
            "error": handle.error,
        }) + "\n\n"

    return StreamingResponse(_gen(), media_type="text/event-stream")


def main(host: str = "0.0.0.0", port: int = 8000) -> None:
    import uvicorn

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()