"""Eval sweep — compare policy scores across scenarios for the submission table. Runs N episodes per scenario per policy against a live sre-gym env and writes a JSONL summary suitable for the hackathon comparison table. Supported policies: - `random` — emit a valid random action each turn - `heuristic` — the deterministic heuristic from collect_trajectories - `groq` — Llama-3.3-70B via Groq (uses GROQ_API_KEY) - `fireworks` — any Fireworks-served model (uses FIREWORKS_API_KEY) - `anthropic` — any Anthropic model (uses ANTHROPIC_API_KEY) - `sft_adapter` — a local HF transformers checkpoint (directory path) The output JSONL schema: {policy, model, scenario_id, episode_idx, final_score, incident_resolved, steps, elapsed_s} Intended usage (Sunday evening, after SFT and/or GRPO has landed): python train/eval_sweep.py \ --env-url https://dakshdoesdev-sre-gym.hf.space \ --scenarios all \ --episodes-per-scenario 5 \ --policies random,heuristic,groq \ --groq-model llama-3.3-70b-versatile \ --output train/data/eval_sweep.jsonl """ from __future__ import annotations import argparse import asyncio import json import os import random import sys import time import uuid from pathlib import Path from typing import Any import httpx # Reuse trajectory-collection helpers to avoid duplicating the action-parse # and fallback-action logic. sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from train.collect_trajectories import ( # type: ignore # noqa: E402 SYSTEM_PROMPT, _build_fallback_action, _build_user_prompt, _parse_action, _request_openai_compat_output, ) from unified_incident_env import UnifiedIncidentAction, UnifiedIncidentEnv # noqa: E402 from unified_incident_env.server.challenge import SCENARIOS # noqa: E402 SUPPORTED_DIFFICULTIES = {"easy", "medium", "hard"} def _resolve_scenarios(raw: str) -> list[str]: tokens = [token.strip() for token in raw.split(",") if token.strip()] scenario_ids: list[str] = [] for token in tokens: if token == "all": scenario_ids.extend(SCENARIOS.keys()) elif token in SUPPORTED_DIFFICULTIES: scenario_ids.extend( s for s, c in SCENARIOS.items() if c["difficulty"] == token ) elif token in SCENARIOS: scenario_ids.append(token) else: raise SystemExit(f"Unknown scenario selector: {token}") seen: set[str] = set() deduped: list[str] = [] for s in scenario_ids: if s not in seen: deduped.append(s) seen.add(s) return deduped def _random_action(observation: Any) -> UnifiedIncidentAction: # Pick an allowed action, populate minimal required fields randomly. allowed = observation.allowed_actions or ["query_logs"] action_type = random.choice(allowed) services = list(observation.service_health.keys()) or ["database"] if action_type in {"query_logs", "query_dependencies", "query_deploys", "rollback_deploy", "restart_service", "isolate_service"}: return UnifiedIncidentAction(action_type=action_type, service=random.choice(services)) if action_type == "query_metrics": return UnifiedIncidentAction( action_type=action_type, service=random.choice(services), metric=random.choice(["cpu", "error_rate", "latency"]), ) if action_type == "run_check": pending = [c.name for c in observation.checks if not c.passed] or ["end_to_end"] return UnifiedIncidentAction(action_type=action_type, check_name=pending[0]) if action_type == "submit_hypothesis": return _build_fallback_action(observation) # reasonable hypothesis shape return UnifiedIncidentAction(action_type=action_type) async def _play_episode( *, env_url: str, scenario_id: str, policy: str, model: str, http_client: httpx.AsyncClient | None, api_key: str | None, base_url: str | None, max_tokens: int, max_retries: int, ) -> dict[str, Any]: started = time.perf_counter() async with UnifiedIncidentEnv(base_url=env_url) as env: obs = (await env.reset(scenario_id=scenario_id, episode_id=str(uuid.uuid4()))).observation steps = 0 while not obs.done: fallback = _build_fallback_action(obs) if policy == "random": action = _random_action(obs) elif policy == "heuristic": action = fallback elif policy in ("groq", "fireworks"): prompt = _build_user_prompt(obs) text = "" for attempt in range(1, max_retries + 1): try: text = await _request_openai_compat_output( http_client=http_client, api_key=api_key, base_url=base_url, model=model, prompt=prompt, max_tokens=max_tokens, ) if text: break except Exception: await asyncio.sleep(min(2.0 * attempt, 5.0)) parsed = _parse_action(text, obs) action = parsed or fallback else: raise SystemExit(f"Policy {policy} not implemented here; use `policies` flag.") step = await env.step(action) obs = step.observation steps += 1 return { "scenario_id": scenario_id, "policy": policy, "model": model, "final_score": float(obs.final_score), "incident_resolved": bool(obs.incident_resolved), "steps": steps, "elapsed_s": round(time.perf_counter() - started, 3), } async def _run_sweep(args: argparse.Namespace) -> None: scenarios = _resolve_scenarios(args.scenarios) policies = [p.strip() for p in args.policies.split(",") if p.strip()] # Health-probe the env async with httpx.AsyncClient(timeout=10.0) as probe: response = await probe.get(f"{args.env_url.rstrip('/')}/health") response.raise_for_status() output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) if output_path.exists(): output_path.unlink() groq_http: httpx.AsyncClient | None = None fireworks_http: httpx.AsyncClient | None = None if "groq" in policies: if not args.groq_api_key: raise SystemExit("GROQ_API_KEY required for groq policy") groq_http = httpx.AsyncClient( timeout=httpx.Timeout(60.0), limits=httpx.Limits(max_connections=args.parallelism * 2), follow_redirects=True, ) if "fireworks" in policies: if not args.fireworks_api_key: raise SystemExit("FIREWORKS_API_KEY required for fireworks policy") fireworks_http = httpx.AsyncClient( timeout=httpx.Timeout(90.0), limits=httpx.Limits(max_connections=args.parallelism * 2), follow_redirects=True, ) semaphore = asyncio.Semaphore(args.parallelism) async def run_one(policy: str, scenario: str, idx: int) -> None: async with semaphore: model_map = { "groq": args.groq_model, "fireworks": args.fireworks_model, "random": "random", "heuristic": "heuristic", } http, key, base = None, None, None if policy == "groq": http, key, base = groq_http, args.groq_api_key, args.groq_base_url elif policy == "fireworks": http, key, base = fireworks_http, args.fireworks_api_key, args.fireworks_base_url try: record = await _play_episode( env_url=args.env_url, scenario_id=scenario, policy=policy, model=model_map.get(policy, policy), http_client=http, api_key=key, base_url=base, max_tokens=args.max_tokens, max_retries=args.max_retries, ) record["episode_idx"] = idx except Exception as exc: record = { "scenario_id": scenario, "policy": policy, "episode_idx": idx, "error": f"{type(exc).__name__}: {exc}", } with output_path.open("a") as f: f.write(json.dumps(record) + "\n") score = record.get("final_score") resolved = record.get("incident_resolved") print( f"[{policy:<10}] {scenario:<40} ep={idx} " f"score={f'{score:.3f}' if score is not None else 'err'} " f"resolved={resolved}", file=sys.stderr, flush=True, ) tasks = [] for policy in policies: for scenario in scenarios: for idx in range(args.episodes_per_scenario): tasks.append(run_one(policy, scenario, idx)) try: await asyncio.gather(*tasks) finally: if groq_http is not None: await groq_http.aclose() if fireworks_http is not None: await fireworks_http.aclose() # Print a summary table per policy per scenario. records = [json.loads(l) for l in output_path.read_text().splitlines() if l.strip()] by_policy: dict[str, list[dict[str, Any]]] = {} for r in records: by_policy.setdefault(r["policy"], []).append(r) print("\n=== SUMMARY ===", file=sys.stderr) for policy, rs in by_policy.items(): scored = [r for r in rs if "final_score" in r] if not scored: print(f" {policy}: all episodes errored ({len(rs)} errors)", file=sys.stderr) continue mean = sum(r["final_score"] for r in scored) / len(scored) resolved = sum(1 for r in scored if r.get("incident_resolved")) / len(scored) print( f" {policy:<12} n={len(scored):<3} mean_score={mean:.3f} resolved={resolved:.1%}", file=sys.stderr, ) def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) parser.add_argument("--env-url", required=True) parser.add_argument("--scenarios", required=True) parser.add_argument("--policies", required=True, help="comma-separated: random,heuristic,groq,fireworks") parser.add_argument("--episodes-per-scenario", type=int, default=5) parser.add_argument("--parallelism", type=int, default=3) parser.add_argument("--max-tokens", type=int, default=256) parser.add_argument("--max-retries", type=int, default=3) parser.add_argument("--output", required=True) parser.add_argument("--groq-api-key", default=os.getenv("GROQ_API_KEY")) parser.add_argument("--groq-base-url", default=os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")) parser.add_argument("--groq-model", default="llama-3.3-70b-versatile") parser.add_argument("--fireworks-api-key", default=os.getenv("FIREWORKS_API_KEY")) parser.add_argument( "--fireworks-base-url", default=os.getenv("FIREWORKS_BASE_URL", "https://api.fireworks.ai/inference/v1"), ) parser.add_argument("--fireworks-model", default="accounts/fireworks/models/llama-v3p3-70b-instruct") return parser.parse_args() def main() -> None: args = _parse_args() asyncio.run(_run_sweep(args)) if __name__ == "__main__": main()