""" Baseline inference script for SRE OpenEnv. This script runs an AI agent (via OpenAI API) against the SRE environment. It implements a standard agentic loop with OpenEnv 0.1 spec compliance: 1. Initialize environment with a task (easy, medium, hard). 2. Format observation and system prompt for the LLM. 3. Use OpenAI Tool Calling (Structured Outputs) for reliable action parsing. 4. Execute action via StepResult and repeat until terminal or max steps. """ import os import sys import json import logging import time from typing import List, Optional from openai import OpenAI from pydantic import ValidationError from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # Add project root to sys.path to allow absolute imports sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from client import SREEnv from models import SREAction, SREObservation # Configuration MODEL_NAME = "gpt-5.4" MAX_STEPS = 60 # Local limit for inference script # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def get_agent_response(client: OpenAI, model: str, messages: List[dict]) -> Optional[SREAction]: """ Calls the LLM via tool calling to strictly enforce SREAction schema. Includes retry logic for rate limits (429) - 5s, 10s, 20s. """ retry_delays = [5, 10, 20] # Define the tool based on SREAction schema tools = [ { "type": "function", "function": { "name": "execute_sre_action", "description": "Execute a shell command or patch a file in the SRE environment", "parameters": { "type": "object", "properties": { "action_type": { "type": "string", "enum": ["run_shell", "patch_file"], "description": "The type of action to perform" }, "command": { "type": "string", "description": "The shell command to run (if action_type is run_shell)" }, "file_path": { "type": "string", "description": "The path to the file to patch (if action_type is patch_file)" }, "content": { "type": "string", "description": "The new content for the file (if action_type is patch_file)" } }, "required": ["action_type"] } } } ] for i, delay in enumerate(retry_delays + [None]): try: response = client.chat.completions.create( model=model, messages=messages, tools=tools, tool_choice="auto", temperature=0.0, ) # Check if there are tool calls to avoid 'NoneType' errors choice = response.choices[0] if choice.message.tool_calls: tool_call = choice.message.tool_calls[0] args = json.loads(tool_call.function.arguments) return SREAction(**args) # Fallback: Parse from content if choices[0].message.content has a JSON block content = choice.message.content or "" if "{" in content and "}" in content: try: # Try to find the first '{' and last '}' start = content.find("{") end = content.rfind("}") + 1 json_str = content[start:end] # Clean up common LLM artifacts like ```json ... ``` json_str = json_str.replace("```json", "").replace("```", "").strip() args = json.loads(json_str) # Support various JSON formats: # 1. Direct object: {"action_type": "...", ...} # 2. Wrapped object: {"action": {"action_type": "...", ...}} # 3. Hallucinated "action" as action_type: {"action": "run_shell", ...} if "action_type" not in args: if "action" in args: if isinstance(args["action"], dict): args = args["action"] else: # Hallucination: {"action": "run_shell", "command": "..."} args["action_type"] = args.pop("action") # Ensure it matches the SREAction schema (which uses action_type) return SREAction(**args) except (ValueError, json.JSONDecodeError): pass logger.warning(f"Response did not contain valid tool calls or JSON. Content: {content}") return None except Exception as e: err_str = str(e) if "429" in err_str and delay is not None: logger.warning(f"Rate limit hit (429). Retrying in {delay} seconds... (Attempt {i+1}/{len(retry_delays)})") time.sleep(delay) continue logger.error(f"Error getting agent response via tool calling: {e}") return None return None def run_episode(task_id: str, model: str = None): """ Runs a single agent episode for a given task. """ if model is None: model = MODEL_NAME # OpenAI GPT-4o-mini api_key = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=api_key) env = SREEnv() logger.info(f"Starting episode for task: {task_id}") # OpenEnv reset() returns a StepResult object result = env.reset(task_id=task_id) obs = result.observation # state is a property, not a method state = env.state system_prompt = f""" You are a Senior Site Reliability Engineer. Your goal is to solve the following infrastructure task: Task: {state.task_name} Difficulty: {state.difficulty} Description: {state.description} You can interact with the environment using the 'execute_sre_action' tool. Your actions are: 1. "run_shell": Execute a bash command and observe its output. 2. "patch_file": Overwrite or create a file with new content. Rules: - THE SHELL IS STATELESS: Every 'run_shell' call is a NEW isolated process. - 'cd' and 'export' will NOT persist to the next step. - You MUST use absolute paths or combine commands: `cd /path && run_cmd`. - Be efficient and minimize steps. - Investigate logs and system state (ls, ps, cat) before patching. - Fix the root cause precisely. """ # Help the agent by providing initial directory structure ls_result = env.step(SREAction(action_type="run_shell", command=f"ls -R /tmp/sre_tasks/{task_id}/")) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Initial state observed. Files:\n{ls_result.observation.stdout}\n\nstdout: {obs.stdout}\nstderr: {obs.stderr}"} ] total_reward = 0.0 step_count = 0 max_steps = state.max_steps while step_count < max_steps: step_count += 1 logger.info(f"Step {step_count}/{max_steps}") action = get_agent_response(client, model, messages) if not action: logger.warning("No valid action returned by agent. Terminating.") break logger.info(f"Agent Action: {action.action_type} - {action.command if action.action_type == 'run_shell' else action.file_path}") # OpenEnv step() returns a StepResult object result = env.step(action) obs = result.observation total_reward += result.reward done = result.done # Update conversation history # Support cross-version serialization from dataclasses import asdict as dc_asdict if hasattr(action, "model_dump_json"): action_json = action.model_dump_json() elif hasattr(action, "model_dump"): action_json = json.dumps(action.model_dump()) else: action_json = json.dumps(dc_asdict(action)) obs_desc = f"Observation: stdout='{obs.stdout}', stderr='{obs.stderr}', exit_code={obs.exit_code}" messages.append({"role": "assistant", "content": f"Action call: {action_json}"}) messages.append({"role": "user", "content": obs_desc}) if done: logger.info("Episode finished.") break # state is a property final_state = env.state logger.info(f"Result for {task_id}: Final Reward = {final_state.current_reward:.2f}, Steps = {step_count}") return final_state.current_reward if __name__ == "__main__": # Ensure environment is selected for all tasks tasks = ["medium_build"] scores = {} for tid in tasks: score = run_episode(tid) scores[tid] = score print("\n" + "="*30) print("Baseline Inference Results (OpenEnv 0.1)") print("="*30) for tid, score in scores.items(): print(f"{tid:15}: {score:.2f}") print("="*30)