Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from statistics import mean | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from incident_commander_env.models import IncidentAction | |
| from incident_commander_env.scenarios import FAMILIES | |
| from incident_commander_env.server.incident_commander_environment import IncidentCommanderEnvironment | |
| from scripts.manual_episode import SOLUTIONS, first_relevant_metric | |
| def step(env: IncidentCommanderEnvironment, tool_name: str, **arguments): | |
| return env.step(IncidentAction(tool_name=tool_name, arguments=arguments)) | |
| def weak_policy(env: IncidentCommanderEnvironment): | |
| obs = env.reset(scenario_id=env.scenario.family) | |
| service = env.scenario.services[0] | |
| step(env, "query_logs", service=service, query="error", minutes=15) | |
| return step( | |
| env, | |
| "final_report", | |
| root_cause="database is slow", | |
| mitigation="restart the database", | |
| customer_update="We are investigating customer impact and will update soon.", | |
| ) | |
| def oracle_policy(env: IncidentCommanderEnvironment): | |
| obs = env.reset(scenario_id=env.scenario.family) | |
| solution = SOLUTIONS[env.scenario.family] | |
| first = env.scenario.services[0] | |
| step(env, "query_logs", service=first, query="error timeout failed warn", minutes=20) | |
| metric_service, metric = first_relevant_metric(env) | |
| step(env, "inspect_metric", service=metric_service, metric=metric, minutes=30) | |
| step(env, "read_runbook", topic=solution["topic"]) | |
| step(env, "test_hypothesis", root_cause=solution["root"]) | |
| step(env, "apply_mitigation", mitigation=solution["mitigation"]) | |
| return step( | |
| env, | |
| "final_report", | |
| root_cause=solution["root"], | |
| mitigation=solution["mitigation"], | |
| customer_update=( | |
| "Customers are impacted by elevated failures. We applied mitigation and " | |
| "expect recovery within 15 minutes; next update in 10 minutes." | |
| ), | |
| ) | |
| def run_policy(policy_name: str, policy_fn): | |
| rows = [] | |
| for family in FAMILIES: | |
| env = IncidentCommanderEnvironment(scenario_id=family) | |
| final = policy_fn(env) | |
| breakdown = final.score_breakdown | |
| rows.append( | |
| { | |
| "scenario": family, | |
| "reward": final.reward, | |
| "root_cause_ok": breakdown.get("root_cause", 0.0) > 0, | |
| "mitigation_ok": breakdown.get("mitigation", 0.0) > 0, | |
| "unsafe": breakdown.get("unsafe_penalty", 0.0) < 0, | |
| "tool_calls": final.turn, | |
| } | |
| ) | |
| return { | |
| "policy": policy_name, | |
| "mean_reward": round(mean(row["reward"] for row in rows), 4), | |
| "root_cause_accuracy": round(mean(row["root_cause_ok"] for row in rows), 4), | |
| "mitigation_accuracy": round(mean(row["mitigation_ok"] for row in rows), 4), | |
| "unsafe_rate": round(mean(row["unsafe"] for row in rows), 4), | |
| "avg_tool_calls": round(mean(row["tool_calls"] for row in rows), 2), | |
| "scenarios": rows, | |
| } | |
| def write_assets(results: list[dict]) -> None: | |
| try: | |
| from PIL import Image, ImageDraw | |
| except Exception: | |
| return | |
| assets = Path("assets") | |
| assets.mkdir(exist_ok=True) | |
| width, height = 900, 520 | |
| image = Image.new("RGB", (width, height), "white") | |
| draw = ImageDraw.Draw(image) | |
| draw.text((30, 25), "Incident Commander RL Arena: Baseline vs Oracle", fill=(20, 20, 20)) | |
| max_bar = 620 | |
| y = 110 | |
| colors = [(180, 70, 70), (40, 125, 85)] | |
| for idx, result in enumerate(results): | |
| reward = result["mean_reward"] | |
| bar_len = int(max_bar * reward) | |
| draw.rectangle((220, y, 220 + bar_len, y + 44), fill=colors[idx % len(colors)]) | |
| draw.text((30, y + 12), result["policy"], fill=(20, 20, 20)) | |
| draw.text((230 + bar_len, y + 12), f"{reward:.2f}", fill=(20, 20, 20)) | |
| y += 80 | |
| draw.line((220, 310, 840, 310), fill=(80, 80, 80), width=2) | |
| for tick in range(0, 6): | |
| x = 220 + tick * max_bar // 5 | |
| draw.line((x, 304, x, 316), fill=(80, 80, 80), width=2) | |
| draw.text((x - 10, 322), f"{tick / 5:.1f}", fill=(80, 80, 80)) | |
| image.save(assets / "baseline_vs_oracle.png") | |
| curve = Image.new("RGB", (width, height), "white") | |
| draw = ImageDraw.Draw(curve) | |
| draw.text((30, 25), "Reward curve template - replace with GRPO run output", fill=(20, 20, 20)) | |
| draw.line((80, 440, 830, 440), fill=(60, 60, 60), width=2) | |
| draw.line((80, 80, 80, 440), fill=(60, 60, 60), width=2) | |
| points = [(80 + i * 75, 410 - int((i / 10) ** 0.8 * 260)) for i in range(11)] | |
| draw.line(points, fill=(40, 125, 85), width=4) | |
| draw.text((390, 462), "training step", fill=(60, 60, 60)) | |
| draw.text((15, 70), "reward", fill=(60, 60, 60)) | |
| curve.save(assets / "reward_curve_template.png") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--write-assets", action="store_true") | |
| args = parser.parse_args() | |
| results = [ | |
| run_policy("weak_baseline", weak_policy), | |
| run_policy("oracle_trace", oracle_policy), | |
| ] | |
| output_dir = Path("outputs/evals") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| (output_dir / "baseline_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8") | |
| print(json.dumps(results, indent=2)) | |
| if args.write_assets: | |
| write_assets(results) | |
| if __name__ == "__main__": | |
| main() | |