sentinel_env / inference.py
NoNameFound's picture
Upload folder using huggingface_hub
2585713 verified
"""Baseline inference script for OpenEnv-Sentinel.
Drives an LLM agent through all 3 SRE incident triage tasks.
Environment variables:
ENV_URL - Sentinel environment server URL (e.g. http://localhost:8000)
API_BASE_URL - LLM API base URL (default: https://router.huggingface.co/v1)
MODEL_NAME - Model or deployment name (default: openai/gpt-oss-120b:novita)
HF_TOKEN - Hugging Face token (used as API key)
LOCAL_IMAGE_NAME - Docker image name when using from_docker_image() (optional)
"""
import asyncio
import json
import os
import re
import sys
import time
from openai import OpenAI
# ── configuration ───────────────────────────────────────────────────
# Environment server URL (where the Sentinel env is running)
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
# LLM configuration (aligned with official OpenEnv inference examples)
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:novita")
HF_TOKEN = os.getenv("HF_TOKEN")
# Optional — if you use from_docker_image():
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
# API key: prefer API_KEY, fall back to HF_TOKEN
API_KEY = os.getenv("API_KEY") or HF_TOKEN or os.getenv("OPENAI_API_KEY", "")
TASK_TIMEOUT = 360 # 6 minutes per task
MAX_PARSE_RETRIES = 3
MAX_COMPLETION_TOKENS = 16384 # reasoning models need room for chain-of-thought
SYSTEM_PROMPT = """You are an expert SRE agent triaging a production incident.
You have access to diagnostic tools. Respond with ONLY a single JSON object — no markdown, no explanation, no extra text.
Available tools:
- get_service_status: {"tool_name": "get_service_status", "parameters": {"service": "<name>"}}
- query_logs: {"tool_name": "query_logs", "parameters": {"service": "<name>", "query": "<text filter, use empty string for all logs>"}}
- query_metrics: {"tool_name": "query_metrics", "parameters": {"service": "<name>", "metric": "<cpu|memory|error_rate|latency|connections>"}}
- get_dependency_map: {"tool_name": "get_dependency_map", "parameters": {"service": "<name or omit for full map>"}}
- consult_runbook: {"tool_name": "consult_runbook", "parameters": {"topic": "<search_topic>"}}
- check_recent_changes: {"tool_name": "check_recent_changes", "parameters": {"service": "<name or omit for all>"}}
- submit_resolution: {"tool_name": "submit_resolution", "parameters": {"root_cause": "<detailed explanation>", "affected_service": "<primary ROOT CAUSE service>", "recommendation": "<specific actionable fix>"}}
INVESTIGATION PLAN — you have only 20 steps total, be extremely efficient:
Step 1: get_dependency_map (no service param) to see full architecture
Step 2: check_recent_changes (no service param) to see all recent deploys and changes
Step 3-4: get_service_status for the UNHEALTHY/DEGRADED services mentioned in the incident
Step 5-6: query_logs for unhealthy services (use "" as query to get all logs)
Step 7-8: query_metrics for the suspicious root-cause service (error_rate, memory, connections)
Step 9: submit_resolution with your findings
CRITICAL RULES:
- The ROOT CAUSE is often UPSTREAM — a dependency of the symptomatic service, not the alerted service itself
- Look for: bad deployments, missing env vars, OOM/memory issues, connection pool exhaustion, long-running queries
- affected_service MUST be the root-cause service, NOT the symptom service
- root_cause must mention specific service names, error types, versions, and technical details
- recommendation must be specific and actionable (e.g. rollback, increase memory limit, kill query, set timeout)
- Do NOT repeat the same tool call — you already have that data
- You MUST call submit_resolution by step 10 at the latest — do not keep investigating
- Respond with ONLY a JSON object. No markdown fences, no explanation."""
FORCE_RESOLUTION_PROMPT = """URGENT: You MUST call submit_resolution NOW. No more investigation.
Synthesize everything you have gathered. Your response MUST be ONLY:
{"tool_name": "submit_resolution", "parameters": {"root_cause": "<detailed with service names, errors, versions>", "affected_service": "<the root cause service>", "recommendation": "<specific fix>"}}
Do NOT call any other tool. Submit NOW."""
# ── action parsing ──────────────────────────────────────────────────
def parse_action(text: str) -> dict | None:
"""Parse LLM output into an action dict with multiple fallbacks."""
# 1. Direct JSON parse
try:
obj = json.loads(text.strip())
if isinstance(obj, dict) and "tool_name" in obj:
return obj
except json.JSONDecodeError:
pass
# 2. Extract from markdown code fence
fence_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
if fence_match:
try:
obj = json.loads(fence_match.group(1).strip())
if isinstance(obj, dict) and "tool_name" in obj:
return obj
except json.JSONDecodeError:
pass
# 3. Regex: first {...} block
brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
if brace_match:
try:
obj = json.loads(brace_match.group(0))
if isinstance(obj, dict) and "tool_name" in obj:
return obj
except json.JSONDecodeError:
pass
return None
# ── history management ──────────────────────────────────────────────
def build_initial_prompt(observation: dict) -> str:
"""Build the first user prompt from the reset observation."""
parts = []
parts.append(f"INCIDENT: {observation.get('incident_summary', '')}")
tool_descs = observation.get("tool_descriptions")
if tool_descs:
parts.append("\n--- AVAILABLE TOOL PARAMETERS ---")
for tool, meta in tool_descs.items():
parts.append(f" {tool}: {json.dumps(meta)}")
parts.append(
f"\nStep {observation.get('step_number', 0)}/{observation.get('max_steps', 20)}"
)
parts.append("\nBegin your investigation. Respond with your first action as JSON:")
return "\n".join(parts)
def build_tool_response_prompt(observation: dict) -> str:
"""Build a follow-up user prompt after a tool call."""
parts = []
tool_output = observation.get("tool_output", "")
if tool_output:
parts.append(f"Tool output:\n{tool_output}")
if observation.get("last_action_error"):
parts.append(f"\n⚠ ERROR: {observation['last_action_error']}")
step_num = observation.get("step_number", 0)
max_steps = observation.get("max_steps", 20)
parts.append(
f"\nStep {step_num}/{max_steps} | "
f"Cumulative reward: {observation.get('cumulative_reward', 0.0):.2f}"
)
# Add urgency nudges based on step progress
if step_num >= max_steps - 5:
parts.append("\n⚠⚠ CRITICAL: You MUST call submit_resolution NOW! No more investigation!")
elif step_num >= max_steps - 8:
parts.append("\n⚠ WARNING: Submit your resolution NOW. Call submit_resolution with your best analysis.")
elif step_num >= max_steps - 12:
parts.append("\nNote: Start forming your resolution. You should submit within the next 2-3 steps.")
parts.append("\nRespond with your next action as JSON:")
return "\n".join(parts)
# ── main loop ───────────────────────────────────────────────────────
async def run_task(task_id: int, base_url: str, client: OpenAI) -> float:
"""Run a single task against the environment via WebSocket. Returns grader score."""
import websockets
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = ws_url.rstrip("/") + "/ws"
async with websockets.connect(ws_url, ping_interval=120, ping_timeout=300) as ws:
# Reset
await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id}}))
resp = json.loads(await ws.recv())
data = resp["data"]
observation = data.get("observation", data)
done = data.get("done", False)
# [START] — mandatory structured log
print(f"[START] task=task_{task_id} env=sentinel_env model={MODEL_NAME}")
# Build multi-turn conversation
messages: list[dict] = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_initial_prompt(observation)},
]
final_score = 0.0
local_step = 0 # track client-side loop iterations
rewards_list: list[str] = [] # collect per-step rewards for [END] line
while not done:
local_step += 1
step_num = observation.get("step_number", local_step)
max_steps = observation.get("max_steps", 20)
# Safety: break if we've looped too many times without env advancing
if local_step > max_steps + 10:
print(f" Exceeded max loop iterations ({local_step}), breaking", flush=True)
break
print(f" Step {step_num} (iter {local_step}): calling LLM...", flush=True)
# Force resolution when nearing step limit
force_resolution = step_num >= max_steps - 8
if force_resolution:
print(f" Step {step_num}: FORCING resolution submission", flush=True)
# Add force prompt as an additional system nudge in the conversation
force_msg = {"role": "system", "content": FORCE_RESOLUTION_PROMPT}
call_messages = messages + [force_msg]
else:
call_messages = messages
# Trim conversation if too long (keep system + first user + last 20 messages)
if len(call_messages) > 30:
call_messages = call_messages[:2] + call_messages[-20:]
# Try to get a valid action from the LLM
action_dict = None
for attempt in range(MAX_PARSE_RETRIES):
try:
# Run sync LLM call in a thread so the event loop can
# still handle WebSocket pings during long reasoning calls
response = await asyncio.to_thread(
client.chat.completions.create,
model=MODEL_NAME,
messages=call_messages,
max_completion_tokens=MAX_COMPLETION_TOKENS,
)
raw = response.choices[0].message.content or ""
action_dict = parse_action(raw)
if action_dict:
# Store the assistant response in conversation
messages.append({"role": "assistant", "content": raw})
print(f" Step {step_num}: action={action_dict.get('tool_name', '?')}", flush=True)
break
else:
print(f" Step {step_num}: parse failed (attempt {attempt + 1}), raw={raw[:200]}", flush=True)
# On parse failure, add a nudge and retry
if attempt < MAX_PARSE_RETRIES - 1:
call_messages = call_messages + [
{"role": "assistant", "content": raw},
{"role": "user", "content": "That was not valid JSON. Respond with ONLY a JSON object like {\"tool_name\": \"...\", \"parameters\": {...}}"},
]
except Exception as e:
print(f" LLM error (attempt {attempt + 1}): {e}", file=sys.stderr)
if action_dict is None:
# Fallback: send an invalid action to let the env handle it
action_dict = {"tool_name": "_invalid_", "parameters": {}}
messages.append({"role": "assistant", "content": json.dumps(action_dict)})
# Normalize "all" query param to empty string (handler uses it as substring filter)
if action_dict.get("parameters", {}).get("query") == "all":
action_dict["parameters"]["query"] = ""
if action_dict.get("parameters", {}).get("severity") == "all":
action_dict["parameters"].pop("severity", None)
# Step the environment via WebSocket
await ws.send(json.dumps({"type": "step", "data": action_dict}))
resp = json.loads(await ws.recv())
data = resp["data"]
observation = data.get("observation", data)
done = data.get("done", False)
if observation.get("done"):
done = True
step_reward = observation.get("reward", 0.0)
rewards_list.append(f"{step_reward:.2f}")
# [STEP] — mandatory structured log
print(f"[STEP] step={step_num} action={action_dict.get('tool_name')} "
f"reward={step_reward:.2f} done={str(done).lower()} "
f"error={observation.get('last_action_error', 'null')}")
# Add tool response as user message in the conversation
if not done:
messages.append({"role": "user", "content": build_tool_response_prompt(observation)})
# Get final state for score
await ws.send(json.dumps({"type": "state"}))
resp = json.loads(await ws.recv())
state_data = resp["data"]
final_score = state_data.get("final_score", 0.0)
# [END] — mandatory structured log
print(f"[END] success={str(final_score > 0).lower()} steps={local_step} "
f"score={final_score:.2f} rewards={','.join(rewards_list)}")
return final_score
async def main() -> None:
llm_client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY,
)
print(f"Using LLM API: {API_BASE_URL} / model={MODEL_NAME}")
print(f"Environment URL: {ENV_URL}")
scores: dict[int, float] = {}
for task_id in [1, 2, 3]:
print(f"\n{'='*50}")
print(f"Running Task {task_id}...")
print(f"{'='*50}")
try:
score = await asyncio.wait_for(
run_task(task_id, ENV_URL, llm_client),
timeout=TASK_TIMEOUT,
)
scores[task_id] = score
except asyncio.TimeoutError:
print(f" Task {task_id} timed out after {TASK_TIMEOUT}s", file=sys.stderr)
scores[task_id] = 0.0
except Exception as e:
print(f" Task {task_id} failed: {e}", file=sys.stderr)
scores[task_id] = 0.0
print(f"Task {task_id}: {scores[task_id]:.2f}")
avg = sum(scores.values()) / len(scores) if scores else 0.0
print(f"\n{'='*50}")
print(f"Task 1: {scores.get(1, 0.0):.2f}")
print(f"Task 2: {scores.get(2, 0.0):.2f}")
print(f"Task 3: {scores.get(3, 0.0):.2f}")
print(f"Average: {avg:.2f}")
print(f"{'='*50}")
if __name__ == "__main__":
asyncio.run(main())