"""HTTP client for the Coliseum pool server. Coliseum exposes the sre-gym Triage tier through a lease-based contract (``allocate / heartbeat / reset / exec_tool / evaluate / close``) so the parallel-rollout side of any GRPO trainer can drive the env without holding an in-process ``UnifiedIncidentEnvironment`` per worker. The contract shape is the lease-pool pattern that's standard in parallel-rollout RL frameworks; nothing here is bound to a specific trainer implementation. Environment variables: COLISEUM_BASE_URL required, e.g. http://127.0.0.1:8100 COLISEUM_HTTP_MAX_RETRIES default 10 COLISEUM_ALLOCATE_MAX_RETRIES default 10 COLISEUM_EVALUATE_MAX_RETRIES default 1 COLISEUM_CLOSE_MAX_RETRIES default 3 COLISEUM_EXEC_TOOL_MAX_RETRIES default 3 COLISEUM_HTTP_TIMEOUT_S default 30 """ from __future__ import annotations import asyncio import logging import os from typing import Any import httpx logger = logging.getLogger(__name__) def create_arena_client() -> "ArenaClient": base_url = os.getenv("COLISEUM_BASE_URL", "") if not base_url: raise RuntimeError("COLISEUM_BASE_URL is empty.") return ArenaClient(base_url) async def _post( url: str, payload: dict[str, Any], *, max_retries: int, timeout_s: float, ) -> dict[str, Any]: last_exc: Exception | None = None for attempt in range(max_retries): try: async with httpx.AsyncClient(timeout=timeout_s) as client: response = await client.post(url, json=payload) response.raise_for_status() return response.json() except Exception as exc: # retry all transport errors last_exc = exc wait = min(2 ** attempt * 0.25, 5.0) logger.debug("POST %s failed (attempt %d/%d): %s", url, attempt + 1, max_retries, exc) await asyncio.sleep(wait) raise RuntimeError(f"POST {url} failed after {max_retries} retries: {last_exc}") class ArenaClient: """Lease-based HTTP client for the Coliseum pool server.""" def __init__(self, base_url: str) -> None: self.base_url = base_url.rstrip("/") self.default_max_retries = int(os.getenv("COLISEUM_HTTP_MAX_RETRIES", "10")) self.allocate_max_retries = int(os.getenv("COLISEUM_ALLOCATE_MAX_RETRIES", "10")) self.evaluate_max_retries = int(os.getenv("COLISEUM_EVALUATE_MAX_RETRIES", "1")) self.close_max_retries = int(os.getenv("COLISEUM_CLOSE_MAX_RETRIES", "3")) self.exec_tool_max_retries = int(os.getenv("COLISEUM_EXEC_TOOL_MAX_RETRIES", "3")) self.timeout_s = float(os.getenv("COLISEUM_HTTP_TIMEOUT_S", "30")) async def allocate(self, task_key: str, request_id: str | None = None) -> dict[str, Any]: out = await _post( f"{self.base_url}/allocate", {"task_key": task_key, "request_id": request_id}, max_retries=self.allocate_max_retries, timeout_s=self.timeout_s, ) if not out.get("ok", False): raise RuntimeError(f"allocate failed: {out}") return out async def heartbeat(self, lease_id: str) -> None: out = await _post( f"{self.base_url}/heartbeat", {"lease_id": lease_id}, max_retries=self.default_max_retries, timeout_s=self.timeout_s, ) if not out.get("ok", False): raise RuntimeError(f"heartbeat failed: {out}") async def reset( self, lease_id: str, task_meta: dict[str, Any], run_ctx: dict[str, Any], task_timeouts: dict[str, Any] | None = None, ) -> dict[str, Any]: out = await _post( f"{self.base_url}/reset", { "lease_id": lease_id, "task_meta": task_meta, "run_ctx": run_ctx, "task_timeouts": task_timeouts, }, max_retries=self.default_max_retries, timeout_s=self.timeout_s, ) if not out.get("ok", False): raise RuntimeError(f"reset failed: {out}") return out async def exec_tool(self, lease_id: str, tool_name: str, arguments: dict[str, Any]) -> str: out = await _post( f"{self.base_url}/exec_tool", { "lease_id": lease_id, "tool_call": {"name": tool_name, "arguments": arguments}, }, max_retries=self.exec_tool_max_retries, timeout_s=self.timeout_s, ) if not out.get("ok", False): raise RuntimeError(f"exec_tool failed: {out}") return str(out.get("observation", "")) async def evaluate(self, lease_id: str) -> float: out = await _post( f"{self.base_url}/evaluate", {"lease_id": lease_id}, max_retries=self.evaluate_max_retries, timeout_s=self.timeout_s, ) if not out.get("ok", False): raise RuntimeError(f"evaluate failed: {out}") return float(out.get("score", 0.0)) async def close(self, lease_id: str) -> None: try: out = await _post( f"{self.base_url}/close", {"lease_id": lease_id}, max_retries=self.close_max_retries, timeout_s=self.timeout_s, ) except Exception as exc: if "Unknown lease" in str(exc): logger.debug("close(%s): lease already gone", lease_id) return raise if not out.get("ok", False): error_msg = str(out.get("error", "")) if "Unknown" in error_msg and "lease" in error_msg.lower(): logger.debug("close(%s): lease already gone", lease_id) return raise RuntimeError(f"close failed: {out}")