"""Baseline inference script for OpenEnv-Sentinel. Drives an LLM agent through all 3 SRE incident triage tasks. Environment variables: ENV_URL - Sentinel environment server URL (e.g. http://localhost:8000) API_BASE_URL - LLM API base URL (default: https://router.huggingface.co/v1) MODEL_NAME - Model or deployment name (default: openai/gpt-oss-120b:novita) HF_TOKEN - Hugging Face token (used as API key) LOCAL_IMAGE_NAME - Docker image name when using from_docker_image() (optional) """ import asyncio import json import os import re import sys import time from openai import OpenAI # ── configuration ─────────────────────────────────────────────────── # Environment server URL (where the Sentinel env is running) ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") # LLM configuration (aligned with official OpenEnv inference examples) API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:novita") HF_TOKEN = os.getenv("HF_TOKEN") # Optional — if you use from_docker_image(): LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") # API key: prefer API_KEY, fall back to HF_TOKEN API_KEY = os.getenv("API_KEY") or HF_TOKEN or os.getenv("OPENAI_API_KEY", "") TASK_TIMEOUT = 360 # 6 minutes per task MAX_PARSE_RETRIES = 3 MAX_COMPLETION_TOKENS = 16384 # reasoning models need room for chain-of-thought SYSTEM_PROMPT = """You are an expert SRE agent triaging a production incident. You have access to diagnostic tools. Respond with ONLY a single JSON object — no markdown, no explanation, no extra text. Available tools: - get_service_status: {"tool_name": "get_service_status", "parameters": {"service": ""}} - query_logs: {"tool_name": "query_logs", "parameters": {"service": "", "query": ""}} - query_metrics: {"tool_name": "query_metrics", "parameters": {"service": "", "metric": ""}} - get_dependency_map: {"tool_name": "get_dependency_map", "parameters": {"service": ""}} - consult_runbook: {"tool_name": "consult_runbook", "parameters": {"topic": ""}} - check_recent_changes: {"tool_name": "check_recent_changes", "parameters": {"service": ""}} - submit_resolution: {"tool_name": "submit_resolution", "parameters": {"root_cause": "", "affected_service": "", "recommendation": ""}} INVESTIGATION PLAN — you have only 20 steps total, be extremely efficient: Step 1: get_dependency_map (no service param) to see full architecture Step 2: check_recent_changes (no service param) to see all recent deploys and changes Step 3-4: get_service_status for the UNHEALTHY/DEGRADED services mentioned in the incident Step 5-6: query_logs for unhealthy services (use "" as query to get all logs) Step 7-8: query_metrics for the suspicious root-cause service (error_rate, memory, connections) Step 9: submit_resolution with your findings CRITICAL RULES: - The ROOT CAUSE is often UPSTREAM — a dependency of the symptomatic service, not the alerted service itself - Look for: bad deployments, missing env vars, OOM/memory issues, connection pool exhaustion, long-running queries - affected_service MUST be the root-cause service, NOT the symptom service - root_cause must mention specific service names, error types, versions, and technical details - recommendation must be specific and actionable (e.g. rollback, increase memory limit, kill query, set timeout) - Do NOT repeat the same tool call — you already have that data - You MUST call submit_resolution by step 10 at the latest — do not keep investigating - Respond with ONLY a JSON object. No markdown fences, no explanation.""" FORCE_RESOLUTION_PROMPT = """URGENT: You MUST call submit_resolution NOW. No more investigation. Synthesize everything you have gathered. Your response MUST be ONLY: {"tool_name": "submit_resolution", "parameters": {"root_cause": "", "affected_service": "", "recommendation": ""}} Do NOT call any other tool. Submit NOW.""" # ── action parsing ────────────────────────────────────────────────── def parse_action(text: str) -> dict | None: """Parse LLM output into an action dict with multiple fallbacks.""" # 1. Direct JSON parse try: obj = json.loads(text.strip()) if isinstance(obj, dict) and "tool_name" in obj: return obj except json.JSONDecodeError: pass # 2. Extract from markdown code fence fence_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) if fence_match: try: obj = json.loads(fence_match.group(1).strip()) if isinstance(obj, dict) and "tool_name" in obj: return obj except json.JSONDecodeError: pass # 3. Regex: first {...} block brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL) if brace_match: try: obj = json.loads(brace_match.group(0)) if isinstance(obj, dict) and "tool_name" in obj: return obj except json.JSONDecodeError: pass return None # ── history management ────────────────────────────────────────────── def build_initial_prompt(observation: dict) -> str: """Build the first user prompt from the reset observation.""" parts = [] parts.append(f"INCIDENT: {observation.get('incident_summary', '')}") tool_descs = observation.get("tool_descriptions") if tool_descs: parts.append("\n--- AVAILABLE TOOL PARAMETERS ---") for tool, meta in tool_descs.items(): parts.append(f" {tool}: {json.dumps(meta)}") parts.append( f"\nStep {observation.get('step_number', 0)}/{observation.get('max_steps', 20)}" ) parts.append("\nBegin your investigation. Respond with your first action as JSON:") return "\n".join(parts) def build_tool_response_prompt(observation: dict) -> str: """Build a follow-up user prompt after a tool call.""" parts = [] tool_output = observation.get("tool_output", "") if tool_output: parts.append(f"Tool output:\n{tool_output}") if observation.get("last_action_error"): parts.append(f"\n⚠ ERROR: {observation['last_action_error']}") step_num = observation.get("step_number", 0) max_steps = observation.get("max_steps", 20) parts.append( f"\nStep {step_num}/{max_steps} | " f"Cumulative reward: {observation.get('cumulative_reward', 0.0):.2f}" ) # Add urgency nudges based on step progress if step_num >= max_steps - 5: parts.append("\n⚠⚠ CRITICAL: You MUST call submit_resolution NOW! No more investigation!") elif step_num >= max_steps - 8: parts.append("\n⚠ WARNING: Submit your resolution NOW. Call submit_resolution with your best analysis.") elif step_num >= max_steps - 12: parts.append("\nNote: Start forming your resolution. You should submit within the next 2-3 steps.") parts.append("\nRespond with your next action as JSON:") return "\n".join(parts) # ── main loop ─────────────────────────────────────────────────────── async def run_task(task_id: int, base_url: str, client: OpenAI) -> float: """Run a single task against the environment via WebSocket. Returns grader score.""" import websockets ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = ws_url.rstrip("/") + "/ws" async with websockets.connect(ws_url, ping_interval=120, ping_timeout=300) as ws: # Reset await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id}})) resp = json.loads(await ws.recv()) data = resp["data"] observation = data.get("observation", data) done = data.get("done", False) # [START] — mandatory structured log print(f"[START] task=task_{task_id} env=sentinel_env model={MODEL_NAME}") # Build multi-turn conversation messages: list[dict] = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_initial_prompt(observation)}, ] final_score = 0.0 local_step = 0 # track client-side loop iterations rewards_list: list[str] = [] # collect per-step rewards for [END] line while not done: local_step += 1 step_num = observation.get("step_number", local_step) max_steps = observation.get("max_steps", 20) # Safety: break if we've looped too many times without env advancing if local_step > max_steps + 10: print(f" Exceeded max loop iterations ({local_step}), breaking", flush=True) break print(f" Step {step_num} (iter {local_step}): calling LLM...", flush=True) # Force resolution when nearing step limit force_resolution = step_num >= max_steps - 8 if force_resolution: print(f" Step {step_num}: FORCING resolution submission", flush=True) # Add force prompt as an additional system nudge in the conversation force_msg = {"role": "system", "content": FORCE_RESOLUTION_PROMPT} call_messages = messages + [force_msg] else: call_messages = messages # Trim conversation if too long (keep system + first user + last 20 messages) if len(call_messages) > 30: call_messages = call_messages[:2] + call_messages[-20:] # Try to get a valid action from the LLM action_dict = None for attempt in range(MAX_PARSE_RETRIES): try: # Run sync LLM call in a thread so the event loop can # still handle WebSocket pings during long reasoning calls response = await asyncio.to_thread( client.chat.completions.create, model=MODEL_NAME, messages=call_messages, max_completion_tokens=MAX_COMPLETION_TOKENS, ) raw = response.choices[0].message.content or "" action_dict = parse_action(raw) if action_dict: # Store the assistant response in conversation messages.append({"role": "assistant", "content": raw}) print(f" Step {step_num}: action={action_dict.get('tool_name', '?')}", flush=True) break else: print(f" Step {step_num}: parse failed (attempt {attempt + 1}), raw={raw[:200]}", flush=True) # On parse failure, add a nudge and retry if attempt < MAX_PARSE_RETRIES - 1: call_messages = call_messages + [ {"role": "assistant", "content": raw}, {"role": "user", "content": "That was not valid JSON. Respond with ONLY a JSON object like {\"tool_name\": \"...\", \"parameters\": {...}}"}, ] except Exception as e: print(f" LLM error (attempt {attempt + 1}): {e}", file=sys.stderr) if action_dict is None: # Fallback: send an invalid action to let the env handle it action_dict = {"tool_name": "_invalid_", "parameters": {}} messages.append({"role": "assistant", "content": json.dumps(action_dict)}) # Normalize "all" query param to empty string (handler uses it as substring filter) if action_dict.get("parameters", {}).get("query") == "all": action_dict["parameters"]["query"] = "" if action_dict.get("parameters", {}).get("severity") == "all": action_dict["parameters"].pop("severity", None) # Step the environment via WebSocket await ws.send(json.dumps({"type": "step", "data": action_dict})) resp = json.loads(await ws.recv()) data = resp["data"] observation = data.get("observation", data) done = data.get("done", False) if observation.get("done"): done = True step_reward = observation.get("reward", 0.0) rewards_list.append(f"{step_reward:.2f}") # [STEP] — mandatory structured log print(f"[STEP] step={step_num} action={action_dict.get('tool_name')} " f"reward={step_reward:.2f} done={str(done).lower()} " f"error={observation.get('last_action_error', 'null')}") # Add tool response as user message in the conversation if not done: messages.append({"role": "user", "content": build_tool_response_prompt(observation)}) # Get final state for score await ws.send(json.dumps({"type": "state"})) resp = json.loads(await ws.recv()) state_data = resp["data"] final_score = state_data.get("final_score", 0.0) # [END] — mandatory structured log print(f"[END] success={str(final_score > 0).lower()} steps={local_step} " f"score={final_score:.2f} rewards={','.join(rewards_list)}") return final_score async def main() -> None: llm_client = OpenAI( base_url=API_BASE_URL, api_key=API_KEY, ) print(f"Using LLM API: {API_BASE_URL} / model={MODEL_NAME}") print(f"Environment URL: {ENV_URL}") scores: dict[int, float] = {} for task_id in [1, 2, 3]: print(f"\n{'='*50}") print(f"Running Task {task_id}...") print(f"{'='*50}") try: score = await asyncio.wait_for( run_task(task_id, ENV_URL, llm_client), timeout=TASK_TIMEOUT, ) scores[task_id] = score except asyncio.TimeoutError: print(f" Task {task_id} timed out after {TASK_TIMEOUT}s", file=sys.stderr) scores[task_id] = 0.0 except Exception as e: print(f" Task {task_id} failed: {e}", file=sys.stderr) scores[task_id] = 0.0 print(f"Task {task_id}: {scores[task_id]:.2f}") avg = sum(scores.values()) / len(scores) if scores else 0.0 print(f"\n{'='*50}") print(f"Task 1: {scores.get(1, 0.0):.2f}") print(f"Task 2: {scores.get(2, 0.0):.2f}") print(f"Task 3: {scores.get(3, 0.0):.2f}") print(f"Average: {avg:.2f}") print(f"{'='*50}") if __name__ == "__main__": asyncio.run(main())