| """ |
| Baseline inference script. |
| |
| Uses an LLM (via OpenAI-compatible API) to play through all 3 incident |
| scenarios. The conversation history acts as a soft belief tracker β |
| the LLM accumulates evidence across steps. |
| |
| stdout format: [START], [STEP], [END] blocks with exact field names |
| as required by the OpenEnv automated evaluator. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| import time |
| import traceback |
| from typing import Any, Dict, List, Optional |
|
|
| import requests |
| from openai import OpenAI |
|
|
|
|
| |
| |
| |
|
|
| ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000") |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| API_KEY = os.environ.get("API_KEY", "") |
| MAX_STEPS = 20 |
| TEMPERATURE = 0.3 |
|
|
|
|
| |
| |
| |
|
|
| SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE) responding to a production incident. |
| |
| You are interacting with a simulated microservices infrastructure through an environment API. |
| Your goal is to: |
| 1. DIAGNOSE the root cause of the incident |
| 2. REMEDIATE the issue (fix it) |
| 3. DECLARE the root cause when confident |
| |
| ## Available Actions |
| You must respond with a single JSON object containing your chosen action: |
| |
| DIAGNOSTIC (information gathering): |
| - {"action_type": "view_alerts"} β See all firing alerts |
| - {"action_type": "query_logs", "target_service": "<name>", "parameters": {"level": "ERROR"}} β Query logs |
| - {"action_type": "check_metrics", "target_service": "<name>"} β Get metric timeseries |
| - {"action_type": "check_dependencies", "target_service": "<name>"} β View dependency graph |
| - {"action_type": "check_deploy_history", "target_service": "<name>"} β Recent deploys |
| - {"action_type": "run_health_check", "target_service": "<name>"} β Ping a service |
| |
| REMEDIATION (fix actions): |
| - {"action_type": "restart_service", "target_service": "<name>"} β Restart a service |
| - {"action_type": "rollback_deploy", "target_service": "<name>"} β Rollback to previous deploy |
| - {"action_type": "scale_service", "target_service": "<name>", "parameters": {"replicas": 5}} β Scale replicas |
| |
| DECLARATION: |
| - {"action_type": "declare_root_cause", "parameters": {"root_cause": "<your diagnosis>"}} |
| |
| ## Available services: api_gateway, auth, orders, payment, cache, database, queue |
| |
| ## Strategy |
| 1. Start by viewing alerts to understand the scope |
| 2. Check metrics and logs for the most affected services |
| 3. Check dependency graphs to trace upstream causes |
| 4. Check deploy history for recently changed services |
| 5. Apply remediation to the root cause service FIRST |
| 6. Declare root cause when confident |
| |
| IMPORTANT: Respond with ONLY a valid JSON object. No explanation, no markdown, just the JSON action. |
| """ |
|
|
|
|
| |
| |
| |
|
|
| class EnvClient: |
| def __init__(self, base_url: str): |
| self.base_url = base_url.rstrip("/") |
| self.session = requests.Session() |
|
|
| def reset(self, task_name: str, seed: int = 42) -> Dict[str, Any]: |
| resp = self.session.post(f"{self.base_url}/reset", json={ |
| "task_name": task_name, "seed": seed}) |
| resp.raise_for_status() |
| return resp.json() |
|
|
| def step(self, action: Dict[str, Any]) -> Dict[str, Any]: |
| resp = self.session.post(f"{self.base_url}/step", json=action) |
| resp.raise_for_status() |
| return resp.json() |
|
|
| def state(self) -> Dict[str, Any]: |
| resp = self.session.get(f"{self.base_url}/state") |
| resp.raise_for_status() |
| return resp.json() |
|
|
|
|
| |
| |
| |
|
|
| def create_openai_client() -> OpenAI: |
| """Create OpenAI client with appropriate config.""" |
| api_key = API_KEY or HF_TOKEN or "no-key" |
| base_url = os.environ.get("API_BASE_URL") |
|
|
| |
| if HF_TOKEN and not API_KEY and not base_url: |
| base_url = f"https://api-inference.huggingface.co/models/{MODEL_NAME}/v1" |
|
|
| return OpenAI(api_key=api_key, base_url=base_url) |
|
|
|
|
| def parse_llm_action(response_text: str) -> Dict[str, Any]: |
| """Extract JSON action from LLM response. Handles markdown wrapping.""" |
| text = response_text.strip() |
|
|
| |
| if text.startswith("```"): |
| lines = text.split("\n") |
| lines = [l for l in lines if not l.strip().startswith("```")] |
| text = "\n".join(lines).strip() |
|
|
| |
| start = text.find("{") |
| end = text.rfind("}") + 1 |
| if start >= 0 and end > start: |
| return json.loads(text[start:end]) |
|
|
| raise ValueError(f"Could not parse action from: {response_text[:200]}") |
|
|
|
|
| def summarize_observation(obs: Dict[str, Any]) -> str: |
| """Convert observation dict to a readable string for the LLM context.""" |
| parts = [] |
| parts.append(f"Incident: {obs.get('incident_summary', 'N/A')}") |
| parts.append(f"Severity: {obs.get('severity', 'N/A')}") |
| parts.append(f"Time: {obs.get('time_elapsed_minutes', 0)}/{obs.get('time_budget_minutes', 30)} min") |
| parts.append(f"Steps: {obs.get('steps_taken', 0)}/{obs.get('max_steps', 20)}") |
| parts.append(f"Reward: {obs.get('current_reward', 0)} (cumulative: {obs.get('cumulative_reward', 0)})") |
|
|
| statuses = obs.get("service_statuses", {}) |
| if statuses: |
| status_str = ", ".join(f"{k}: {v}" for k, v in statuses.items()) |
| parts.append(f"Services: {status_str}") |
|
|
| parts.append(f"Alerts: {obs.get('active_alerts_count', 0)} active") |
| parts.append(f"Action result: {obs.get('action_message', 'N/A')}") |
|
|
| |
| action_result = obs.get("action_result", {}) |
| if action_result: |
| result_str = json.dumps(action_result, indent=2, default=str) |
| if len(result_str) > 2000: |
| result_str = result_str[:2000] + "\n... (truncated)" |
| parts.append(f"Data:\n{result_str}") |
|
|
| return "\n".join(parts) |
|
|
|
|
| def run_episode( |
| env: EnvClient, |
| llm: OpenAI, |
| task_name: str, |
| seed: int = 42, |
| ) -> Dict[str, Any]: |
| """Run a single episode and return results.""" |
|
|
| |
| print(f"[START] task={task_name}") |
|
|
| result = env.reset(task_name, seed) |
| obs = result["observation"] |
|
|
| |
| messages: List[Dict[str, str]] = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": f"INCIDENT TRIGGERED:\n{summarize_observation(obs)}"}, |
| ] |
|
|
| episode_reward = 0.0 |
| final_info = {} |
|
|
| for step_num in range(1, MAX_STEPS + 1): |
| try: |
| |
| completion = llm.chat.completions.create( |
| model=MODEL_NAME, |
| messages=messages, |
| temperature=TEMPERATURE, |
| max_tokens=256, |
| ) |
| llm_response = completion.choices[0].message.content or "" |
|
|
| |
| action = parse_llm_action(llm_response) |
|
|
| |
| print(f"[STEP] step={step_num} action={json.dumps(action)}") |
|
|
| |
| step_result = env.step(action) |
| obs = step_result["observation"] |
| reward = step_result.get("reward", 0.0) |
| done = step_result.get("done", False) |
| info = step_result.get("info", {}) |
| episode_reward += reward |
|
|
| |
| messages.append({"role": "assistant", "content": llm_response}) |
| messages.append({ |
| "role": "user", |
| "content": f"Step {step_num} result (reward={reward}):\n{summarize_observation(obs)}" |
| }) |
|
|
| if done: |
| final_info = info |
| break |
|
|
| except Exception as e: |
| print(f"[STEP] step={step_num} error={str(e)}", file=sys.stderr) |
| |
| action = {"action_type": "view_alerts"} |
| step_result = env.step(action) |
| obs = step_result["observation"] |
| reward = step_result.get("reward", 0.0) |
| done = step_result.get("done", False) |
| episode_reward += reward |
| if done: |
| final_info = step_result.get("info", {}) |
| break |
|
|
| |
| final_state = env.state() |
| score = final_info.get("score", 0.01) |
|
|
| |
| print(f"[END] task={task_name} " |
| f"score={score:.3f} " |
| f"reward={episode_reward:.3f} " |
| f"steps={final_state.get('step_count', 0)}") |
|
|
| return { |
| "task_name": task_name, |
| "score": score, |
| "cumulative_reward": episode_reward, |
| "steps": final_state.get("step_count", 0), |
| "declared_root_cause": final_state.get("declared_root_cause"), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| tasks = ["memory_leak", "cascading_failure", "distributed_deadlock"] |
|
|
| print("=" * 60) |
| print("SRE Incident Response β OpenEnv Inference") |
| print(f"Model: {MODEL_NAME}") |
| print(f"Environment: {ENV_BASE_URL}") |
| print("=" * 60) |
|
|
| env = EnvClient(ENV_BASE_URL) |
| llm = create_openai_client() |
|
|
| results = [] |
| for task in tasks: |
| print(f"\n{'β' * 40}") |
| print(f"Task: {task}") |
| print(f"{'β' * 40}") |
|
|
| try: |
| result = run_episode(env, llm, task) |
| results.append(result) |
| except Exception as e: |
| print(f"[ERROR] Task {task} failed: {e}", file=sys.stderr) |
| traceback.print_exc() |
| results.append({ |
| "task_name": task, |
| "score": 0.01, |
| "cumulative_reward": 0.0, |
| "steps": 0, |
| "error": str(e), |
| }) |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("RESULTS SUMMARY") |
| print(f"{'=' * 60}") |
| for r in results: |
| score = r.get("score", 0.01) |
| print(f" {r['task_name']:30s} score={score:.3f} " |
| f"steps={r.get('steps', 0):2d} " |
| f"root_cause={r.get('declared_root_cause', 'N/A')}") |
|
|
| avg_score = sum(r.get("score", 0) for r in results) / len(results) |
| print(f"\n {'AVERAGE':30s} score={avg_score:.3f}") |
| print(f"{'=' * 60}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|