File size: 5,530 Bytes
015ca5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import annotations

import argparse
import json
from pathlib import Path
from statistics import mean
import sys

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from incident_commander_env.models import IncidentAction
from incident_commander_env.scenarios import FAMILIES
from incident_commander_env.server.incident_commander_environment import IncidentCommanderEnvironment
from scripts.manual_episode import SOLUTIONS, first_relevant_metric


def step(env: IncidentCommanderEnvironment, tool_name: str, **arguments):
    return env.step(IncidentAction(tool_name=tool_name, arguments=arguments))


def weak_policy(env: IncidentCommanderEnvironment):
    obs = env.reset(scenario_id=env.scenario.family)
    service = env.scenario.services[0]
    step(env, "query_logs", service=service, query="error", minutes=15)
    return step(
        env,
        "final_report",
        root_cause="database is slow",
        mitigation="restart the database",
        customer_update="We are investigating customer impact and will update soon.",
    )


def oracle_policy(env: IncidentCommanderEnvironment):
    obs = env.reset(scenario_id=env.scenario.family)
    solution = SOLUTIONS[env.scenario.family]
    first = env.scenario.services[0]
    step(env, "query_logs", service=first, query="error timeout failed warn", minutes=20)
    metric_service, metric = first_relevant_metric(env)
    step(env, "inspect_metric", service=metric_service, metric=metric, minutes=30)
    step(env, "read_runbook", topic=solution["topic"])
    step(env, "test_hypothesis", root_cause=solution["root"])
    step(env, "apply_mitigation", mitigation=solution["mitigation"])
    return step(
        env,
        "final_report",
        root_cause=solution["root"],
        mitigation=solution["mitigation"],
        customer_update=(
            "Customers are impacted by elevated failures. We applied mitigation and "
            "expect recovery within 15 minutes; next update in 10 minutes."
        ),
    )


def run_policy(policy_name: str, policy_fn):
    rows = []
    for family in FAMILIES:
        env = IncidentCommanderEnvironment(scenario_id=family)
        final = policy_fn(env)
        breakdown = final.score_breakdown
        rows.append(
            {
                "scenario": family,
                "reward": final.reward,
                "root_cause_ok": breakdown.get("root_cause", 0.0) > 0,
                "mitigation_ok": breakdown.get("mitigation", 0.0) > 0,
                "unsafe": breakdown.get("unsafe_penalty", 0.0) < 0,
                "tool_calls": final.turn,
            }
        )
    return {
        "policy": policy_name,
        "mean_reward": round(mean(row["reward"] for row in rows), 4),
        "root_cause_accuracy": round(mean(row["root_cause_ok"] for row in rows), 4),
        "mitigation_accuracy": round(mean(row["mitigation_ok"] for row in rows), 4),
        "unsafe_rate": round(mean(row["unsafe"] for row in rows), 4),
        "avg_tool_calls": round(mean(row["tool_calls"] for row in rows), 2),
        "scenarios": rows,
    }


def write_assets(results: list[dict]) -> None:
    try:
        from PIL import Image, ImageDraw
    except Exception:
        return

    assets = Path("assets")
    assets.mkdir(exist_ok=True)
    width, height = 900, 520
    image = Image.new("RGB", (width, height), "white")
    draw = ImageDraw.Draw(image)
    draw.text((30, 25), "Incident Commander RL Arena: Baseline vs Oracle", fill=(20, 20, 20))
    max_bar = 620
    y = 110
    colors = [(180, 70, 70), (40, 125, 85)]
    for idx, result in enumerate(results):
        reward = result["mean_reward"]
        bar_len = int(max_bar * reward)
        draw.rectangle((220, y, 220 + bar_len, y + 44), fill=colors[idx % len(colors)])
        draw.text((30, y + 12), result["policy"], fill=(20, 20, 20))
        draw.text((230 + bar_len, y + 12), f"{reward:.2f}", fill=(20, 20, 20))
        y += 80
    draw.line((220, 310, 840, 310), fill=(80, 80, 80), width=2)
    for tick in range(0, 6):
        x = 220 + tick * max_bar // 5
        draw.line((x, 304, x, 316), fill=(80, 80, 80), width=2)
        draw.text((x - 10, 322), f"{tick / 5:.1f}", fill=(80, 80, 80))
    image.save(assets / "baseline_vs_oracle.png")

    curve = Image.new("RGB", (width, height), "white")
    draw = ImageDraw.Draw(curve)
    draw.text((30, 25), "Reward curve template - replace with GRPO run output", fill=(20, 20, 20))
    draw.line((80, 440, 830, 440), fill=(60, 60, 60), width=2)
    draw.line((80, 80, 80, 440), fill=(60, 60, 60), width=2)
    points = [(80 + i * 75, 410 - int((i / 10) ** 0.8 * 260)) for i in range(11)]
    draw.line(points, fill=(40, 125, 85), width=4)
    draw.text((390, 462), "training step", fill=(60, 60, 60))
    draw.text((15, 70), "reward", fill=(60, 60, 60))
    curve.save(assets / "reward_curve_template.png")


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--write-assets", action="store_true")
    args = parser.parse_args()

    results = [
        run_policy("weak_baseline", weak_policy),
        run_policy("oracle_trace", oracle_policy),
    ]
    output_dir = Path("outputs/evals")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "baseline_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
    print(json.dumps(results, indent=2))
    if args.write_assets:
        write_assets(results)


if __name__ == "__main__":
    main()