Spaces:
Sleeping
Sleeping
File size: 5,530 Bytes
015ca5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | from __future__ import annotations
import argparse
import json
from pathlib import Path
from statistics import mean
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from incident_commander_env.models import IncidentAction
from incident_commander_env.scenarios import FAMILIES
from incident_commander_env.server.incident_commander_environment import IncidentCommanderEnvironment
from scripts.manual_episode import SOLUTIONS, first_relevant_metric
def step(env: IncidentCommanderEnvironment, tool_name: str, **arguments):
return env.step(IncidentAction(tool_name=tool_name, arguments=arguments))
def weak_policy(env: IncidentCommanderEnvironment):
obs = env.reset(scenario_id=env.scenario.family)
service = env.scenario.services[0]
step(env, "query_logs", service=service, query="error", minutes=15)
return step(
env,
"final_report",
root_cause="database is slow",
mitigation="restart the database",
customer_update="We are investigating customer impact and will update soon.",
)
def oracle_policy(env: IncidentCommanderEnvironment):
obs = env.reset(scenario_id=env.scenario.family)
solution = SOLUTIONS[env.scenario.family]
first = env.scenario.services[0]
step(env, "query_logs", service=first, query="error timeout failed warn", minutes=20)
metric_service, metric = first_relevant_metric(env)
step(env, "inspect_metric", service=metric_service, metric=metric, minutes=30)
step(env, "read_runbook", topic=solution["topic"])
step(env, "test_hypothesis", root_cause=solution["root"])
step(env, "apply_mitigation", mitigation=solution["mitigation"])
return step(
env,
"final_report",
root_cause=solution["root"],
mitigation=solution["mitigation"],
customer_update=(
"Customers are impacted by elevated failures. We applied mitigation and "
"expect recovery within 15 minutes; next update in 10 minutes."
),
)
def run_policy(policy_name: str, policy_fn):
rows = []
for family in FAMILIES:
env = IncidentCommanderEnvironment(scenario_id=family)
final = policy_fn(env)
breakdown = final.score_breakdown
rows.append(
{
"scenario": family,
"reward": final.reward,
"root_cause_ok": breakdown.get("root_cause", 0.0) > 0,
"mitigation_ok": breakdown.get("mitigation", 0.0) > 0,
"unsafe": breakdown.get("unsafe_penalty", 0.0) < 0,
"tool_calls": final.turn,
}
)
return {
"policy": policy_name,
"mean_reward": round(mean(row["reward"] for row in rows), 4),
"root_cause_accuracy": round(mean(row["root_cause_ok"] for row in rows), 4),
"mitigation_accuracy": round(mean(row["mitigation_ok"] for row in rows), 4),
"unsafe_rate": round(mean(row["unsafe"] for row in rows), 4),
"avg_tool_calls": round(mean(row["tool_calls"] for row in rows), 2),
"scenarios": rows,
}
def write_assets(results: list[dict]) -> None:
try:
from PIL import Image, ImageDraw
except Exception:
return
assets = Path("assets")
assets.mkdir(exist_ok=True)
width, height = 900, 520
image = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(image)
draw.text((30, 25), "Incident Commander RL Arena: Baseline vs Oracle", fill=(20, 20, 20))
max_bar = 620
y = 110
colors = [(180, 70, 70), (40, 125, 85)]
for idx, result in enumerate(results):
reward = result["mean_reward"]
bar_len = int(max_bar * reward)
draw.rectangle((220, y, 220 + bar_len, y + 44), fill=colors[idx % len(colors)])
draw.text((30, y + 12), result["policy"], fill=(20, 20, 20))
draw.text((230 + bar_len, y + 12), f"{reward:.2f}", fill=(20, 20, 20))
y += 80
draw.line((220, 310, 840, 310), fill=(80, 80, 80), width=2)
for tick in range(0, 6):
x = 220 + tick * max_bar // 5
draw.line((x, 304, x, 316), fill=(80, 80, 80), width=2)
draw.text((x - 10, 322), f"{tick / 5:.1f}", fill=(80, 80, 80))
image.save(assets / "baseline_vs_oracle.png")
curve = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(curve)
draw.text((30, 25), "Reward curve template - replace with GRPO run output", fill=(20, 20, 20))
draw.line((80, 440, 830, 440), fill=(60, 60, 60), width=2)
draw.line((80, 80, 80, 440), fill=(60, 60, 60), width=2)
points = [(80 + i * 75, 410 - int((i / 10) ** 0.8 * 260)) for i in range(11)]
draw.line(points, fill=(40, 125, 85), width=4)
draw.text((390, 462), "training step", fill=(60, 60, 60))
draw.text((15, 70), "reward", fill=(60, 60, 60))
curve.save(assets / "reward_curve_template.png")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--write-assets", action="store_true")
args = parser.parse_args()
results = [
run_policy("weak_baseline", weak_policy),
run_policy("oracle_trace", oracle_policy),
]
output_dir = Path("outputs/evals")
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "baseline_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
print(json.dumps(results, indent=2))
if args.write_assets:
write_assets(results)
if __name__ == "__main__":
main()
|