""" Warm Codex app-server — ONE persistent process, reused across requests, to remove per-request cold start (spawn + initialize + cold thread setup). A lock serializes turns; this also avoids two processes refreshing (and rotating) auth.json at once. Opt-in via CODEX_ENGINE=pool. Same event contract as codex_engine.run_turn: {"type": "delta"|"reasoning", "text": ...} {"type": "final", "text", "thread_id", "usage", "images"} Falls back safely: on any process/protocol failure the process is killed and the next request respawns it (and re-reads auth.json, e.g. after a re-upload). """ import asyncio import json import os from pathlib import Path from typing import AsyncIterator, Optional from codex_engine import CodexError _AUTH_DEAD = ("session has ended", "log in again", "failed to refresh token") _STREAM_LIMIT = 16 * 1024 * 1024 class WarmCodex: def __init__(self, codex_bin: str, codex_home: str, read_timeout: float): self.bin = codex_bin self.home = codex_home self.read_timeout = read_timeout self.proc: Optional[asyncio.subprocess.Process] = None self.lock = asyncio.Lock() # one turn at a time self._pending: dict[int, asyncio.Future] = {} self._queue: Optional[asyncio.Queue] = None # current turn's notifications self._reader: Optional[asyncio.Task] = None self._next_id = 100 self._threads: dict[str, str] = {} # session_id -> thread_id (warm) # -- process lifecycle --------------------------------------------------- async def _ensure(self) -> None: if self.proc and self.proc.returncode is None: return self.proc = await asyncio.create_subprocess_exec( self.bin, "app-server", stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, cwd="/tmp", env={**os.environ, "CODEX_HOME": self.home}, limit=_STREAM_LIMIT, ) self._pending = {} self._threads = {} # in-memory threads are gone after a respawn self._reader = asyncio.create_task(self._read_loop()) await self._request("initialize", { "clientInfo": {"name": "codex-as-api", "title": "Codex as API", "version": "1.0.0"}, "capabilities": {"experimentalApi": True, "requestAttestation": False}, }) await self._send({"method": "initialized"}) async def _kill(self) -> None: if self._reader and not self._reader.done(): self._reader.cancel() if self.proc: try: self.proc.kill() except Exception: pass self.proc = None for f in self._pending.values(): if not f.done(): f.set_exception(CodexError("app-server restarted")) self._pending = {} async def _read_loop(self) -> None: try: while self.proc and self.proc.stdout: line = await self.proc.stdout.readline() if not line: break line = line.strip() if not line: continue try: msg = json.loads(line) except json.JSONDecodeError: continue mid = msg.get("id") if mid is not None and ("result" in msg or "error" in msg): fut = self._pending.pop(mid, None) if fut and not fut.done(): fut.set_result(msg) elif msg.get("method"): if self._queue is not None: self._queue.put_nowait(msg) except asyncio.CancelledError: pass except Exception: pass finally: for f in self._pending.values(): if not f.done(): f.set_exception(CodexError("app-server closed")) self._pending = {} # -- io ------------------------------------------------------------------ async def _send(self, obj: dict) -> None: if not self.proc or not self.proc.stdin: raise CodexError("app-server not running") self.proc.stdin.write((json.dumps(obj) + "\n").encode("utf-8")) await self.proc.stdin.drain() async def _request(self, method: str, params: dict) -> dict: self._next_id += 1 rid = self._next_id fut: asyncio.Future = asyncio.get_event_loop().create_future() self._pending[rid] = fut await self._send({"method": method, "id": rid, "params": params}) try: msg = await asyncio.wait_for(fut, timeout=self.read_timeout) except asyncio.TimeoutError: raise CodexError(f"app-server timed out on {method}") if "error" in msg: raise CodexError(f"app-server error on {method}: {msg['error']}") return msg.get("result", {}) # -- one turn ------------------------------------------------------------ async def run(self, *, prompt, workspace, thread_id, sandbox, model, effort, input_items, output_schema, session_id, developer_instructions=None) -> AsyncIterator[dict]: async with self.lock: try: await self._ensure() except (FileNotFoundError, OSError) as e: raise CodexError(f"could not start codex app-server ('{self.bin}'): {e}") self._queue = asyncio.Queue() try: tid = (self._threads.get(session_id) if session_id else None) or thread_id resolved = None if tid: try: rp = {"threadId": tid, "cwd": str(workspace), "approvalPolicy": "never", "sandbox": sandbox, "excludeTurns": True} if developer_instructions: rp["developerInstructions"] = developer_instructions res = await self._request("thread/resume", rp) resolved = (res.get("thread") or {}).get("id") except CodexError: resolved = None if not resolved: params = {"cwd": str(workspace), "approvalPolicy": "never", "sandbox": sandbox} if model: params["model"] = model if developer_instructions: params["developerInstructions"] = developer_instructions res = await self._request("thread/start", params) resolved = (res.get("thread") or {}).get("id") if not resolved: raise CodexError("app-server did not return a thread id") if session_id: self._threads[session_id] = resolved turn_input = input_items or [ {"type": "text", "text": prompt, "text_elements": []}] tp = {"threadId": resolved, "input": turn_input} if model: tp["model"] = model if effort: tp["effort"] = effort if output_schema: tp["outputSchema"] = output_schema await self._request("turn/start", tp) # returns quickly delta_parts: list[str] = [] final_text: Optional[str] = None images: list[str] = [] usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "prompt_tokens_details": {"cached_tokens": 0}} while True: try: msg = await asyncio.wait_for(self._queue.get(), timeout=self.read_timeout) except asyncio.TimeoutError: raise CodexError("app-server timed out during the turn") method = msg.get("method") if method == "item/agentMessage/delta": d = (msg.get("params") or {}).get("delta", "") if d: delta_parts.append(d) yield {"type": "delta", "text": d} elif method in ("item/reasoning/textDelta", "item/reasoning/summaryTextDelta"): rd = (msg.get("params") or {}).get("delta") \ or (msg.get("params") or {}).get("text", "") if rd: yield {"type": "reasoning", "text": rd} elif method == "item/completed": item = (msg.get("params") or {}).get("item", {}) it = item.get("type") if it == "agentMessage" and item.get("text") is not None: final_text = item["text"] elif it == "imageGeneration": p = item.get("savedPath") or item.get("result") if p: images.append(p) elif method == "thread/tokenUsage/updated": last = ((msg.get("params") or {}).get("tokenUsage") or {}).get("last", {}) usage = { "prompt_tokens": last.get("inputTokens", 0) or 0, "completion_tokens": last.get("outputTokens", 0) or 0, "total_tokens": last.get("totalTokens", 0) or 0, "prompt_tokens_details": { "cached_tokens": last.get("cachedInputTokens", 0) or 0}, } elif method == "error": err = (msg.get("params") or {}).get("error", {}) or {} blob = f"{err.get('message','')} {err.get('additionalDetails') or ''}".lower() if any(s in blob for s in _AUTH_DEAD): await self._kill() # respawn next time (re-read auth) raise CodexError( "Codex login expired (session ended). Re-upload a " "fresh auth.json to /data/.codex/auth.json.") elif method == "turn/completed": break elif msg.get("id") is not None and method is not None: # server->client request (approval); decline to avoid hang await self._send({"id": msg["id"], "error": {"code": -32601, "message": "approvals disabled"}}) text = final_text if final_text is not None else "".join(delta_parts) yield {"type": "final", "text": text, "thread_id": resolved, "usage": usage, "images": images} except CodexError: raise except Exception as e: await self._kill() raise CodexError(f"warm pool error: {e}") finally: self._queue = None _POOL: Optional[WarmCodex] = None def _get_pool(codex_bin, codex_home, read_timeout) -> WarmCodex: global _POOL if _POOL is None: _POOL = WarmCodex(codex_bin, codex_home, read_timeout) return _POOL async def run_turn_pool(*, codex_bin, codex_home, prompt, workspace, thread_id, sandbox, model, read_timeout, effort=None, input_items=None, output_schema=None, session_id=None, developer_instructions=None) -> AsyncIterator[dict]: pool = _get_pool(codex_bin, codex_home, read_timeout) async for evt in pool.run(prompt=prompt, workspace=workspace, thread_id=thread_id, sandbox=sandbox, model=model, effort=effort, input_items=input_items, output_schema=output_schema, session_id=session_id, developer_instructions=developer_instructions): yield evt