File size: 6,160 Bytes
8b92d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
Oracle, heuristic, and random baselines for Stack Doctor.

Used to validate the reward function: random < heuristic < oracle must hold.
"""

from __future__ import annotations

import json
import random

from .scenarios import (
    ROOT_CAUSE_TO_FIX,
    ROOT_CAUSES,
    FIXES,
    SPECIALISTS,
    Scenario,
    SCENARIOS,
    TRAIN_SCENARIOS,
    EVAL_SCENARIOS,
)


def oracle_policy(scenario: Scenario) -> list[dict]:
    """Perfect policy: submit correct answer in 1 step."""
    return [
        {
            "type": "submit",
            "root_cause": scenario.root_cause,
            "fix": scenario.correct_fix,
            "justification": f"Root cause is {scenario.root_cause}, applying the correct fix.",
        }
    ]


def heuristic_policy(scenario: Scenario) -> list[dict]:
    """
    Reasonable heuristic: inspect logs, ask the highest-confidence specialist,
    then submit based on clues.

    Uses keyword matching on specialist opinions and logs to guess root cause.
    """
    actions = []

    # Step 1: inspect logs
    actions.append({"type": "inspect", "target": "logs"})

    # Step 2: ask the highest-confidence specialist
    best_spec = max(
        scenario.specialist_opinions.items(),
        key=lambda kv: kv[1].confidence,
    )
    actions.append({"type": "ask_specialist", "specialist": best_spec[0]})

    # Step 3: heuristic root-cause guess from keywords
    combined_text = (
        scenario.incident_ticket
        + " " + scenario.initial_log
        + " " + best_spec[1].opinion
    ).lower()

    guess = _keyword_guess(combined_text)

    # Step 4: apply fix
    actions.append({"type": "apply_fix", "fix": ROOT_CAUSE_TO_FIX[guess]})

    # Step 5: submit
    actions.append({
        "type": "submit",
        "root_cause": guess,
        "fix": ROOT_CAUSE_TO_FIX[guess],
    })

    return actions


def random_policy(scenario: Scenario) -> list[dict]:
    """Random policy: random actions, random submit."""
    actions = []
    n_steps = random.randint(1, 5)

    for _ in range(n_steps - 1):
        choice = random.choice(["inspect", "ask_specialist"])
        if choice == "inspect":
            actions.append({
                "type": "inspect",
                "target": random.choice(["logs", "config", "snippet", "metrics"]),
            })
        else:
            actions.append({
                "type": "ask_specialist",
                "specialist": random.choice(SPECIALISTS),
            })

    # Final: random submit
    rc = random.choice(ROOT_CAUSES)
    actions.append({
        "type": "submit",
        "root_cause": rc,
        "fix": ROOT_CAUSE_TO_FIX[rc],
    })

    return actions


def _keyword_guess(text: str) -> str:
    """Guess root cause from keyword presence in text."""
    scores = {
        "arch_guard": 0,
        "backend_whitelist": 0,
        "runtime_loader": 0,
        "backend_selector": 0,
        "model_config": 0,
        "weight_layout": 0,
    }

    # arch_guard keywords
    for kw in ["arch", "architecture", "sm_12", "sm_120", "sm_121", "supported_arch", "capability", "is_supported"]:
        if kw in text:
            scores["arch_guard"] += 1

    # backend_whitelist keywords
    for kw in ["whitelist", "supported_gpu", "not in", "marlin", "awq", "gpu name"]:
        if kw in text:
            scores["backend_whitelist"] += 1

    # runtime_loader keywords
    for kw in ["runtime", "libcuda", "ld_library", "cuda_home", "symlink", "shared object", "rocm_path", "hipError"]:
        if kw in text:
            scores["runtime_loader"] += 1

    # backend_selector keywords
    for kw in ["backend", "selector", "xformers", "flash_attn", "latency", "slow", "e4m3fn", "fp8 format"]:
        if kw in text:
            scores["backend_selector"] += 1

    # model_config keywords
    for kw in ["config", "num_expert", "shape mismatch", "rope", "checkpoint", "config.json"]:
        if kw in text:
            scores["model_config"] += 1

    # weight_layout keywords
    for kw in ["weight", "mapping", "swap", "gate_proj", "up_proj", "convert", "layout", "qkv"]:
        if kw in text:
            scores["weight_layout"] += 1

    return max(scores, key=scores.get)


def evaluate_policy(policy_fn, scenarios: list[Scenario], n_runs: int = 1) -> dict:
    """
    Run a policy across scenarios and compute metrics.

    Returns dict with:
      - rc_accuracy: fraction of correct root cause submissions
      - fix_accuracy: fraction of correct fix submissions
      - avg_steps: average steps to resolution
      - avg_reward: average cumulative reward
    """
    from .stack_doctor_environment import StackDoctorEnvironment
    from models import StackDoctorAction

    total_rc_correct = 0
    total_fix_correct = 0
    total_steps = 0
    total_reward = 0.0
    total_episodes = 0

    for _ in range(n_runs):
        for scenario in scenarios:
            env = StackDoctorEnvironment()
            env.reset(scenario_id=scenario.id)

            actions = policy_fn(scenario)
            cumulative = 0.0
            steps = 0

            for action_dict in actions:
                obs = env.step(StackDoctorAction(message=json.dumps(action_dict)))
                cumulative += obs.reward
                steps += 1
                if obs.done:
                    break

            # Check if submit happened
            last_action = actions[-1] if actions else {}
            if last_action.get("type") == "submit":
                if last_action["root_cause"] == scenario.root_cause:
                    total_rc_correct += 1
                if last_action["fix"] == scenario.correct_fix:
                    total_fix_correct += 1

            total_steps += steps
            total_reward += cumulative
            total_episodes += 1

    return {
        "rc_accuracy": total_rc_correct / total_episodes if total_episodes else 0,
        "fix_accuracy": total_fix_correct / total_episodes if total_episodes else 0,
        "avg_steps": total_steps / total_episodes if total_episodes else 0,
        "avg_reward": total_reward / total_episodes if total_episodes else 0,
        "n_episodes": total_episodes,
    }