sre-incident-env / scripts /test_scenario.py
Maverick98's picture
V3: reward rewrite + 8-scenario minimal set with tuned difficulty
71168a0
#!/usr/bin/env python3
"""Test scenario(s) with full trace logging.
Saves complete trace per scenario to outputs/ directory.
All config via CLI args.
Usage:
# Single scenario
python scripts/test_scenario.py --scenario kafka_partition_rebalance_storm_001 --difficulty medium
# Multiple scenarios
python scripts/test_scenario.py \\
--scenario cert_expiry_mutual_tls_001 \\
--scenario kernel_tcp_rmem_silent_drop_001 \\
--difficulty expert
# All scenarios from a tier
python scripts/test_scenario.py --tier hard
# Custom config
python scripts/test_scenario.py \\
--tier medium \\
--model gpt-5.4 \\
--api-base https://api.openai.com/v1 \\
--api-key $OPENAI_API_KEY \\
--scenario-file scenarios/incidents_v3.jsonl \\
--server-url http://127.0.0.1:8766 \\
--max-steps 200 \\
--runs 2 \\
--output-dir outputs/manual_test
"""
import argparse
import asyncio
import json
import os
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from openai import OpenAI
# Make project root importable
sys.path.insert(0, str(Path(__file__).parent.parent))
from client import SREIncidentEnvHTTP
SYSTEM_PROMPT = """You are an expert on-call Site Reliability Engineer responding to a production incident.
# MISSION
Investigate the incident, identify the root cause, FIX the problem, and verify resolution.
This is a LIVE system — your actions have real consequences. Wrong fixes can make things worse.
# PROTOCOL
1. ORIENT: list_services to see the topology.
2. INVESTIGATE: read_logs across services to find which are affected and trace the error chain upstream.
3. DISCOVER: get_service_info on suspected services to learn what actions, config params, and recent deploys are available.
4. REMEDIATE: Apply a fix based on your diagnosis.
5. OBSERVE: After ANY remediation, call read_logs to see what changed.
6. VERIFY: When the system is healthy, call verify_resolution with your diagnosis.
# CRITICAL RULES
- DISCOVER BEFORE FIXING: Call get_service_info before using execute_runbook.
- OBSERVE AFTER FIXING: Always read_logs after a remediation to check the outcome.
- CAUSE ≠ EFFECT: The service with the most errors is usually a VICTIM, not the cause.
"""
def mcp_tools_to_openai(tools_raw):
out = []
for t in tools_raw:
schema = t.get("inputSchema", {}) if isinstance(t, dict) else getattr(t, "input_schema", {}) or {}
properties = {}
required = []
if schema and "properties" in schema:
for n, sc in schema["properties"].items():
prop = {"type": sc.get("type", "string")}
if "description" in sc:
prop["description"] = sc["description"]
properties[n] = prop
required = schema.get("required", [])
name = t["name"] if isinstance(t, dict) else t.name
desc = t.get("description", "") if isinstance(t, dict) else (t.description or "")
out.append({
"type": "function",
"function": {
"name": name,
"description": desc,
"parameters": {"type": "object", "properties": properties, "required": required},
},
})
return out
async def run_scenario(
llm: OpenAI,
server_url: str,
scenario_id: str,
difficulty: str,
model: str,
max_steps: int,
run_num: int,
) -> Dict[str, Any]:
"""Run one scenario, return full trace + summary."""
trace = []
started = time.time()
is_reasoning = any(x in model for x in ["o3", "o4", "gpt-5"])
try:
async with SREIncidentEnvHTTP(base_url=server_url, timeout=120) as env:
# Reset
obs = await env.reset(difficulty=difficulty, scenario_id=scenario_id)
tools_raw = await env.list_tools()
tools = mcp_tools_to_openai(tools_raw)
alert = obs.get("message", "Production incident detected.") if isinstance(obs, dict) else "Production incident."
trace.append({"role": "system", "content": SYSTEM_PROMPT})
trace.append({"role": "user", "content": alert})
messages = list(trace)
done = False
final_reward = None
steps_taken = 0
error = None
for step in range(1, max_steps + 1):
steps_taken = step
create_kwargs = {
"model": model,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
}
if is_reasoning:
create_kwargs["max_completion_tokens"] = 2000
else:
create_kwargs["temperature"] = 0.1
create_kwargs["max_tokens"] = 500
try:
resp = llm.chat.completions.create(**create_kwargs, timeout=90)
except Exception as e:
error = f"API error: {str(e)[:200]}"
print(f" T{step}: API ERROR - {error[:80]}", flush=True)
break
msg = resp.choices[0].message
if msg.tool_calls:
tc = msg.tool_calls[0]
tool_name = tc.function.name
try:
tool_args = json.loads(tc.function.arguments)
except json.JSONDecodeError:
tool_args = {}
args_str = json.dumps(tool_args)
print(f" T{step}: {tool_name}({args_str[:120]})", end="", flush=True)
asst_msg = {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tc.id,
"type": "function",
"function": {"name": tool_name, "arguments": tc.function.arguments},
}],
}
messages.append(asst_msg)
trace.append({**asst_msg, "_step": step})
try:
result = await env.call_tool(tool_name, **tool_args)
except Exception as e:
result = json.dumps({"error": f"transport_error: {str(e)[:100]}"})
if not isinstance(result, str):
result = json.dumps(result)
# Extract outcome for inline display
outcome = None
try:
parsed = json.loads(result)
outcome = parsed.get("outcome", "")
except json.JSONDecodeError:
pass
if outcome:
print(f" [{outcome}]", end="")
if env._last_done:
final_reward = env._last_reward
print(f" → reward={final_reward:.4f}", flush=True)
done = True
tool_msg = {"role": "tool", "tool_call_id": tc.id, "content": result}
messages.append(tool_msg)
trace.append({**tool_msg, "_step": step})
break
print(flush=True)
tool_msg = {"role": "tool", "tool_call_id": tc.id, "content": result}
messages.append(tool_msg)
trace.append({**tool_msg, "_step": step})
elif msg.content:
text_msg = {"role": "assistant", "content": msg.content}
messages.append(text_msg)
trace.append({**text_msg, "_step": step})
messages.append({"role": "user", "content": "Please use a tool."})
elapsed = time.time() - started
return {
"summary": {
"scenario_id": scenario_id,
"difficulty": difficulty,
"model": model,
"run": run_num,
"reward": final_reward,
"done": done,
"steps": steps_taken,
"elapsed_seconds": round(elapsed, 1),
"error": error,
"max_steps_hit": not done and not error,
},
"trace": trace,
}
except Exception as e:
return {
"summary": {
"scenario_id": scenario_id,
"difficulty": difficulty,
"model": model,
"run": run_num,
"reward": None,
"done": False,
"steps": 0,
"elapsed_seconds": round(time.time() - started, 1),
"error": f"session_error: {str(e)[:200]}",
"max_steps_hit": False,
},
"trace": trace,
}
def list_scenarios_in_file(scenario_file: str) -> List[Dict[str, str]]:
"""Read scenario file, return list of {id, difficulty}."""
out = []
with open(scenario_file) as f:
for line in f:
try:
s = json.loads(line)
out.append({"id": s["id"], "difficulty": s["difficulty"]})
except json.JSONDecodeError:
continue
return out
async def main_async():
parser = argparse.ArgumentParser(description="Test SRE incident scenarios with full trace logging")
parser.add_argument("--scenario", action="append", default=[], help="Scenario ID (repeatable)")
parser.add_argument("--tier", choices=["easy", "medium", "hard", "expert"], help="Run all scenarios in this tier")
parser.add_argument("--difficulty", help="Override difficulty for --scenario (auto-detected if omitted)")
parser.add_argument("--scenario-file", default="scenarios/incidents_v3.jsonl", help="Scenarios JSONL file")
parser.add_argument("--model", default="gpt-5.4", help="LLM model name")
parser.add_argument("--api-base", default=os.getenv("API_BASE_URL", "https://api.openai.com/v1"))
parser.add_argument("--api-key", default=os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN"))
parser.add_argument("--server-url", default="http://127.0.0.1:8766", help="SRE env server URL")
parser.add_argument("--max-steps", type=int, default=200)
parser.add_argument("--runs", type=int, default=1, help="Runs per scenario")
parser.add_argument("--output-dir", default=None, help="Output dir (default: outputs/test_TIMESTAMP)")
args = parser.parse_args()
if not args.api_key:
print("ERROR: --api-key required (or set OPENAI_API_KEY / HF_TOKEN env var)")
sys.exit(1)
# Resolve scenarios
all_scenarios = list_scenarios_in_file(args.scenario_file)
targets = []
if args.tier:
targets = [s for s in all_scenarios if s["difficulty"] == args.tier]
if args.scenario:
for sid in args.scenario:
match = next((s for s in all_scenarios if s["id"] == sid), None)
if not match:
print(f"WARNING: scenario {sid} not found in {args.scenario_file}")
continue
if args.difficulty:
match = {"id": sid, "difficulty": args.difficulty}
if match not in targets:
targets.append(match)
if not targets:
print("ERROR: No scenarios to run. Use --scenario or --tier")
sys.exit(1)
# Setup output dir
out_dir = Path(args.output_dir) if args.output_dir else Path(f"outputs/test_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Model: {args.model}")
print(f"Server: {args.server_url}")
print(f"Scenario file: {args.scenario_file}")
print(f"Max steps: {args.max_steps}")
print(f"Runs per scenario: {args.runs}")
print(f"Output: {out_dir}")
print(f"Targets: {len(targets)} scenario(s) x {args.runs} runs = {len(targets) * args.runs} episodes")
print("=" * 70)
llm = OpenAI(base_url=args.api_base, api_key=args.api_key)
all_results = []
for sc in targets:
for run_num in range(1, args.runs + 1):
print(f"\n=== {sc['id']} ({sc['difficulty']}) RUN {run_num} ===")
result = await run_scenario(
llm=llm,
server_url=args.server_url,
scenario_id=sc["id"],
difficulty=sc["difficulty"],
model=args.model,
max_steps=args.max_steps,
run_num=run_num,
)
all_results.append(result["summary"])
# Save per-scenario trace
fname = f"{sc['id']}_run{run_num}.json"
with open(out_dir / fname, "w") as f:
json.dump(result, f, indent=2)
# Summary
print(f"\n{'=' * 70}")
print("SUMMARY")
print(f"{'=' * 70}")
for r in all_results:
sid = r["scenario_id"]
diff = r["difficulty"]
run = r["run"]
if r["reward"] is not None:
status = f"reward={r['reward']:.4f} steps={r['steps']}"
elif r["max_steps_hit"]:
status = f"max_steps ({r['steps']})"
elif r["error"]:
status = f"error: {r['error'][:50]}"
else:
status = "unknown"
print(f" [{diff:6s}] {sid[:40]:40s} R{run}: {status}")
# Save combined
combined = {
"model": args.model,
"scenario_file": args.scenario_file,
"max_steps": args.max_steps,
"timestamp": datetime.now().isoformat(),
"results": all_results,
}
with open(out_dir / "_summary.json", "w") as f:
json.dump(combined, f, indent=2)
print(f"\nFull traces saved to: {out_dir}")
def main():
asyncio.run(main_async())
if __name__ == "__main__":
main()