Spaces:
Running
Running
| """FastAPI pool server exposing sre-gym Triage through a lease-based contract. | |
| The rollout agent drives the env with this lifecycle per episode: | |
| allocate(task_key) -> {lease_id} | |
| reset(lease_id, task_meta, run_ctx) | |
| exec_tool(lease_id, tool_call) -> observation_string # repeated | |
| evaluate(lease_id) -> score | |
| close(lease_id) | |
| We wrap a ``UnifiedIncidentEnvironment`` instance per lease. Lease state is | |
| guarded by per-lease ``asyncio.Lock`` so 8-way concurrent rollouts on the same | |
| server stay consistent. Idle leases are reaped after ``COLISEUM_LEASE_TTL_S`` | |
| seconds. | |
| Run standalone: | |
| uvicorn coliseum.server:app --host 0.0.0.0 --port 8100 | |
| Environment variables: | |
| COLISEUM_LEASE_TTL_S default 600 | |
| COLISEUM_REAPER_PERIOD_S default 30 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| # Make the sibling package importable when launched via uvicorn from anywhere. | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| from unified_incident_env.models import UnifiedIncidentAction # noqa: E402 | |
| from unified_incident_env.server.challenge import SCENARIOS # noqa: E402 | |
| from unified_incident_env.server.environment import UnifiedIncidentEnvironment # noqa: E402 | |
| logger = logging.getLogger("coliseum.server") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| LEASE_TTL_S = float(os.getenv("COLISEUM_LEASE_TTL_S", "600")) | |
| REAPER_PERIOD_S = float(os.getenv("COLISEUM_REAPER_PERIOD_S", "30")) | |
| class Lease: | |
| lease_id: str | |
| task_key: str | |
| env: UnifiedIncidentEnvironment | |
| lock: asyncio.Lock = field(default_factory=asyncio.Lock) | |
| last_touch: float = field(default_factory=time.time) | |
| reset_done: bool = False | |
| final_score: float | None = None | |
| def touch(self) -> None: | |
| self.last_touch = time.time() | |
| class AllocateRequest(BaseModel): | |
| task_key: str | |
| request_id: str | None = None | |
| class LeaseRequest(BaseModel): | |
| lease_id: str | |
| class ResetRequest(BaseModel): | |
| lease_id: str | |
| task_meta: dict[str, Any] = Field(default_factory=dict) | |
| run_ctx: dict[str, Any] = Field(default_factory=dict) | |
| task_timeouts: dict[str, Any] | None = None | |
| class ToolCall(BaseModel): | |
| name: str | |
| arguments: dict[str, Any] = Field(default_factory=dict) | |
| class ExecToolRequest(BaseModel): | |
| lease_id: str | |
| tool_call: ToolCall | |
| class ArenaPool: | |
| """Per-process lease registry for parallel sre-gym rollouts.""" | |
| def __init__(self) -> None: | |
| self._leases: dict[str, Lease] = {} | |
| self._dict_lock = asyncio.Lock() | |
| async def allocate(self, task_key: str) -> Lease: | |
| if task_key not in SCENARIOS: | |
| raise ValueError(f"Unknown task_key {task_key!r}; known: {list(SCENARIOS)}") | |
| env = UnifiedIncidentEnvironment() | |
| lease = Lease(lease_id=str(uuid.uuid4()), task_key=task_key, env=env) | |
| async with self._dict_lock: | |
| self._leases[lease.lease_id] = lease | |
| logger.info("allocate: lease=%s task=%s", lease.lease_id, task_key) | |
| return lease | |
| async def get(self, lease_id: str) -> Lease: | |
| async with self._dict_lock: | |
| lease = self._leases.get(lease_id) | |
| if lease is None: | |
| raise KeyError(f"Unknown lease {lease_id}") | |
| lease.touch() | |
| return lease | |
| async def close(self, lease_id: str) -> bool: | |
| async with self._dict_lock: | |
| lease = self._leases.pop(lease_id, None) | |
| if lease is None: | |
| return False | |
| logger.info("close: lease=%s task=%s", lease_id, lease.task_key) | |
| return True | |
| async def reap(self) -> int: | |
| now = time.time() | |
| stale: list[str] = [] | |
| async with self._dict_lock: | |
| for lease_id, lease in list(self._leases.items()): | |
| if now - lease.last_touch > LEASE_TTL_S: | |
| stale.append(lease_id) | |
| for lease_id in stale: | |
| self._leases.pop(lease_id, None) | |
| if stale: | |
| logger.info("reaper: evicted %d stale lease(s)", len(stale)) | |
| return len(stale) | |
| def active_count(self) -> int: | |
| return len(self._leases) | |
| pool = ArenaPool() | |
| async def _reaper_loop() -> None: | |
| while True: | |
| try: | |
| await pool.reap() | |
| except Exception: | |
| logger.exception("reaper loop tick failed") | |
| await asyncio.sleep(REAPER_PERIOD_S) | |
| async def lifespan(app: FastAPI): | |
| task = asyncio.create_task(_reaper_loop()) | |
| try: | |
| yield | |
| finally: | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| app = FastAPI(title="sre-gym coliseum pool server", lifespan=lifespan) | |
| def _observation_string(obs: Any, *, reward: float | None = None) -> str: | |
| """Render a UnifiedIncidentObservation as the single string a rollout | |
| agent expects from exec_tool.""" | |
| payload = { | |
| "tick": obs.tick_count, | |
| "workflow_stage": obs.workflow_stage, | |
| "last_action_result": obs.last_action_result, | |
| "tool_output": obs.tool_output, | |
| "failure_type": obs.failure_type, | |
| "why_failed": obs.why_failed, | |
| "loop_warning": obs.loop_warning, | |
| "reward": reward, | |
| "checks": [{"name": c.name, "passed": c.passed} for c in obs.checks], | |
| "active_alerts": [{"service": a.service, "severity": a.severity, "message": a.message} for a in obs.active_alerts], | |
| "noise_alerts": [{"service": a.service, "severity": a.severity, "message": a.message} for a in obs.noise_alerts], | |
| "service_health": {name: s.status for name, s in obs.service_health.items()}, | |
| "allowed_actions": obs.allowed_actions, | |
| "required_fields_by_action": obs.required_fields_by_action, | |
| "blast_radius": obs.blast_radius, | |
| "final_score": obs.final_score, | |
| "done": obs.done, | |
| "prompt_text": obs.prompt_text, | |
| } | |
| return json.dumps(payload, separators=(",", ":")) | |
| async def healthz() -> dict[str, Any]: | |
| return {"ok": True, "active_leases": pool.active_count(), "scenarios": list(SCENARIOS.keys())} | |
| async def allocate(request: AllocateRequest) -> dict[str, Any]: | |
| try: | |
| lease = await pool.allocate(request.task_key) | |
| except ValueError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| return {"ok": True, "lease_id": lease.lease_id, "task_key": lease.task_key, "request_id": request.request_id} | |
| async def heartbeat(request: LeaseRequest) -> dict[str, Any]: | |
| try: | |
| await pool.get(request.lease_id) | |
| except KeyError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| return {"ok": True} | |
| async def reset(request: ResetRequest) -> dict[str, Any]: | |
| try: | |
| lease = await pool.get(request.lease_id) | |
| except KeyError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| async with lease.lock: | |
| scenario_id = request.task_meta.get("scenario_id") or lease.task_key | |
| obs = lease.env.reset(scenario_id=scenario_id) | |
| lease.reset_done = True | |
| lease.final_score = None | |
| return {"ok": True, "observation": _observation_string(obs)} | |
| async def exec_tool(request: ExecToolRequest) -> dict[str, Any]: | |
| try: | |
| lease = await pool.get(request.lease_id) | |
| except KeyError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| if not lease.reset_done: | |
| return {"ok": False, "error": "reset has not been called for this lease"} | |
| action_kwargs = {"action_type": request.tool_call.name, **request.tool_call.arguments} | |
| try: | |
| action = UnifiedIncidentAction(**action_kwargs) | |
| except Exception as exc: | |
| # Return the validation error to the rollout agent as a no-op | |
| # observation so training sees the failure signal without crashing. | |
| return {"ok": True, "observation": json.dumps({"error": f"invalid action: {exc}", "tool_call": request.tool_call.model_dump()})} | |
| async with lease.lock: | |
| obs = lease.env.step(action) | |
| lease.final_score = float(obs.final_score) | |
| return {"ok": True, "observation": _observation_string(obs, reward=float(obs.reward))} | |
| async def evaluate(request: LeaseRequest) -> dict[str, Any]: | |
| try: | |
| lease = await pool.get(request.lease_id) | |
| except KeyError as exc: | |
| return {"ok": False, "error": str(exc)} | |
| score = lease.final_score if lease.final_score is not None else float(lease.env.state.final_score) | |
| return {"ok": True, "score": score} | |
| async def close(request: LeaseRequest) -> dict[str, Any]: | |
| closed = await pool.close(request.lease_id) | |
| if not closed: | |
| return {"ok": False, "error": f"Unknown lease {request.lease_id}"} | |
| return {"ok": True} | |