Spaces:
Sleeping
Sleeping
| """ | |
| Baseline inference script for SRE OpenEnv. | |
| This script runs an AI agent (via OpenAI API) against the SRE environment. | |
| It implements a standard agentic loop with OpenEnv 0.1 spec compliance: | |
| 1. Initialize environment with a task (easy, medium, hard). | |
| 2. Format observation and system prompt for the LLM. | |
| 3. Use OpenAI Tool Calling (Structured Outputs) for reliable action parsing. | |
| 4. Execute action via StepResult and repeat until terminal or max steps. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import logging | |
| import time | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| from pydantic import ValidationError | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Add project root to sys.path to allow absolute imports | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from client import SREEnv | |
| from models import SREAction, SREObservation | |
| # Configuration from competition requirements | |
| # Defaults for API_BASE_URL and MODEL_NAME are allowed; HF_TOKEN is mandatory | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") # No default β must be set in environment | |
| if not HF_TOKEN: | |
| raise EnvironmentError( | |
| "HF_TOKEN is not set. Please add your Hugging Face API token to your .env file." | |
| ) | |
| MAX_STEPS = 60 # Local limit for inference script | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| def get_agent_response(client: OpenAI, model: str, messages: List[dict]) -> Optional[SREAction]: | |
| """ | |
| Calls the LLM via tool calling to strictly enforce SREAction schema. | |
| Includes retry logic for rate limits (429) - 5s, 10s, 20s. | |
| """ | |
| retry_delays = [5, 10, 20] | |
| # Define the tool based on SREAction schema | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "execute_sre_action", | |
| "description": "Execute a shell command or patch a file in the SRE environment", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "action_type": { | |
| "type": "string", | |
| "enum": ["run_shell", "patch_file"], | |
| "description": "The type of action to perform" | |
| }, | |
| "command": { | |
| "type": "string", | |
| "description": "The shell command to run (if action_type is run_shell)" | |
| }, | |
| "file_path": { | |
| "type": "string", | |
| "description": "The path to the file to patch (if action_type is patch_file)" | |
| }, | |
| "content": { | |
| "type": "string", | |
| "description": "The new content for the file (if action_type is patch_file)" | |
| } | |
| }, | |
| "required": ["action_type"] | |
| } | |
| } | |
| } | |
| ] | |
| for i, delay in enumerate(retry_delays + [None]): | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| tools=tools, | |
| tool_choice="auto", | |
| temperature=0.0, | |
| ) | |
| # Check if there are tool calls to avoid 'NoneType' errors | |
| choice = response.choices[0] | |
| if choice.message.tool_calls: | |
| tool_call = choice.message.tool_calls[0] | |
| args = json.loads(tool_call.function.arguments) | |
| return SREAction(**args) | |
| # Fallback: Parse from content if choices[0].message.content has a JSON block | |
| content = choice.message.content or "" | |
| if "{" in content and "}" in content: | |
| try: | |
| # Try to find the first '{' and last '}' | |
| start = content.find("{") | |
| end = content.rfind("}") + 1 | |
| json_str = content[start:end] | |
| # Clean up common LLM artifacts like ```json ... ``` | |
| json_str = json_str.replace("```json", "").replace("```", "").strip() | |
| args = json.loads(json_str) | |
| # Support various JSON formats: | |
| # 1. Direct object: {"action_type": "...", ...} | |
| # 2. Wrapped object: {"action": {"action_type": "...", ...}} | |
| # 3. Hallucinated "action" as action_type: {"action": "run_shell", ...} | |
| if "action_type" not in args: | |
| if "action" in args: | |
| if isinstance(args["action"], dict): | |
| args = args["action"] | |
| else: | |
| # Hallucination: {"action": "run_shell", "command": "..."} | |
| args["action_type"] = args.pop("action") | |
| # Ensure it matches the SREAction schema (which uses action_type) | |
| return SREAction(**args) | |
| except (ValueError, json.JSONDecodeError): | |
| pass | |
| logger.warning(f"Response did not contain valid tool calls or JSON. Content: {content}") | |
| return None | |
| except Exception as e: | |
| err_str = str(e) | |
| if "429" in err_str and delay is not None: | |
| logger.warning(f"Rate limit hit (429). Retrying in {delay} seconds... (Attempt {i+1}/{len(retry_delays)})") | |
| time.sleep(delay) | |
| continue | |
| logger.error(f"Error getting agent response via tool calling: {e}") | |
| return None | |
| return None | |
| def run_episode(task_id: str, model: str = None): | |
| """ | |
| Runs a single agent episode for a given task. | |
| """ | |
| # Use provided model or fall back to global MODEL_NAME | |
| selected_model = model or MODEL_NAME | |
| # OpenAI Client configured using competition variables | |
| client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=HF_TOKEN | |
| ) | |
| # The environment server URL. | |
| # Defaults to the HF Space deployment; override with SRE_ENV_URL for local. | |
| env_url = os.getenv("SRE_ENV_URL", "https://dragonfire146-sre-openenv.hf.space") | |
| logger.info(f"Connecting to SRE environment at: {env_url}") | |
| env = SREEnv(base_url=env_url) | |
| logger.info(f"Starting episode for task: {task_id} using model: {selected_model}") | |
| # OpenEnv reset() returns a StepResult object | |
| result = env.reset(task_id=task_id) | |
| obs = result.observation | |
| # state is a property, not a method | |
| state = env.state | |
| system_prompt = f""" | |
| You are a Senior Site Reliability Engineer. Your goal is to solve the following infrastructure task: | |
| Task: {state.task_name} | |
| Difficulty: {state.difficulty} | |
| Description: {state.description} | |
| You can interact with the environment using the 'execute_sre_action' tool. | |
| Your actions are: | |
| 1. "run_shell": Execute a bash command and observe its output. | |
| 2. "patch_file": Overwrite or create a file with new content. | |
| Rules: | |
| - THE SHELL IS STATELESS: Every 'run_shell' call is a NEW isolated process. | |
| - 'cd' and 'export' will NOT persist to the next step. | |
| - You MUST use absolute paths or combine commands: `cd /path && run_cmd`. | |
| - Be efficient and minimize steps. | |
| - Investigate logs and system state (ls, ps, cat) before patching. | |
| - Fix the root cause precisely. | |
| """ | |
| # Help the agent by providing initial directory structure | |
| ls_result = env.step(SREAction(action_type="run_shell", command=f"ls -R /tmp/sre_tasks/{task_id}/")) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Initial state observed. Files:\n{ls_result.observation.stdout}\n\nstdout: {obs.stdout}\nstderr: {obs.stderr}"} | |
| ] | |
| total_reward = 0.0 | |
| step_count = 0 | |
| max_steps = state.max_steps | |
| # ββ Required structured output block: episode start ββββββββββββββββββββββ | |
| print("[START]", flush=True) | |
| print(json.dumps({ | |
| "task_id": task_id, | |
| "model": selected_model, | |
| "max_steps": max_steps, | |
| }), flush=True) | |
| while step_count < max_steps: | |
| step_count += 1 | |
| logger.info(f"Step {step_count}/{max_steps}") | |
| action = get_agent_response(client, selected_model, messages) | |
| if not action: | |
| logger.warning("No valid action returned by agent. Terminating.") | |
| break | |
| logger.info(f"Agent Action: {action.action_type} - {action.command if action.action_type == 'run_shell' else action.file_path}") | |
| # OpenEnv step() returns a StepResult object | |
| result = env.step(action) | |
| obs = result.observation | |
| total_reward += result.reward | |
| done = result.done | |
| # Update conversation history | |
| # Support cross-version serialization | |
| from dataclasses import asdict as dc_asdict | |
| if hasattr(action, "model_dump_json"): | |
| action_json = action.model_dump_json() | |
| elif hasattr(action, "model_dump"): | |
| action_json = json.dumps(action.model_dump()) | |
| else: | |
| action_json = json.dumps(dc_asdict(action)) | |
| obs_desc = f"Observation: stdout='{obs.stdout}', stderr='{obs.stderr}', exit_code={obs.exit_code}" | |
| messages.append({"role": "assistant", "content": f"Action call: {action_json}"}) | |
| messages.append({"role": "user", "content": obs_desc}) | |
| # ββ Required structured output block: per-step ββββββββββββββββββββββββ | |
| print("[STEP]", flush=True) | |
| print(json.dumps({ | |
| "step": step_count, | |
| "action_type": action.action_type, | |
| "action": action.command if action.action_type == "run_shell" else action.file_path, | |
| "reward": result.reward, | |
| "done": done, | |
| "exit_code": obs.exit_code, | |
| }), flush=True) | |
| if done: | |
| logger.info("Episode finished.") | |
| break | |
| # state is a property | |
| final_state = env.state | |
| logger.info(f"Result for {task_id}: Final Reward = {final_state.current_reward:.2f}, Steps = {step_count}") | |
| # ββ Required structured output block: episode end βββββββββββββββββββββββββ | |
| print("[END]", flush=True) | |
| print(json.dumps({ | |
| "task_id": task_id, | |
| "total_steps": step_count, | |
| "final_reward": final_state.current_reward, | |
| }), flush=True) | |
| return final_state.current_reward | |
| if __name__ == "__main__": | |
| # Ensure environment is selected for all tasks | |
| tasks = ["medium_build"] | |
| scores = {} | |
| for tid in tasks: | |
| score = run_episode(tid) | |
| scores[tid] = score | |
| print("\n" + "="*30) | |
| print("Baseline Inference Results (OpenEnv 0.1)") | |
| print("="*30) | |
| for tid, score in scores.items(): | |
| print(f"{tid:15}: {score:.2f}") | |
| print("="*30) | |