Spaces:
Sleeping
Sleeping
| """ | |
| inference.py β Baseline agent for OnCallEnv. | |
| Uses the OpenAI API client to run an LLM against all tasks. | |
| Supports OpenAI, Gemini, and any OpenAI-compatible API. | |
| Required env vars: | |
| API_BASE_URL β LLM API endpoint | |
| MODEL_NAME β Model identifier | |
| HF_TOKEN β API key (used as OPENAI_API_KEY) | |
| Usage: | |
| export API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/ | |
| export MODEL_NAME=gemini-2.0-flash | |
| export HF_TOKEN=your-api-key | |
| python inference.py | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import textwrap | |
| import requests | |
| from typing import Any | |
| from openai import OpenAI | |
| # ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gemini-2.0-flash") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| API_KEY = HF_TOKEN or os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY", "") | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") | |
| TASKS = [ | |
| "easy_memory_leak", | |
| "medium_cascading_failure", | |
| "hard_cache_degradation", | |
| "medium_dns_misconfiguration", | |
| "hard_replication_lag", | |
| "expert_multi_root_cause", | |
| ] | |
| # ββ System prompt βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = textwrap.dedent("""\ | |
| You are an expert Site Reliability Engineer (SRE) responding to a production incident. | |
| You must diagnose the root cause and fix the issue as efficiently as possible. | |
| INVESTIGATION COMMANDS: | |
| check_metrics <service> β View CPU, memory, latency, error rates, custom metrics | |
| check_logs <service> β View recent log entries | |
| check_config <service> β View service configuration | |
| check_dependencies <service> β View service dependency graph | |
| check_deploy_history <service> β View recent deployments and version changes | |
| REMEDIATION COMMANDS: | |
| restart_service <service> β Restart a service (use for OOM/crash issues) | |
| rollback_deploy <service> β Roll back to previous deployment version | |
| scale_service <service> <replicas> β Scale replicas up/down | |
| update_config <service> <key> <value> β Update a config parameter | |
| mark_resolved <root_cause_description> β REQUIRED: Mark incident resolved with root cause | |
| STRATEGY: | |
| 1. Read the alerts carefully to identify ALL affected services | |
| 2. CRITICAL: Before applying ANY remediation, you MUST investigate at least ALL alerted | |
| services and their direct dependencies. Check a minimum of 3-4 different services. | |
| 3. For each alerted service: check_metrics β check_logs β check_config or check_deploy_history | |
| 4. Follow the dependency chain β symptoms often appear upstream of the root cause | |
| 5. Look for recent changes (deploys, config changes) that correlate with the incident | |
| 6. There may be MULTIPLE independent root causes β fix ALL of them before marking resolved | |
| 7. Apply the correct remediation: | |
| - OOM/memory issues β restart_service | |
| - Bad deployment β rollback_deploy | |
| - Wrong config value β update_config | |
| 8. AFTER fixing ALL issues, you MUST call mark_resolved with a description that includes: | |
| - The name(s) of the root cause service(s) | |
| - What went wrong (e.g., "memory leak", "config change", "bad deployment") | |
| This step is CRITICAL for full credit. | |
| IMPORTANT RULES: | |
| - Respond with EXACTLY ONE command per turn | |
| - No explanation, no markdown, no extra text β just the command string | |
| - Do NOT restart or rollback healthy services β you will be penalized | |
| - After a successful remediation, ALWAYS call mark_resolved immediately | |
| - If there are multiple issues, fix each one before calling mark_resolved | |
| Example responses: | |
| check_logs payment-service | |
| restart_service payment-service | |
| update_config order-service db_pool_size 50 | |
| rollback_deploy cache-service | |
| mark_resolved payment-service memory leak due to unbounded transaction cache causing OOM | |
| """) | |
| # ββ Environment client ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EnvClient: | |
| """HTTP client for the OnCallEnv API.""" | |
| def __init__(self, base_url: str): | |
| self.base_url = base_url.rstrip("/") | |
| def reset(self, task_id: str) -> dict: | |
| resp = requests.post(f"{self.base_url}/reset", json={"task_id": task_id}) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def step(self, command: str) -> dict: | |
| resp = requests.post(f"{self.base_url}/step", json={"command": command}) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def state(self) -> dict: | |
| resp = requests.get(f"{self.base_url}/state") | |
| resp.raise_for_status() | |
| return resp.json() | |
| # ββ Prompt builder ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_user_prompt(obs: dict) -> str: | |
| """Format the current observation into a user prompt for the LLM.""" | |
| alerts_text = "" | |
| for a in obs.get("alerts", []): | |
| alerts_text += f" [{a['severity'].upper()}] {a['service']}: {a['message']}\n" | |
| prompt = f"Step {obs['step']}/{obs['max_steps']}\n\n" | |
| prompt += f"GOAL: {obs['goal']}\n\n" | |
| prompt += f"ACTIVE ALERTS:\n{alerts_text}\n" | |
| prompt += f"AVAILABLE SERVICES: {', '.join(obs.get('services', []))}\n\n" | |
| if obs.get('last_action'): | |
| prompt += f"LAST ACTION: {obs['last_action']}\n" | |
| prompt += f"RESULT:\n{obs.get('last_action_result', 'N/A')}\n" | |
| if obs.get('last_action_error'): | |
| prompt += "STATUS: ERROR β command failed\n" | |
| prompt += "\n" | |
| prompt += "What is your next command? Reply with EXACTLY ONE command." | |
| return prompt | |
| def parse_action(response_text: str) -> str: | |
| """Extract a single command from the LLM response.""" | |
| if not response_text: | |
| return "check_metrics api-gateway" | |
| valid_cmds = [ | |
| "check_metrics", "check_logs", "check_config", "check_dependencies", | |
| "check_deploy_history", "restart_service", "rollback_deploy", | |
| "scale_service", "update_config", "mark_resolved", | |
| ] | |
| # Try each line for a valid command | |
| for line in response_text.strip().splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Remove markdown backticks, leading symbols, quotes | |
| line = line.strip("`").strip("- ").strip("> ").strip("'\"").strip() | |
| for cmd in valid_cmds: | |
| if line.lower().startswith(cmd): | |
| return line | |
| # Fallback: return the first non-empty line cleaned up | |
| for line in response_text.strip().splitlines(): | |
| line = line.strip().strip("`").strip("- ").strip("> ").strip() | |
| if line and not line.startswith("#") and not line.startswith("//"): | |
| return line | |
| return "check_metrics api-gateway" | |
| # ββ Main loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task(client: OpenAI, env: EnvClient, task_id: str) -> float: | |
| """Run the agent on a single task and return the final score.""" | |
| print(f"[START] task={task_id} env=OnCallEnv model={MODEL_NAME}") | |
| obs = env.reset(task_id) | |
| done = False | |
| step_num = 0 | |
| rewards_list = [] | |
| # Conversation history for multi-turn reasoning | |
| messages: list[dict[str, str]] = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| ] | |
| while not done: | |
| step_num += 1 | |
| user_prompt = build_user_prompt(obs) | |
| messages.append({"role": "user", "content": user_prompt}) | |
| # Call LLM with retry | |
| raw_action = "" | |
| for attempt in range(3): | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=1024, | |
| ) | |
| raw_action = response.choices[0].message.content or "" | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| time.sleep(2 ** attempt) | |
| else: | |
| raw_action = "check_metrics api-gateway" | |
| action = parse_action(raw_action) | |
| messages.append({"role": "assistant", "content": action}) | |
| # Step environment | |
| result = env.step(action) | |
| obs = result["observation"] | |
| reward = result["reward"] | |
| done = result["done"] | |
| step_reward = reward["total"] if isinstance(reward, dict) and "total" in reward else float(reward) | |
| rewards_list.append(f"{step_reward:.2f}") | |
| last_error = obs.get("last_action_error") | |
| error_str = str(last_error) if last_error else "null" | |
| print(f"[STEP] step={step_num} action={action} reward={step_reward:.2f} done={str(done).lower()} error={error_str}") | |
| if done: | |
| final_score = step_reward | |
| success = str(final_score > 0.0).lower() | |
| rewards_str = ",".join(rewards_list) | |
| print(f"[END] success={success} steps={step_num} score={final_score:.2f} rewards={rewards_str}") | |
| return final_score | |
| # Trim conversation history if getting too long (keep system + last 16 turns) | |
| if len(messages) > 34: # system + 16 pairs | |
| messages = [messages[0]] + messages[-32:] | |
| # Shouldn't reach here, but just in case | |
| st = env.state() | |
| return st.get("score", 0.0) | |
| def main(): | |
| if not API_KEY: | |
| sys.exit(1) | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| env = EnvClient(ENV_URL) | |
| scores: dict[str, float] = {} | |
| for task_id in TASKS: | |
| try: | |
| score = run_task(client, env, task_id) | |
| scores[task_id] = score | |
| except Exception as e: | |
| print(f"[END] success=false steps=0 score=0.00 rewards=0.00") | |
| scores[task_id] = 0.0 | |
| if __name__ == "__main__": | |
| main() | |