incident_commander_env / scripts /evaluate_baselines.py
Dar4devil's picture
Build incident commander OpenEnv arena
015ca5f
Raw
History Blame Contribute Delete
5.53 kB
from __future__ import annotations
import argparse
import json
from pathlib import Path
from statistics import mean
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from incident_commander_env.models import IncidentAction
from incident_commander_env.scenarios import FAMILIES
from incident_commander_env.server.incident_commander_environment import IncidentCommanderEnvironment
from scripts.manual_episode import SOLUTIONS, first_relevant_metric
def step(env: IncidentCommanderEnvironment, tool_name: str, **arguments):
return env.step(IncidentAction(tool_name=tool_name, arguments=arguments))
def weak_policy(env: IncidentCommanderEnvironment):
obs = env.reset(scenario_id=env.scenario.family)
service = env.scenario.services[0]
step(env, "query_logs", service=service, query="error", minutes=15)
return step(
env,
"final_report",
root_cause="database is slow",
mitigation="restart the database",
customer_update="We are investigating customer impact and will update soon.",
)
def oracle_policy(env: IncidentCommanderEnvironment):
obs = env.reset(scenario_id=env.scenario.family)
solution = SOLUTIONS[env.scenario.family]
first = env.scenario.services[0]
step(env, "query_logs", service=first, query="error timeout failed warn", minutes=20)
metric_service, metric = first_relevant_metric(env)
step(env, "inspect_metric", service=metric_service, metric=metric, minutes=30)
step(env, "read_runbook", topic=solution["topic"])
step(env, "test_hypothesis", root_cause=solution["root"])
step(env, "apply_mitigation", mitigation=solution["mitigation"])
return step(
env,
"final_report",
root_cause=solution["root"],
mitigation=solution["mitigation"],
customer_update=(
"Customers are impacted by elevated failures. We applied mitigation and "
"expect recovery within 15 minutes; next update in 10 minutes."
),
)
def run_policy(policy_name: str, policy_fn):
rows = []
for family in FAMILIES:
env = IncidentCommanderEnvironment(scenario_id=family)
final = policy_fn(env)
breakdown = final.score_breakdown
rows.append(
{
"scenario": family,
"reward": final.reward,
"root_cause_ok": breakdown.get("root_cause", 0.0) > 0,
"mitigation_ok": breakdown.get("mitigation", 0.0) > 0,
"unsafe": breakdown.get("unsafe_penalty", 0.0) < 0,
"tool_calls": final.turn,
}
)
return {
"policy": policy_name,
"mean_reward": round(mean(row["reward"] for row in rows), 4),
"root_cause_accuracy": round(mean(row["root_cause_ok"] for row in rows), 4),
"mitigation_accuracy": round(mean(row["mitigation_ok"] for row in rows), 4),
"unsafe_rate": round(mean(row["unsafe"] for row in rows), 4),
"avg_tool_calls": round(mean(row["tool_calls"] for row in rows), 2),
"scenarios": rows,
}
def write_assets(results: list[dict]) -> None:
try:
from PIL import Image, ImageDraw
except Exception:
return
assets = Path("assets")
assets.mkdir(exist_ok=True)
width, height = 900, 520
image = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(image)
draw.text((30, 25), "Incident Commander RL Arena: Baseline vs Oracle", fill=(20, 20, 20))
max_bar = 620
y = 110
colors = [(180, 70, 70), (40, 125, 85)]
for idx, result in enumerate(results):
reward = result["mean_reward"]
bar_len = int(max_bar * reward)
draw.rectangle((220, y, 220 + bar_len, y + 44), fill=colors[idx % len(colors)])
draw.text((30, y + 12), result["policy"], fill=(20, 20, 20))
draw.text((230 + bar_len, y + 12), f"{reward:.2f}", fill=(20, 20, 20))
y += 80
draw.line((220, 310, 840, 310), fill=(80, 80, 80), width=2)
for tick in range(0, 6):
x = 220 + tick * max_bar // 5
draw.line((x, 304, x, 316), fill=(80, 80, 80), width=2)
draw.text((x - 10, 322), f"{tick / 5:.1f}", fill=(80, 80, 80))
image.save(assets / "baseline_vs_oracle.png")
curve = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(curve)
draw.text((30, 25), "Reward curve template - replace with GRPO run output", fill=(20, 20, 20))
draw.line((80, 440, 830, 440), fill=(60, 60, 60), width=2)
draw.line((80, 80, 80, 440), fill=(60, 60, 60), width=2)
points = [(80 + i * 75, 410 - int((i / 10) ** 0.8 * 260)) for i in range(11)]
draw.line(points, fill=(40, 125, 85), width=4)
draw.text((390, 462), "training step", fill=(60, 60, 60))
draw.text((15, 70), "reward", fill=(60, 60, 60))
curve.save(assets / "reward_curve_template.png")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--write-assets", action="store_true")
args = parser.parse_args()
results = [
run_policy("weak_baseline", weak_policy),
run_policy("oracle_trace", oracle_policy),
]
output_dir = Path("outputs/evals")
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "baseline_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
print(json.dumps(results, indent=2))
if args.write_assets:
write_assets(results)
if __name__ == "__main__":
main()