sre-openenv / inference.py
Dragonfire146's picture
feat: add required [START]/[STEP]/[END] structured stdout output blocks
017aa35
"""
Baseline inference script for SRE OpenEnv.
This script runs an AI agent (via OpenAI API) against the SRE environment.
It implements a standard agentic loop with OpenEnv 0.1 spec compliance:
1. Initialize environment with a task (easy, medium, hard).
2. Format observation and system prompt for the LLM.
3. Use OpenAI Tool Calling (Structured Outputs) for reliable action parsing.
4. Execute action via StepResult and repeat until terminal or max steps.
"""
import os
import sys
import json
import logging
import time
from typing import List, Optional
from openai import OpenAI
from pydantic import ValidationError
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Add project root to sys.path to allow absolute imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from client import SREEnv
from models import SREAction, SREObservation
# Configuration from competition requirements
# Defaults for API_BASE_URL and MODEL_NAME are allowed; HF_TOKEN is mandatory
API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN") # No default β€” must be set in environment
if not HF_TOKEN:
raise EnvironmentError(
"HF_TOKEN is not set. Please add your Hugging Face API token to your .env file."
)
MAX_STEPS = 60 # Local limit for inference script
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def get_agent_response(client: OpenAI, model: str, messages: List[dict]) -> Optional[SREAction]:
"""
Calls the LLM via tool calling to strictly enforce SREAction schema.
Includes retry logic for rate limits (429) - 5s, 10s, 20s.
"""
retry_delays = [5, 10, 20]
# Define the tool based on SREAction schema
tools = [
{
"type": "function",
"function": {
"name": "execute_sre_action",
"description": "Execute a shell command or patch a file in the SRE environment",
"parameters": {
"type": "object",
"properties": {
"action_type": {
"type": "string",
"enum": ["run_shell", "patch_file"],
"description": "The type of action to perform"
},
"command": {
"type": "string",
"description": "The shell command to run (if action_type is run_shell)"
},
"file_path": {
"type": "string",
"description": "The path to the file to patch (if action_type is patch_file)"
},
"content": {
"type": "string",
"description": "The new content for the file (if action_type is patch_file)"
}
},
"required": ["action_type"]
}
}
}
]
for i, delay in enumerate(retry_delays + [None]):
try:
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
temperature=0.0,
)
# Check if there are tool calls to avoid 'NoneType' errors
choice = response.choices[0]
if choice.message.tool_calls:
tool_call = choice.message.tool_calls[0]
args = json.loads(tool_call.function.arguments)
return SREAction(**args)
# Fallback: Parse from content if choices[0].message.content has a JSON block
content = choice.message.content or ""
if "{" in content and "}" in content:
try:
# Try to find the first '{' and last '}'
start = content.find("{")
end = content.rfind("}") + 1
json_str = content[start:end]
# Clean up common LLM artifacts like ```json ... ```
json_str = json_str.replace("```json", "").replace("```", "").strip()
args = json.loads(json_str)
# Support various JSON formats:
# 1. Direct object: {"action_type": "...", ...}
# 2. Wrapped object: {"action": {"action_type": "...", ...}}
# 3. Hallucinated "action" as action_type: {"action": "run_shell", ...}
if "action_type" not in args:
if "action" in args:
if isinstance(args["action"], dict):
args = args["action"]
else:
# Hallucination: {"action": "run_shell", "command": "..."}
args["action_type"] = args.pop("action")
# Ensure it matches the SREAction schema (which uses action_type)
return SREAction(**args)
except (ValueError, json.JSONDecodeError):
pass
logger.warning(f"Response did not contain valid tool calls or JSON. Content: {content}")
return None
except Exception as e:
err_str = str(e)
if "429" in err_str and delay is not None:
logger.warning(f"Rate limit hit (429). Retrying in {delay} seconds... (Attempt {i+1}/{len(retry_delays)})")
time.sleep(delay)
continue
logger.error(f"Error getting agent response via tool calling: {e}")
return None
return None
def run_episode(task_id: str, model: str = None):
"""
Runs a single agent episode for a given task.
"""
# Use provided model or fall back to global MODEL_NAME
selected_model = model or MODEL_NAME
# OpenAI Client configured using competition variables
client = OpenAI(
base_url=API_BASE_URL,
api_key=HF_TOKEN
)
# The environment server URL.
# Defaults to the HF Space deployment; override with SRE_ENV_URL for local.
env_url = os.getenv("SRE_ENV_URL", "https://dragonfire146-sre-openenv.hf.space")
logger.info(f"Connecting to SRE environment at: {env_url}")
env = SREEnv(base_url=env_url)
logger.info(f"Starting episode for task: {task_id} using model: {selected_model}")
# OpenEnv reset() returns a StepResult object
result = env.reset(task_id=task_id)
obs = result.observation
# state is a property, not a method
state = env.state
system_prompt = f"""
You are a Senior Site Reliability Engineer. Your goal is to solve the following infrastructure task:
Task: {state.task_name}
Difficulty: {state.difficulty}
Description: {state.description}
You can interact with the environment using the 'execute_sre_action' tool.
Your actions are:
1. "run_shell": Execute a bash command and observe its output.
2. "patch_file": Overwrite or create a file with new content.
Rules:
- THE SHELL IS STATELESS: Every 'run_shell' call is a NEW isolated process.
- 'cd' and 'export' will NOT persist to the next step.
- You MUST use absolute paths or combine commands: `cd /path && run_cmd`.
- Be efficient and minimize steps.
- Investigate logs and system state (ls, ps, cat) before patching.
- Fix the root cause precisely.
"""
# Help the agent by providing initial directory structure
ls_result = env.step(SREAction(action_type="run_shell", command=f"ls -R /tmp/sre_tasks/{task_id}/"))
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Initial state observed. Files:\n{ls_result.observation.stdout}\n\nstdout: {obs.stdout}\nstderr: {obs.stderr}"}
]
total_reward = 0.0
step_count = 0
max_steps = state.max_steps
# ── Required structured output block: episode start ──────────────────────
print("[START]", flush=True)
print(json.dumps({
"task_id": task_id,
"model": selected_model,
"max_steps": max_steps,
}), flush=True)
while step_count < max_steps:
step_count += 1
logger.info(f"Step {step_count}/{max_steps}")
action = get_agent_response(client, selected_model, messages)
if not action:
logger.warning("No valid action returned by agent. Terminating.")
break
logger.info(f"Agent Action: {action.action_type} - {action.command if action.action_type == 'run_shell' else action.file_path}")
# OpenEnv step() returns a StepResult object
result = env.step(action)
obs = result.observation
total_reward += result.reward
done = result.done
# Update conversation history
# Support cross-version serialization
from dataclasses import asdict as dc_asdict
if hasattr(action, "model_dump_json"):
action_json = action.model_dump_json()
elif hasattr(action, "model_dump"):
action_json = json.dumps(action.model_dump())
else:
action_json = json.dumps(dc_asdict(action))
obs_desc = f"Observation: stdout='{obs.stdout}', stderr='{obs.stderr}', exit_code={obs.exit_code}"
messages.append({"role": "assistant", "content": f"Action call: {action_json}"})
messages.append({"role": "user", "content": obs_desc})
# ── Required structured output block: per-step ────────────────────────
print("[STEP]", flush=True)
print(json.dumps({
"step": step_count,
"action_type": action.action_type,
"action": action.command if action.action_type == "run_shell" else action.file_path,
"reward": result.reward,
"done": done,
"exit_code": obs.exit_code,
}), flush=True)
if done:
logger.info("Episode finished.")
break
# state is a property
final_state = env.state
logger.info(f"Result for {task_id}: Final Reward = {final_state.current_reward:.2f}, Steps = {step_count}")
# ── Required structured output block: episode end ─────────────────────────
print("[END]", flush=True)
print(json.dumps({
"task_id": task_id,
"total_steps": step_count,
"final_reward": final_state.current_reward,
}), flush=True)
return final_state.current_reward
if __name__ == "__main__":
# Ensure environment is selected for all tasks
tasks = ["medium_build"]
scores = {}
for tid in tasks:
score = run_episode(tid)
scores[tid] = score
print("\n" + "="*30)
print("Baseline Inference Results (OpenEnv 0.1)")
print("="*30)
for tid, score in scores.items():
print(f"{tid:15}: {score:.2f}")
print("="*30)