codex / codex_pool.py
sarveshpatel's picture
Upload 10 files
19491c5 verified
Raw
History Blame Contribute Delete
12.5 kB
"""
Warm Codex app-server — ONE persistent process, reused across requests, to remove
per-request cold start (spawn + initialize + cold thread setup). A lock serializes
turns; this also avoids two processes refreshing (and rotating) auth.json at once.
Opt-in via CODEX_ENGINE=pool. Same event contract as codex_engine.run_turn:
{"type": "delta"|"reasoning", "text": ...}
{"type": "final", "text", "thread_id", "usage", "images"}
Falls back safely: on any process/protocol failure the process is killed and the
next request respawns it (and re-reads auth.json, e.g. after a re-upload).
"""
import asyncio
import json
import os
from pathlib import Path
from typing import AsyncIterator, Optional
from codex_engine import CodexError
_AUTH_DEAD = ("session has ended", "log in again", "failed to refresh token")
_STREAM_LIMIT = 16 * 1024 * 1024
class WarmCodex:
def __init__(self, codex_bin: str, codex_home: str, read_timeout: float):
self.bin = codex_bin
self.home = codex_home
self.read_timeout = read_timeout
self.proc: Optional[asyncio.subprocess.Process] = None
self.lock = asyncio.Lock() # one turn at a time
self._pending: dict[int, asyncio.Future] = {}
self._queue: Optional[asyncio.Queue] = None # current turn's notifications
self._reader: Optional[asyncio.Task] = None
self._next_id = 100
self._threads: dict[str, str] = {} # session_id -> thread_id (warm)
# -- process lifecycle ---------------------------------------------------
async def _ensure(self) -> None:
if self.proc and self.proc.returncode is None:
return
self.proc = await asyncio.create_subprocess_exec(
self.bin, "app-server",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
cwd="/tmp",
env={**os.environ, "CODEX_HOME": self.home},
limit=_STREAM_LIMIT,
)
self._pending = {}
self._threads = {} # in-memory threads are gone after a respawn
self._reader = asyncio.create_task(self._read_loop())
await self._request("initialize", {
"clientInfo": {"name": "codex-as-api", "title": "Codex as API",
"version": "1.0.0"},
"capabilities": {"experimentalApi": True, "requestAttestation": False},
})
await self._send({"method": "initialized"})
async def _kill(self) -> None:
if self._reader and not self._reader.done():
self._reader.cancel()
if self.proc:
try:
self.proc.kill()
except Exception:
pass
self.proc = None
for f in self._pending.values():
if not f.done():
f.set_exception(CodexError("app-server restarted"))
self._pending = {}
async def _read_loop(self) -> None:
try:
while self.proc and self.proc.stdout:
line = await self.proc.stdout.readline()
if not line:
break
line = line.strip()
if not line:
continue
try:
msg = json.loads(line)
except json.JSONDecodeError:
continue
mid = msg.get("id")
if mid is not None and ("result" in msg or "error" in msg):
fut = self._pending.pop(mid, None)
if fut and not fut.done():
fut.set_result(msg)
elif msg.get("method"):
if self._queue is not None:
self._queue.put_nowait(msg)
except asyncio.CancelledError:
pass
except Exception:
pass
finally:
for f in self._pending.values():
if not f.done():
f.set_exception(CodexError("app-server closed"))
self._pending = {}
# -- io ------------------------------------------------------------------
async def _send(self, obj: dict) -> None:
if not self.proc or not self.proc.stdin:
raise CodexError("app-server not running")
self.proc.stdin.write((json.dumps(obj) + "\n").encode("utf-8"))
await self.proc.stdin.drain()
async def _request(self, method: str, params: dict) -> dict:
self._next_id += 1
rid = self._next_id
fut: asyncio.Future = asyncio.get_event_loop().create_future()
self._pending[rid] = fut
await self._send({"method": method, "id": rid, "params": params})
try:
msg = await asyncio.wait_for(fut, timeout=self.read_timeout)
except asyncio.TimeoutError:
raise CodexError(f"app-server timed out on {method}")
if "error" in msg:
raise CodexError(f"app-server error on {method}: {msg['error']}")
return msg.get("result", {})
# -- one turn ------------------------------------------------------------
async def run(self, *, prompt, workspace, thread_id, sandbox, model,
effort, input_items, output_schema, session_id,
developer_instructions=None) -> AsyncIterator[dict]:
async with self.lock:
try:
await self._ensure()
except (FileNotFoundError, OSError) as e:
raise CodexError(f"could not start codex app-server ('{self.bin}'): {e}")
self._queue = asyncio.Queue()
try:
tid = (self._threads.get(session_id) if session_id else None) or thread_id
resolved = None
if tid:
try:
rp = {"threadId": tid, "cwd": str(workspace),
"approvalPolicy": "never", "sandbox": sandbox,
"excludeTurns": True}
if developer_instructions:
rp["developerInstructions"] = developer_instructions
res = await self._request("thread/resume", rp)
resolved = (res.get("thread") or {}).get("id")
except CodexError:
resolved = None
if not resolved:
params = {"cwd": str(workspace), "approvalPolicy": "never",
"sandbox": sandbox}
if model:
params["model"] = model
if developer_instructions:
params["developerInstructions"] = developer_instructions
res = await self._request("thread/start", params)
resolved = (res.get("thread") or {}).get("id")
if not resolved:
raise CodexError("app-server did not return a thread id")
if session_id:
self._threads[session_id] = resolved
turn_input = input_items or [
{"type": "text", "text": prompt, "text_elements": []}]
tp = {"threadId": resolved, "input": turn_input}
if model:
tp["model"] = model
if effort:
tp["effort"] = effort
if output_schema:
tp["outputSchema"] = output_schema
await self._request("turn/start", tp) # returns quickly
delta_parts: list[str] = []
final_text: Optional[str] = None
images: list[str] = []
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0,
"prompt_tokens_details": {"cached_tokens": 0}}
while True:
try:
msg = await asyncio.wait_for(self._queue.get(),
timeout=self.read_timeout)
except asyncio.TimeoutError:
raise CodexError("app-server timed out during the turn")
method = msg.get("method")
if method == "item/agentMessage/delta":
d = (msg.get("params") or {}).get("delta", "")
if d:
delta_parts.append(d)
yield {"type": "delta", "text": d}
elif method in ("item/reasoning/textDelta",
"item/reasoning/summaryTextDelta"):
rd = (msg.get("params") or {}).get("delta") \
or (msg.get("params") or {}).get("text", "")
if rd:
yield {"type": "reasoning", "text": rd}
elif method == "item/completed":
item = (msg.get("params") or {}).get("item", {})
it = item.get("type")
if it == "agentMessage" and item.get("text") is not None:
final_text = item["text"]
elif it == "imageGeneration":
p = item.get("savedPath") or item.get("result")
if p:
images.append(p)
elif method == "thread/tokenUsage/updated":
last = ((msg.get("params") or {}).get("tokenUsage") or {}).get("last", {})
usage = {
"prompt_tokens": last.get("inputTokens", 0) or 0,
"completion_tokens": last.get("outputTokens", 0) or 0,
"total_tokens": last.get("totalTokens", 0) or 0,
"prompt_tokens_details": {
"cached_tokens": last.get("cachedInputTokens", 0) or 0},
}
elif method == "error":
err = (msg.get("params") or {}).get("error", {}) or {}
blob = f"{err.get('message','')} {err.get('additionalDetails') or ''}".lower()
if any(s in blob for s in _AUTH_DEAD):
await self._kill() # respawn next time (re-read auth)
raise CodexError(
"Codex login expired (session ended). Re-upload a "
"fresh auth.json to /data/.codex/auth.json.")
elif method == "turn/completed":
break
elif msg.get("id") is not None and method is not None:
# server->client request (approval); decline to avoid hang
await self._send({"id": msg["id"],
"error": {"code": -32601,
"message": "approvals disabled"}})
text = final_text if final_text is not None else "".join(delta_parts)
yield {"type": "final", "text": text, "thread_id": resolved,
"usage": usage, "images": images}
except CodexError:
raise
except Exception as e:
await self._kill()
raise CodexError(f"warm pool error: {e}")
finally:
self._queue = None
_POOL: Optional[WarmCodex] = None
def _get_pool(codex_bin, codex_home, read_timeout) -> WarmCodex:
global _POOL
if _POOL is None:
_POOL = WarmCodex(codex_bin, codex_home, read_timeout)
return _POOL
async def run_turn_pool(*, codex_bin, codex_home, prompt, workspace, thread_id,
sandbox, model, read_timeout, effort=None,
input_items=None, output_schema=None,
session_id=None, developer_instructions=None) -> AsyncIterator[dict]:
pool = _get_pool(codex_bin, codex_home, read_timeout)
async for evt in pool.run(prompt=prompt, workspace=workspace, thread_id=thread_id,
sandbox=sandbox, model=model, effort=effort,
input_items=input_items, output_schema=output_schema,
session_id=session_id,
developer_instructions=developer_instructions):
yield evt