from __future__ import annotations import argparse import json from pathlib import Path from statistics import mean import sys ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from incident_commander_env.models import IncidentAction from incident_commander_env.scenarios import FAMILIES from incident_commander_env.server.incident_commander_environment import IncidentCommanderEnvironment from scripts.manual_episode import SOLUTIONS, first_relevant_metric def step(env: IncidentCommanderEnvironment, tool_name: str, **arguments): return env.step(IncidentAction(tool_name=tool_name, arguments=arguments)) def weak_policy(env: IncidentCommanderEnvironment): obs = env.reset(scenario_id=env.scenario.family) service = env.scenario.services[0] step(env, "query_logs", service=service, query="error", minutes=15) return step( env, "final_report", root_cause="database is slow", mitigation="restart the database", customer_update="We are investigating customer impact and will update soon.", ) def oracle_policy(env: IncidentCommanderEnvironment): obs = env.reset(scenario_id=env.scenario.family) solution = SOLUTIONS[env.scenario.family] first = env.scenario.services[0] step(env, "query_logs", service=first, query="error timeout failed warn", minutes=20) metric_service, metric = first_relevant_metric(env) step(env, "inspect_metric", service=metric_service, metric=metric, minutes=30) step(env, "read_runbook", topic=solution["topic"]) step(env, "test_hypothesis", root_cause=solution["root"]) step(env, "apply_mitigation", mitigation=solution["mitigation"]) return step( env, "final_report", root_cause=solution["root"], mitigation=solution["mitigation"], customer_update=( "Customers are impacted by elevated failures. We applied mitigation and " "expect recovery within 15 minutes; next update in 10 minutes." ), ) def run_policy(policy_name: str, policy_fn): rows = [] for family in FAMILIES: env = IncidentCommanderEnvironment(scenario_id=family) final = policy_fn(env) breakdown = final.score_breakdown rows.append( { "scenario": family, "reward": final.reward, "root_cause_ok": breakdown.get("root_cause", 0.0) > 0, "mitigation_ok": breakdown.get("mitigation", 0.0) > 0, "unsafe": breakdown.get("unsafe_penalty", 0.0) < 0, "tool_calls": final.turn, } ) return { "policy": policy_name, "mean_reward": round(mean(row["reward"] for row in rows), 4), "root_cause_accuracy": round(mean(row["root_cause_ok"] for row in rows), 4), "mitigation_accuracy": round(mean(row["mitigation_ok"] for row in rows), 4), "unsafe_rate": round(mean(row["unsafe"] for row in rows), 4), "avg_tool_calls": round(mean(row["tool_calls"] for row in rows), 2), "scenarios": rows, } def write_assets(results: list[dict]) -> None: try: from PIL import Image, ImageDraw except Exception: return assets = Path("assets") assets.mkdir(exist_ok=True) width, height = 900, 520 image = Image.new("RGB", (width, height), "white") draw = ImageDraw.Draw(image) draw.text((30, 25), "Incident Commander RL Arena: Baseline vs Oracle", fill=(20, 20, 20)) max_bar = 620 y = 110 colors = [(180, 70, 70), (40, 125, 85)] for idx, result in enumerate(results): reward = result["mean_reward"] bar_len = int(max_bar * reward) draw.rectangle((220, y, 220 + bar_len, y + 44), fill=colors[idx % len(colors)]) draw.text((30, y + 12), result["policy"], fill=(20, 20, 20)) draw.text((230 + bar_len, y + 12), f"{reward:.2f}", fill=(20, 20, 20)) y += 80 draw.line((220, 310, 840, 310), fill=(80, 80, 80), width=2) for tick in range(0, 6): x = 220 + tick * max_bar // 5 draw.line((x, 304, x, 316), fill=(80, 80, 80), width=2) draw.text((x - 10, 322), f"{tick / 5:.1f}", fill=(80, 80, 80)) image.save(assets / "baseline_vs_oracle.png") curve = Image.new("RGB", (width, height), "white") draw = ImageDraw.Draw(curve) draw.text((30, 25), "Reward curve template - replace with GRPO run output", fill=(20, 20, 20)) draw.line((80, 440, 830, 440), fill=(60, 60, 60), width=2) draw.line((80, 80, 80, 440), fill=(60, 60, 60), width=2) points = [(80 + i * 75, 410 - int((i / 10) ** 0.8 * 260)) for i in range(11)] draw.line(points, fill=(40, 125, 85), width=4) draw.text((390, 462), "training step", fill=(60, 60, 60)) draw.text((15, 70), "reward", fill=(60, 60, 60)) curve.save(assets / "reward_curve_template.png") def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--write-assets", action="store_true") args = parser.parse_args() results = [ run_policy("weak_baseline", weak_policy), run_policy("oracle_trace", oracle_policy), ] output_dir = Path("outputs/evals") output_dir.mkdir(parents=True, exist_ok=True) (output_dir / "baseline_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8") print(json.dumps(results, indent=2)) if args.write_assets: write_assets(results) if __name__ == "__main__": main()