SystemTruth / train /eval_sweep.py
dakshdoesdev's picture
Add Groq driver, GRPO notebook, and eval-sweep harness
0d3e723
"""Eval sweep — compare policy scores across scenarios for the submission table.
Runs N episodes per scenario per policy against a live sre-gym env and writes
a JSONL summary suitable for the hackathon comparison table.
Supported policies:
- `random` — emit a valid random action each turn
- `heuristic` — the deterministic heuristic from collect_trajectories
- `groq` — Llama-3.3-70B via Groq (uses GROQ_API_KEY)
- `fireworks` — any Fireworks-served model (uses FIREWORKS_API_KEY)
- `anthropic` — any Anthropic model (uses ANTHROPIC_API_KEY)
- `sft_adapter` — a local HF transformers checkpoint (directory path)
The output JSONL schema:
{policy, model, scenario_id, episode_idx, final_score, incident_resolved,
steps, elapsed_s}
Intended usage (Sunday evening, after SFT and/or GRPO has landed):
python train/eval_sweep.py \
--env-url https://dakshdoesdev-sre-gym.hf.space \
--scenarios all \
--episodes-per-scenario 5 \
--policies random,heuristic,groq \
--groq-model llama-3.3-70b-versatile \
--output train/data/eval_sweep.jsonl
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import random
import sys
import time
import uuid
from pathlib import Path
from typing import Any
import httpx
# Reuse trajectory-collection helpers to avoid duplicating the action-parse
# and fallback-action logic.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from train.collect_trajectories import ( # type: ignore # noqa: E402
SYSTEM_PROMPT,
_build_fallback_action,
_build_user_prompt,
_parse_action,
_request_openai_compat_output,
)
from unified_incident_env import UnifiedIncidentAction, UnifiedIncidentEnv # noqa: E402
from unified_incident_env.server.challenge import SCENARIOS # noqa: E402
SUPPORTED_DIFFICULTIES = {"easy", "medium", "hard"}
def _resolve_scenarios(raw: str) -> list[str]:
tokens = [token.strip() for token in raw.split(",") if token.strip()]
scenario_ids: list[str] = []
for token in tokens:
if token == "all":
scenario_ids.extend(SCENARIOS.keys())
elif token in SUPPORTED_DIFFICULTIES:
scenario_ids.extend(
s for s, c in SCENARIOS.items() if c["difficulty"] == token
)
elif token in SCENARIOS:
scenario_ids.append(token)
else:
raise SystemExit(f"Unknown scenario selector: {token}")
seen: set[str] = set()
deduped: list[str] = []
for s in scenario_ids:
if s not in seen:
deduped.append(s)
seen.add(s)
return deduped
def _random_action(observation: Any) -> UnifiedIncidentAction:
# Pick an allowed action, populate minimal required fields randomly.
allowed = observation.allowed_actions or ["query_logs"]
action_type = random.choice(allowed)
services = list(observation.service_health.keys()) or ["database"]
if action_type in {"query_logs", "query_dependencies", "query_deploys",
"rollback_deploy", "restart_service", "isolate_service"}:
return UnifiedIncidentAction(action_type=action_type, service=random.choice(services))
if action_type == "query_metrics":
return UnifiedIncidentAction(
action_type=action_type,
service=random.choice(services),
metric=random.choice(["cpu", "error_rate", "latency"]),
)
if action_type == "run_check":
pending = [c.name for c in observation.checks if not c.passed] or ["end_to_end"]
return UnifiedIncidentAction(action_type=action_type, check_name=pending[0])
if action_type == "submit_hypothesis":
return _build_fallback_action(observation) # reasonable hypothesis shape
return UnifiedIncidentAction(action_type=action_type)
async def _play_episode(
*,
env_url: str,
scenario_id: str,
policy: str,
model: str,
http_client: httpx.AsyncClient | None,
api_key: str | None,
base_url: str | None,
max_tokens: int,
max_retries: int,
) -> dict[str, Any]:
started = time.perf_counter()
async with UnifiedIncidentEnv(base_url=env_url) as env:
obs = (await env.reset(scenario_id=scenario_id, episode_id=str(uuid.uuid4()))).observation
steps = 0
while not obs.done:
fallback = _build_fallback_action(obs)
if policy == "random":
action = _random_action(obs)
elif policy == "heuristic":
action = fallback
elif policy in ("groq", "fireworks"):
prompt = _build_user_prompt(obs)
text = ""
for attempt in range(1, max_retries + 1):
try:
text = await _request_openai_compat_output(
http_client=http_client,
api_key=api_key,
base_url=base_url,
model=model,
prompt=prompt,
max_tokens=max_tokens,
)
if text:
break
except Exception:
await asyncio.sleep(min(2.0 * attempt, 5.0))
parsed = _parse_action(text, obs)
action = parsed or fallback
else:
raise SystemExit(f"Policy {policy} not implemented here; use `policies` flag.")
step = await env.step(action)
obs = step.observation
steps += 1
return {
"scenario_id": scenario_id,
"policy": policy,
"model": model,
"final_score": float(obs.final_score),
"incident_resolved": bool(obs.incident_resolved),
"steps": steps,
"elapsed_s": round(time.perf_counter() - started, 3),
}
async def _run_sweep(args: argparse.Namespace) -> None:
scenarios = _resolve_scenarios(args.scenarios)
policies = [p.strip() for p in args.policies.split(",") if p.strip()]
# Health-probe the env
async with httpx.AsyncClient(timeout=10.0) as probe:
response = await probe.get(f"{args.env_url.rstrip('/')}/health")
response.raise_for_status()
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
if output_path.exists():
output_path.unlink()
groq_http: httpx.AsyncClient | None = None
fireworks_http: httpx.AsyncClient | None = None
if "groq" in policies:
if not args.groq_api_key:
raise SystemExit("GROQ_API_KEY required for groq policy")
groq_http = httpx.AsyncClient(
timeout=httpx.Timeout(60.0),
limits=httpx.Limits(max_connections=args.parallelism * 2),
follow_redirects=True,
)
if "fireworks" in policies:
if not args.fireworks_api_key:
raise SystemExit("FIREWORKS_API_KEY required for fireworks policy")
fireworks_http = httpx.AsyncClient(
timeout=httpx.Timeout(90.0),
limits=httpx.Limits(max_connections=args.parallelism * 2),
follow_redirects=True,
)
semaphore = asyncio.Semaphore(args.parallelism)
async def run_one(policy: str, scenario: str, idx: int) -> None:
async with semaphore:
model_map = {
"groq": args.groq_model,
"fireworks": args.fireworks_model,
"random": "random",
"heuristic": "heuristic",
}
http, key, base = None, None, None
if policy == "groq":
http, key, base = groq_http, args.groq_api_key, args.groq_base_url
elif policy == "fireworks":
http, key, base = fireworks_http, args.fireworks_api_key, args.fireworks_base_url
try:
record = await _play_episode(
env_url=args.env_url,
scenario_id=scenario,
policy=policy,
model=model_map.get(policy, policy),
http_client=http,
api_key=key,
base_url=base,
max_tokens=args.max_tokens,
max_retries=args.max_retries,
)
record["episode_idx"] = idx
except Exception as exc:
record = {
"scenario_id": scenario,
"policy": policy,
"episode_idx": idx,
"error": f"{type(exc).__name__}: {exc}",
}
with output_path.open("a") as f:
f.write(json.dumps(record) + "\n")
score = record.get("final_score")
resolved = record.get("incident_resolved")
print(
f"[{policy:<10}] {scenario:<40} ep={idx} "
f"score={f'{score:.3f}' if score is not None else 'err'} "
f"resolved={resolved}",
file=sys.stderr,
flush=True,
)
tasks = []
for policy in policies:
for scenario in scenarios:
for idx in range(args.episodes_per_scenario):
tasks.append(run_one(policy, scenario, idx))
try:
await asyncio.gather(*tasks)
finally:
if groq_http is not None:
await groq_http.aclose()
if fireworks_http is not None:
await fireworks_http.aclose()
# Print a summary table per policy per scenario.
records = [json.loads(l) for l in output_path.read_text().splitlines() if l.strip()]
by_policy: dict[str, list[dict[str, Any]]] = {}
for r in records:
by_policy.setdefault(r["policy"], []).append(r)
print("\n=== SUMMARY ===", file=sys.stderr)
for policy, rs in by_policy.items():
scored = [r for r in rs if "final_score" in r]
if not scored:
print(f" {policy}: all episodes errored ({len(rs)} errors)", file=sys.stderr)
continue
mean = sum(r["final_score"] for r in scored) / len(scored)
resolved = sum(1 for r in scored if r.get("incident_resolved")) / len(scored)
print(
f" {policy:<12} n={len(scored):<3} mean_score={mean:.3f} resolved={resolved:.1%}",
file=sys.stderr,
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
parser.add_argument("--env-url", required=True)
parser.add_argument("--scenarios", required=True)
parser.add_argument("--policies", required=True, help="comma-separated: random,heuristic,groq,fireworks")
parser.add_argument("--episodes-per-scenario", type=int, default=5)
parser.add_argument("--parallelism", type=int, default=3)
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--max-retries", type=int, default=3)
parser.add_argument("--output", required=True)
parser.add_argument("--groq-api-key", default=os.getenv("GROQ_API_KEY"))
parser.add_argument("--groq-base-url", default=os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1"))
parser.add_argument("--groq-model", default="llama-3.3-70b-versatile")
parser.add_argument("--fireworks-api-key", default=os.getenv("FIREWORKS_API_KEY"))
parser.add_argument(
"--fireworks-base-url",
default=os.getenv("FIREWORKS_BASE_URL", "https://api.fireworks.ai/inference/v1"),
)
parser.add_argument("--fireworks-model", default="accounts/fireworks/models/llama-v3p3-70b-instruct")
return parser.parse_args()
def main() -> None:
args = _parse_args()
asyncio.run(_run_sweep(args))
if __name__ == "__main__":
main()