Spaces:
Running
Running
| """HTTP client for the Coliseum pool server. | |
| Coliseum exposes the sre-gym Triage tier through a lease-based contract | |
| (``allocate / heartbeat / reset / exec_tool / evaluate / close``) so the | |
| parallel-rollout side of any GRPO trainer can drive the env without holding | |
| an in-process ``UnifiedIncidentEnvironment`` per worker. | |
| The contract shape is the lease-pool pattern that's standard in | |
| parallel-rollout RL frameworks; nothing here is bound to a specific trainer | |
| implementation. | |
| Environment variables: | |
| COLISEUM_BASE_URL required, e.g. http://127.0.0.1:8100 | |
| COLISEUM_HTTP_MAX_RETRIES default 10 | |
| COLISEUM_ALLOCATE_MAX_RETRIES default 10 | |
| COLISEUM_EVALUATE_MAX_RETRIES default 1 | |
| COLISEUM_CLOSE_MAX_RETRIES default 3 | |
| COLISEUM_EXEC_TOOL_MAX_RETRIES default 3 | |
| COLISEUM_HTTP_TIMEOUT_S default 30 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| from typing import Any | |
| import httpx | |
| logger = logging.getLogger(__name__) | |
| def create_arena_client() -> "ArenaClient": | |
| base_url = os.getenv("COLISEUM_BASE_URL", "") | |
| if not base_url: | |
| raise RuntimeError("COLISEUM_BASE_URL is empty.") | |
| return ArenaClient(base_url) | |
| async def _post( | |
| url: str, | |
| payload: dict[str, Any], | |
| *, | |
| max_retries: int, | |
| timeout_s: float, | |
| ) -> dict[str, Any]: | |
| last_exc: Exception | None = None | |
| for attempt in range(max_retries): | |
| try: | |
| async with httpx.AsyncClient(timeout=timeout_s) as client: | |
| response = await client.post(url, json=payload) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as exc: # retry all transport errors | |
| last_exc = exc | |
| wait = min(2 ** attempt * 0.25, 5.0) | |
| logger.debug("POST %s failed (attempt %d/%d): %s", url, attempt + 1, max_retries, exc) | |
| await asyncio.sleep(wait) | |
| raise RuntimeError(f"POST {url} failed after {max_retries} retries: {last_exc}") | |
| class ArenaClient: | |
| """Lease-based HTTP client for the Coliseum pool server.""" | |
| def __init__(self, base_url: str) -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self.default_max_retries = int(os.getenv("COLISEUM_HTTP_MAX_RETRIES", "10")) | |
| self.allocate_max_retries = int(os.getenv("COLISEUM_ALLOCATE_MAX_RETRIES", "10")) | |
| self.evaluate_max_retries = int(os.getenv("COLISEUM_EVALUATE_MAX_RETRIES", "1")) | |
| self.close_max_retries = int(os.getenv("COLISEUM_CLOSE_MAX_RETRIES", "3")) | |
| self.exec_tool_max_retries = int(os.getenv("COLISEUM_EXEC_TOOL_MAX_RETRIES", "3")) | |
| self.timeout_s = float(os.getenv("COLISEUM_HTTP_TIMEOUT_S", "30")) | |
| async def allocate(self, task_key: str, request_id: str | None = None) -> dict[str, Any]: | |
| out = await _post( | |
| f"{self.base_url}/allocate", | |
| {"task_key": task_key, "request_id": request_id}, | |
| max_retries=self.allocate_max_retries, | |
| timeout_s=self.timeout_s, | |
| ) | |
| if not out.get("ok", False): | |
| raise RuntimeError(f"allocate failed: {out}") | |
| return out | |
| async def heartbeat(self, lease_id: str) -> None: | |
| out = await _post( | |
| f"{self.base_url}/heartbeat", | |
| {"lease_id": lease_id}, | |
| max_retries=self.default_max_retries, | |
| timeout_s=self.timeout_s, | |
| ) | |
| if not out.get("ok", False): | |
| raise RuntimeError(f"heartbeat failed: {out}") | |
| async def reset( | |
| self, | |
| lease_id: str, | |
| task_meta: dict[str, Any], | |
| run_ctx: dict[str, Any], | |
| task_timeouts: dict[str, Any] | None = None, | |
| ) -> dict[str, Any]: | |
| out = await _post( | |
| f"{self.base_url}/reset", | |
| { | |
| "lease_id": lease_id, | |
| "task_meta": task_meta, | |
| "run_ctx": run_ctx, | |
| "task_timeouts": task_timeouts, | |
| }, | |
| max_retries=self.default_max_retries, | |
| timeout_s=self.timeout_s, | |
| ) | |
| if not out.get("ok", False): | |
| raise RuntimeError(f"reset failed: {out}") | |
| return out | |
| async def exec_tool(self, lease_id: str, tool_name: str, arguments: dict[str, Any]) -> str: | |
| out = await _post( | |
| f"{self.base_url}/exec_tool", | |
| { | |
| "lease_id": lease_id, | |
| "tool_call": {"name": tool_name, "arguments": arguments}, | |
| }, | |
| max_retries=self.exec_tool_max_retries, | |
| timeout_s=self.timeout_s, | |
| ) | |
| if not out.get("ok", False): | |
| raise RuntimeError(f"exec_tool failed: {out}") | |
| return str(out.get("observation", "")) | |
| async def evaluate(self, lease_id: str) -> float: | |
| out = await _post( | |
| f"{self.base_url}/evaluate", | |
| {"lease_id": lease_id}, | |
| max_retries=self.evaluate_max_retries, | |
| timeout_s=self.timeout_s, | |
| ) | |
| if not out.get("ok", False): | |
| raise RuntimeError(f"evaluate failed: {out}") | |
| return float(out.get("score", 0.0)) | |
| async def close(self, lease_id: str) -> None: | |
| try: | |
| out = await _post( | |
| f"{self.base_url}/close", | |
| {"lease_id": lease_id}, | |
| max_retries=self.close_max_retries, | |
| timeout_s=self.timeout_s, | |
| ) | |
| except Exception as exc: | |
| if "Unknown lease" in str(exc): | |
| logger.debug("close(%s): lease already gone", lease_id) | |
| return | |
| raise | |
| if not out.get("ok", False): | |
| error_msg = str(out.get("error", "")) | |
| if "Unknown" in error_msg and "lease" in error_msg.lower(): | |
| logger.debug("close(%s): lease already gone", lease_id) | |
| return | |
| raise RuntimeError(f"close failed: {out}") | |