Spaces:
Running
Running
| """ | |
| Warm Codex app-server — ONE persistent process, reused across requests, to remove | |
| per-request cold start (spawn + initialize + cold thread setup). A lock serializes | |
| turns; this also avoids two processes refreshing (and rotating) auth.json at once. | |
| Opt-in via CODEX_ENGINE=pool. Same event contract as codex_engine.run_turn: | |
| {"type": "delta"|"reasoning", "text": ...} | |
| {"type": "final", "text", "thread_id", "usage", "images"} | |
| Falls back safely: on any process/protocol failure the process is killed and the | |
| next request respawns it (and re-reads auth.json, e.g. after a re-upload). | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import AsyncIterator, Optional | |
| from codex_engine import CodexError | |
| _AUTH_DEAD = ("session has ended", "log in again", "failed to refresh token") | |
| _STREAM_LIMIT = 16 * 1024 * 1024 | |
| class WarmCodex: | |
| def __init__(self, codex_bin: str, codex_home: str, read_timeout: float): | |
| self.bin = codex_bin | |
| self.home = codex_home | |
| self.read_timeout = read_timeout | |
| self.proc: Optional[asyncio.subprocess.Process] = None | |
| self.lock = asyncio.Lock() # one turn at a time | |
| self._pending: dict[int, asyncio.Future] = {} | |
| self._queue: Optional[asyncio.Queue] = None # current turn's notifications | |
| self._reader: Optional[asyncio.Task] = None | |
| self._next_id = 100 | |
| self._threads: dict[str, str] = {} # session_id -> thread_id (warm) | |
| # -- process lifecycle --------------------------------------------------- | |
| async def _ensure(self) -> None: | |
| if self.proc and self.proc.returncode is None: | |
| return | |
| self.proc = await asyncio.create_subprocess_exec( | |
| self.bin, "app-server", | |
| stdin=asyncio.subprocess.PIPE, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.DEVNULL, | |
| cwd="/tmp", | |
| env={**os.environ, "CODEX_HOME": self.home}, | |
| limit=_STREAM_LIMIT, | |
| ) | |
| self._pending = {} | |
| self._threads = {} # in-memory threads are gone after a respawn | |
| self._reader = asyncio.create_task(self._read_loop()) | |
| await self._request("initialize", { | |
| "clientInfo": {"name": "codex-as-api", "title": "Codex as API", | |
| "version": "1.0.0"}, | |
| "capabilities": {"experimentalApi": True, "requestAttestation": False}, | |
| }) | |
| await self._send({"method": "initialized"}) | |
| async def _kill(self) -> None: | |
| if self._reader and not self._reader.done(): | |
| self._reader.cancel() | |
| if self.proc: | |
| try: | |
| self.proc.kill() | |
| except Exception: | |
| pass | |
| self.proc = None | |
| for f in self._pending.values(): | |
| if not f.done(): | |
| f.set_exception(CodexError("app-server restarted")) | |
| self._pending = {} | |
| async def _read_loop(self) -> None: | |
| try: | |
| while self.proc and self.proc.stdout: | |
| line = await self.proc.stdout.readline() | |
| if not line: | |
| break | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| msg = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| mid = msg.get("id") | |
| if mid is not None and ("result" in msg or "error" in msg): | |
| fut = self._pending.pop(mid, None) | |
| if fut and not fut.done(): | |
| fut.set_result(msg) | |
| elif msg.get("method"): | |
| if self._queue is not None: | |
| self._queue.put_nowait(msg) | |
| except asyncio.CancelledError: | |
| pass | |
| except Exception: | |
| pass | |
| finally: | |
| for f in self._pending.values(): | |
| if not f.done(): | |
| f.set_exception(CodexError("app-server closed")) | |
| self._pending = {} | |
| # -- io ------------------------------------------------------------------ | |
| async def _send(self, obj: dict) -> None: | |
| if not self.proc or not self.proc.stdin: | |
| raise CodexError("app-server not running") | |
| self.proc.stdin.write((json.dumps(obj) + "\n").encode("utf-8")) | |
| await self.proc.stdin.drain() | |
| async def _request(self, method: str, params: dict) -> dict: | |
| self._next_id += 1 | |
| rid = self._next_id | |
| fut: asyncio.Future = asyncio.get_event_loop().create_future() | |
| self._pending[rid] = fut | |
| await self._send({"method": method, "id": rid, "params": params}) | |
| try: | |
| msg = await asyncio.wait_for(fut, timeout=self.read_timeout) | |
| except asyncio.TimeoutError: | |
| raise CodexError(f"app-server timed out on {method}") | |
| if "error" in msg: | |
| raise CodexError(f"app-server error on {method}: {msg['error']}") | |
| return msg.get("result", {}) | |
| # -- one turn ------------------------------------------------------------ | |
| async def run(self, *, prompt, workspace, thread_id, sandbox, model, | |
| effort, input_items, output_schema, session_id, | |
| developer_instructions=None) -> AsyncIterator[dict]: | |
| async with self.lock: | |
| try: | |
| await self._ensure() | |
| except (FileNotFoundError, OSError) as e: | |
| raise CodexError(f"could not start codex app-server ('{self.bin}'): {e}") | |
| self._queue = asyncio.Queue() | |
| try: | |
| tid = (self._threads.get(session_id) if session_id else None) or thread_id | |
| resolved = None | |
| if tid: | |
| try: | |
| rp = {"threadId": tid, "cwd": str(workspace), | |
| "approvalPolicy": "never", "sandbox": sandbox, | |
| "excludeTurns": True} | |
| if developer_instructions: | |
| rp["developerInstructions"] = developer_instructions | |
| res = await self._request("thread/resume", rp) | |
| resolved = (res.get("thread") or {}).get("id") | |
| except CodexError: | |
| resolved = None | |
| if not resolved: | |
| params = {"cwd": str(workspace), "approvalPolicy": "never", | |
| "sandbox": sandbox} | |
| if model: | |
| params["model"] = model | |
| if developer_instructions: | |
| params["developerInstructions"] = developer_instructions | |
| res = await self._request("thread/start", params) | |
| resolved = (res.get("thread") or {}).get("id") | |
| if not resolved: | |
| raise CodexError("app-server did not return a thread id") | |
| if session_id: | |
| self._threads[session_id] = resolved | |
| turn_input = input_items or [ | |
| {"type": "text", "text": prompt, "text_elements": []}] | |
| tp = {"threadId": resolved, "input": turn_input} | |
| if model: | |
| tp["model"] = model | |
| if effort: | |
| tp["effort"] = effort | |
| if output_schema: | |
| tp["outputSchema"] = output_schema | |
| await self._request("turn/start", tp) # returns quickly | |
| delta_parts: list[str] = [] | |
| final_text: Optional[str] = None | |
| images: list[str] = [] | |
| usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, | |
| "prompt_tokens_details": {"cached_tokens": 0}} | |
| while True: | |
| try: | |
| msg = await asyncio.wait_for(self._queue.get(), | |
| timeout=self.read_timeout) | |
| except asyncio.TimeoutError: | |
| raise CodexError("app-server timed out during the turn") | |
| method = msg.get("method") | |
| if method == "item/agentMessage/delta": | |
| d = (msg.get("params") or {}).get("delta", "") | |
| if d: | |
| delta_parts.append(d) | |
| yield {"type": "delta", "text": d} | |
| elif method in ("item/reasoning/textDelta", | |
| "item/reasoning/summaryTextDelta"): | |
| rd = (msg.get("params") or {}).get("delta") \ | |
| or (msg.get("params") or {}).get("text", "") | |
| if rd: | |
| yield {"type": "reasoning", "text": rd} | |
| elif method == "item/completed": | |
| item = (msg.get("params") or {}).get("item", {}) | |
| it = item.get("type") | |
| if it == "agentMessage" and item.get("text") is not None: | |
| final_text = item["text"] | |
| elif it == "imageGeneration": | |
| p = item.get("savedPath") or item.get("result") | |
| if p: | |
| images.append(p) | |
| elif method == "thread/tokenUsage/updated": | |
| last = ((msg.get("params") or {}).get("tokenUsage") or {}).get("last", {}) | |
| usage = { | |
| "prompt_tokens": last.get("inputTokens", 0) or 0, | |
| "completion_tokens": last.get("outputTokens", 0) or 0, | |
| "total_tokens": last.get("totalTokens", 0) or 0, | |
| "prompt_tokens_details": { | |
| "cached_tokens": last.get("cachedInputTokens", 0) or 0}, | |
| } | |
| elif method == "error": | |
| err = (msg.get("params") or {}).get("error", {}) or {} | |
| blob = f"{err.get('message','')} {err.get('additionalDetails') or ''}".lower() | |
| if any(s in blob for s in _AUTH_DEAD): | |
| await self._kill() # respawn next time (re-read auth) | |
| raise CodexError( | |
| "Codex login expired (session ended). Re-upload a " | |
| "fresh auth.json to /data/.codex/auth.json.") | |
| elif method == "turn/completed": | |
| break | |
| elif msg.get("id") is not None and method is not None: | |
| # server->client request (approval); decline to avoid hang | |
| await self._send({"id": msg["id"], | |
| "error": {"code": -32601, | |
| "message": "approvals disabled"}}) | |
| text = final_text if final_text is not None else "".join(delta_parts) | |
| yield {"type": "final", "text": text, "thread_id": resolved, | |
| "usage": usage, "images": images} | |
| except CodexError: | |
| raise | |
| except Exception as e: | |
| await self._kill() | |
| raise CodexError(f"warm pool error: {e}") | |
| finally: | |
| self._queue = None | |
| _POOL: Optional[WarmCodex] = None | |
| def _get_pool(codex_bin, codex_home, read_timeout) -> WarmCodex: | |
| global _POOL | |
| if _POOL is None: | |
| _POOL = WarmCodex(codex_bin, codex_home, read_timeout) | |
| return _POOL | |
| async def run_turn_pool(*, codex_bin, codex_home, prompt, workspace, thread_id, | |
| sandbox, model, read_timeout, effort=None, | |
| input_items=None, output_schema=None, | |
| session_id=None, developer_instructions=None) -> AsyncIterator[dict]: | |
| pool = _get_pool(codex_bin, codex_home, read_timeout) | |
| async for evt in pool.run(prompt=prompt, workspace=workspace, thread_id=thread_id, | |
| sandbox=sandbox, model=model, effort=effort, | |
| input_items=input_items, output_schema=output_schema, | |
| session_id=session_id, | |
| developer_instructions=developer_instructions): | |
| yield evt | |