SystemTruth / coliseum /client.py
Madhav189's picture
data: 30 expert episodes + training notebook
f337985
"""HTTP client for the Coliseum pool server.
Coliseum exposes the sre-gym Triage tier through a lease-based contract
(``allocate / heartbeat / reset / exec_tool / evaluate / close``) so the
parallel-rollout side of any GRPO trainer can drive the env without holding
an in-process ``UnifiedIncidentEnvironment`` per worker.
The contract shape is the lease-pool pattern that's standard in
parallel-rollout RL frameworks; nothing here is bound to a specific trainer
implementation.
Environment variables:
COLISEUM_BASE_URL required, e.g. http://127.0.0.1:8100
COLISEUM_HTTP_MAX_RETRIES default 10
COLISEUM_ALLOCATE_MAX_RETRIES default 10
COLISEUM_EVALUATE_MAX_RETRIES default 1
COLISEUM_CLOSE_MAX_RETRIES default 3
COLISEUM_EXEC_TOOL_MAX_RETRIES default 3
COLISEUM_HTTP_TIMEOUT_S default 30
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Any
import httpx
logger = logging.getLogger(__name__)
def create_arena_client() -> "ArenaClient":
base_url = os.getenv("COLISEUM_BASE_URL", "")
if not base_url:
raise RuntimeError("COLISEUM_BASE_URL is empty.")
return ArenaClient(base_url)
async def _post(
url: str,
payload: dict[str, Any],
*,
max_retries: int,
timeout_s: float,
) -> dict[str, Any]:
last_exc: Exception | None = None
for attempt in range(max_retries):
try:
async with httpx.AsyncClient(timeout=timeout_s) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
return response.json()
except Exception as exc: # retry all transport errors
last_exc = exc
wait = min(2 ** attempt * 0.25, 5.0)
logger.debug("POST %s failed (attempt %d/%d): %s", url, attempt + 1, max_retries, exc)
await asyncio.sleep(wait)
raise RuntimeError(f"POST {url} failed after {max_retries} retries: {last_exc}")
class ArenaClient:
"""Lease-based HTTP client for the Coliseum pool server."""
def __init__(self, base_url: str) -> None:
self.base_url = base_url.rstrip("/")
self.default_max_retries = int(os.getenv("COLISEUM_HTTP_MAX_RETRIES", "10"))
self.allocate_max_retries = int(os.getenv("COLISEUM_ALLOCATE_MAX_RETRIES", "10"))
self.evaluate_max_retries = int(os.getenv("COLISEUM_EVALUATE_MAX_RETRIES", "1"))
self.close_max_retries = int(os.getenv("COLISEUM_CLOSE_MAX_RETRIES", "3"))
self.exec_tool_max_retries = int(os.getenv("COLISEUM_EXEC_TOOL_MAX_RETRIES", "3"))
self.timeout_s = float(os.getenv("COLISEUM_HTTP_TIMEOUT_S", "30"))
async def allocate(self, task_key: str, request_id: str | None = None) -> dict[str, Any]:
out = await _post(
f"{self.base_url}/allocate",
{"task_key": task_key, "request_id": request_id},
max_retries=self.allocate_max_retries,
timeout_s=self.timeout_s,
)
if not out.get("ok", False):
raise RuntimeError(f"allocate failed: {out}")
return out
async def heartbeat(self, lease_id: str) -> None:
out = await _post(
f"{self.base_url}/heartbeat",
{"lease_id": lease_id},
max_retries=self.default_max_retries,
timeout_s=self.timeout_s,
)
if not out.get("ok", False):
raise RuntimeError(f"heartbeat failed: {out}")
async def reset(
self,
lease_id: str,
task_meta: dict[str, Any],
run_ctx: dict[str, Any],
task_timeouts: dict[str, Any] | None = None,
) -> dict[str, Any]:
out = await _post(
f"{self.base_url}/reset",
{
"lease_id": lease_id,
"task_meta": task_meta,
"run_ctx": run_ctx,
"task_timeouts": task_timeouts,
},
max_retries=self.default_max_retries,
timeout_s=self.timeout_s,
)
if not out.get("ok", False):
raise RuntimeError(f"reset failed: {out}")
return out
async def exec_tool(self, lease_id: str, tool_name: str, arguments: dict[str, Any]) -> str:
out = await _post(
f"{self.base_url}/exec_tool",
{
"lease_id": lease_id,
"tool_call": {"name": tool_name, "arguments": arguments},
},
max_retries=self.exec_tool_max_retries,
timeout_s=self.timeout_s,
)
if not out.get("ok", False):
raise RuntimeError(f"exec_tool failed: {out}")
return str(out.get("observation", ""))
async def evaluate(self, lease_id: str) -> float:
out = await _post(
f"{self.base_url}/evaluate",
{"lease_id": lease_id},
max_retries=self.evaluate_max_retries,
timeout_s=self.timeout_s,
)
if not out.get("ok", False):
raise RuntimeError(f"evaluate failed: {out}")
return float(out.get("score", 0.0))
async def close(self, lease_id: str) -> None:
try:
out = await _post(
f"{self.base_url}/close",
{"lease_id": lease_id},
max_retries=self.close_max_retries,
timeout_s=self.timeout_s,
)
except Exception as exc:
if "Unknown lease" in str(exc):
logger.debug("close(%s): lease already gone", lease_id)
return
raise
if not out.get("ok", False):
error_msg = str(out.get("error", ""))
if "Unknown" in error_msg and "lease" in error_msg.lower():
logger.debug("close(%s): lease already gone", lease_id)
return
raise RuntimeError(f"close failed: {out}")