File size: 16,231 Bytes
8b92d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
Stack Doctor MCP Environment.

Wraps the core Stack Doctor environment with MCP tools that agents
can discover and invoke. This is the agent-facing interface —
agents call tools like read_log(), query_specialist(), submit_diagnosis()
instead of constructing JSON action strings.

The training (WebSocket) API still works through _step_impl().
"""

from __future__ import annotations

import json
from typing import Any, Optional
from uuid import uuid4

from mcp.server.fastmcp import FastMCP
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, Observation, State

from models import StackDoctorAction, StackDoctorObservation
from .scenarios import (
    ROOT_CAUSE_TO_FIX,
    FIX_TO_ROOT_CAUSE,
    ROOT_CAUSES,
    FIXES,
    SPECIALISTS,
    Scenario,
    get_scenario,
)

MAX_STEPS = 6
VALID_FIXES = set(FIXES)
VALID_ROOT_CAUSES = set(ROOT_CAUSES)


class StackDoctorMCPEnvironment(MCPEnvironment):
    """
    Stack Doctor with MCP tool interface for agent interaction.

    Agents discover available tools (read_log, check_config, view_code,
    run_diagnostic, query_specialist, apply_fix, submit_diagnosis) and
    call them to investigate incidents and submit diagnoses.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        mcp = FastMCP("stack_doctor")
        self._state_obj = State(episode_id=str(uuid4()), step_count=0)
        self._scenario: Scenario | None = None
        self._step_count = 0
        self._fix_applied = False
        self._fix_was_correct: bool | None = None
        self._done = False
        self._cumulative_reward = 0.0
        self._actions_taken: list[dict] = []

        env = self  # capture for closures

        @mcp.tool()
        def read_log() -> str:
            """Read system and application logs for the current incident.
            Returns log output from the affected inference stack including
            error messages, warnings, and system state information.
            Costs 1 step (-0.25 reward)."""
            return env._do_inspect("logs")

        @mcp.tool()
        def check_config() -> str:
            """Check configuration files for the current incident.
            Returns relevant configuration parameters including GPU settings,
            backend configuration, model parameters, and environment variables.
            Costs 1 step (-0.25 reward)."""
            return env._do_inspect("config")

        @mcp.tool()
        def view_code() -> str:
            """View relevant source code snippets for the current incident.
            Returns code from the affected component showing the likely
            location of the bug or misconfiguration.
            Costs 1 step (-0.25 reward)."""
            return env._do_inspect("snippet")

        @mcp.tool()
        def run_diagnostic() -> str:
            """Run performance diagnostics and metrics collection.
            Returns metrics like latency, throughput, GPU utilization,
            error rates, and memory usage for the affected system.
            Costs 1 step (-0.25 reward)."""
            return env._do_inspect("metrics")

        @mcp.tool()
        def query_specialist(specialist: str) -> str:
            """Ask a specialist for their analysis of the incident.
            Specialists: 'runtime', 'dispatch', 'kernel', 'loader'.
            WARNING: At least one specialist gives wrong advice per incident.
            Cross-verify specialist opinions before trusting them.
            Costs 1 step (-0.25 reward)."""
            return env._do_ask_specialist(specialist)

        @mcp.tool()
        def apply_fix(fix: str) -> str:
            """Apply a fix to the system. Can only be used ONCE per incident.
            Available fixes: 'relax_arch_check', 'add_whitelist_entry',
            'fix_runtime_path', 'switch_backend', 'update_model_config',
            'fix_weight_mapping'.
            Correct fix: +3 reward. Wrong fix: -2 reward."""
            return env._do_apply_fix(fix)

        @mcp.tool()
        def submit_diagnosis(root_cause: str, fix: str, justification: str = "") -> str:
            """Submit your final diagnosis. This ends the episode.
            Root causes: 'arch_guard', 'backend_whitelist', 'runtime_loader',
            'backend_selector', 'model_config', 'weight_layout'.
            Fixes: 'relax_arch_check', 'add_whitelist_entry', 'fix_runtime_path',
            'switch_backend', 'update_model_config', 'fix_weight_mapping'.
            justification: A short sentence explaining WHY you chose this root cause
            and fix based on the evidence you gathered. Bonus +1 if provided.
            Correct root_cause: +8. Wrong: -4. Correct fix: +8. Wrong: -4.
            Bonus +2 if solved in 4 or fewer steps. Bonus +1 for justification."""
            return env._do_submit(root_cause, fix, justification)

        super().__init__(mcp)

    # ------------------------------------------------------------------
    # MCP tool implementations
    # ------------------------------------------------------------------

    def _check_episode(self) -> str | None:
        """Return error message if episode is not active."""
        if self._scenario is None:
            return "No active incident. Call reset() first."
        if self._done:
            return "Episode is over. Call reset() to start a new incident."
        if self._step_count >= MAX_STEPS:
            self._done = True
            return "Max steps reached. Episode over."
        return None

    def _record_step(self, reward: float, action: dict) -> None:
        self._step_count += 1
        self._state_obj.step_count = self._step_count
        self._cumulative_reward += reward
        self._actions_taken.append(action)

    def _do_inspect(self, target: str) -> str:
        err = self._check_episode()
        if err:
            return err

        ir = self._scenario.inspect_results
        result_map = {
            "logs": ir.logs,
            "config": ir.config,
            "snippet": ir.snippet,
            "metrics": ir.metrics,
        }

        self._record_step(-0.25, {"type": "inspect", "target": target})

        remaining = MAX_STEPS - self._step_count
        return (
            f"[INSPECT {target.upper()}]\n"
            f"{result_map[target]}\n\n"
            f"[Steps remaining: {remaining} | Reward: -0.25 | Cumulative: {self._cumulative_reward:.2f}]"
        )

    def _do_ask_specialist(self, specialist: str) -> str:
        err = self._check_episode()
        if err:
            return err

        if specialist not in SPECIALISTS:
            self._record_step(-2.0, {"type": "invalid", "message": f"Unknown specialist: {specialist}"})
            return f"Invalid specialist '{specialist}'. Available: {SPECIALISTS}. Penalty: -2.0"

        followup = self._scenario.specialist_followups.get(specialist, "No additional information.")
        self._record_step(-0.25, {"type": "ask_specialist", "specialist": specialist})

        remaining = MAX_STEPS - self._step_count
        return (
            f"[SPECIALIST: {specialist.upper()}]\n"
            f"{followup}\n\n"
            f"[Steps remaining: {remaining} | Reward: -0.25 | Cumulative: {self._cumulative_reward:.2f}]"
        )

    def _do_apply_fix(self, fix: str) -> str:
        err = self._check_episode()
        if err:
            return err

        if self._fix_applied:
            self._record_step(-2.0, {"type": "invalid", "message": "Fix already applied"})
            return "You already applied a fix this episode. Only one fix allowed. Penalty: -2.0"

        if fix not in VALID_FIXES:
            self._record_step(-2.0, {"type": "invalid", "message": f"Invalid fix: {fix}"})
            return f"Invalid fix '{fix}'. Available: {sorted(VALID_FIXES)}. Penalty: -2.0"

        self._fix_applied = True
        is_correct = fix == self._scenario.correct_fix
        self._fix_was_correct = is_correct
        reward = 3.0 if is_correct else -2.0
        self._record_step(reward, {"type": "apply_fix", "fix": fix, "correct": is_correct})

        remaining = MAX_STEPS - self._step_count
        if is_correct:
            return (
                f"[FIX APPLIED: {fix}] Fix applied successfully. Systems recovering.\n"
                f"Now submit your diagnosis with submit_diagnosis().\n\n"
                f"[Steps remaining: {remaining} | Reward: +3.0 | Cumulative: {self._cumulative_reward:.2f}]"
            )
        else:
            return (
                f"[FIX APPLIED: {fix}] Fix applied but the issue persists.\n"
                f"Consider your diagnosis carefully.\n\n"
                f"[Steps remaining: {remaining} | Reward: -2.0 | Cumulative: {self._cumulative_reward:.2f}]"
            )

    def _do_submit(self, root_cause: str, fix: str, justification: str = "") -> str:
        err = self._check_episode()
        if err:
            return err

        if root_cause not in VALID_ROOT_CAUSES:
            self._record_step(-2.0, {"type": "invalid", "message": f"Invalid root_cause: {root_cause}"})
            return f"Invalid root_cause '{root_cause}'. Available: {sorted(VALID_ROOT_CAUSES)}. Penalty: -2.0"

        if fix not in VALID_FIXES:
            self._record_step(-2.0, {"type": "invalid", "message": f"Invalid fix: {fix}"})
            return f"Invalid fix '{fix}'. Available: {sorted(VALID_FIXES)}. Penalty: -2.0"

        self._done = True
        rc_correct = root_cause == self._scenario.root_cause
        fix_correct = fix == self._scenario.correct_fix
        has_justification = len(justification.strip()) >= 10

        reward = 0.0
        reward += 8.0 if rc_correct else -4.0
        reward += 8.0 if fix_correct else -4.0
        if rc_correct and fix_correct and self._step_count + 1 <= 4:
            reward += 2.0
        if has_justification:
            reward += 1.0

        self._record_step(reward, {
            "type": "submit", "root_cause": root_cause, "fix": fix,
            "justification": justification,
            "rc_correct": rc_correct, "fix_correct": fix_correct,
            "has_justification": has_justification,
        })

        lines = ["[DIAGNOSIS SUBMITTED]"]
        lines.append(f"  Root cause: {root_cause}{'CORRECT' if rc_correct else 'WRONG (was: ' + self._scenario.root_cause + ')'}")
        lines.append(f"  Fix: {fix}{'CORRECT' if fix_correct else 'WRONG (was: ' + self._scenario.correct_fix + ')'}")
        if has_justification:
            lines.append(f"  Justification: {justification.strip()}")
            lines.append("  JUSTIFICATION BONUS: +1")
        else:
            lines.append("  No justification provided (missed +1 bonus)")
        lines.append(f"  Steps used: {self._step_count}/{MAX_STEPS}")
        if rc_correct and fix_correct and self._step_count <= 4:
            lines.append("  EFFICIENCY BONUS: +2 (solved in <= 4 steps)")
        lines.append(f"  Episode reward: {self._cumulative_reward:.2f}")

        return "\n".join(lines)

    # ------------------------------------------------------------------
    # OpenEnv Environment interface (for training / WebSocket API)
    # ------------------------------------------------------------------

    def reset(self, seed=None, episode_id=None, **kwargs) -> StackDoctorObservation:
        scenario_id = kwargs.get("scenario_id")
        split = kwargs.get("split", "train")
        self._scenario = get_scenario(scenario_id, split=split)

        self._state_obj = State(
            episode_id=episode_id or str(uuid4()),
            step_count=0,
        )
        self._step_count = 0
        self._fix_applied = False
        self._fix_was_correct = None
        self._done = False
        self._cumulative_reward = 0.0
        self._actions_taken = []

        specialist_obs = {}
        for name, op in self._scenario.specialist_opinions.items():
            specialist_obs[name] = {
                "opinion": op.opinion,
                "confidence": op.confidence,
            }

        return StackDoctorObservation(
            output=(
                "STACK DOCTOR — New incident assigned.\n"
                "Investigate using the available tools: read_log(), check_config(), "
                "view_code(), run_diagnostic(), query_specialist(name).\n"
                "When ready, apply_fix(fix) and/or submit_diagnosis(root_cause, fix).\n"
                "You have 6 steps. At least one specialist is WRONG — cross-verify.\n"
            ),
            incident_ticket=self._scenario.incident_ticket,
            hardware=self._scenario.hardware,
            model_name=self._scenario.model_name,
            backend=self._scenario.backend,
            log_excerpt=self._scenario.initial_log,
            code_snippet=self._scenario.initial_snippet,
            specialist_opinions=specialist_obs,
            steps_remaining=MAX_STEPS,
            fix_used=False,
            done=False,
            reward=0.0,
        )

    def _step_impl(
        self,
        action: Action,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> Observation:
        """Handle non-MCP actions (JSON action strings for training)."""
        if not isinstance(action, StackDoctorAction):
            return self._make_obs("Invalid action type.", -2.0)

        try:
            parsed = json.loads(action.message)
        except (json.JSONDecodeError, TypeError):
            return self._make_obs(f"Invalid JSON: {action.message[:200]}", -2.0)

        action_type = parsed.get("type")

        if action_type == "inspect":
            result = self._do_inspect(parsed.get("target", "logs"))
        elif action_type == "ask_specialist":
            result = self._do_ask_specialist(parsed.get("specialist", ""))
        elif action_type == "apply_fix":
            result = self._do_apply_fix(parsed.get("fix", ""))
        elif action_type == "submit":
            result = self._do_submit(parsed.get("root_cause", ""), parsed.get("fix", ""), parsed.get("justification", ""))
        else:
            self._record_step(-2.0, {"type": "invalid", "message": f"Unknown: {action_type}"})
            result = f"Unknown action type: {action_type}. Penalty: -2.0"

        # Extract last reward from actions
        last_reward = 0.0
        if self._actions_taken:
            last = self._actions_taken[-1]
            if last.get("type") == "submit":
                # Calculate submit reward
                rc_c = last.get("rc_correct", False)
                fx_c = last.get("fix_correct", False)
                last_reward = (8.0 if rc_c else -4.0) + (8.0 if fx_c else -4.0)
                if rc_c and fx_c and self._step_count <= 4:
                    last_reward += 2.0
                if last.get("has_justification", False):
                    last_reward += 1.0
            elif last.get("type") == "apply_fix":
                last_reward = 3.0 if last.get("correct") else -2.0
            elif last.get("type") == "invalid":
                last_reward = -2.0
            else:
                last_reward = -0.25

        return self._make_obs(result, last_reward)

    def _make_obs(self, output: str, reward: float) -> StackDoctorObservation:
        remaining = MAX_STEPS - self._step_count
        return StackDoctorObservation(
            output=output,
            incident_ticket=self._scenario.incident_ticket if self._scenario else "",
            hardware=self._scenario.hardware if self._scenario else "",
            model_name=self._scenario.model_name if self._scenario else "",
            backend=self._scenario.backend if self._scenario else "",
            log_excerpt="",
            code_snippet="",
            specialist_opinions={},
            steps_remaining=remaining,
            fix_used=self._fix_applied,
            done=self._done,
            reward=reward,
            metadata={
                "cumulative_reward": self._cumulative_reward,
                "step": self._step_count,
                "scenario_id": self._scenario.id if self._scenario else "",
            },
        )

    @property
    def state(self) -> State:
        return self._state_obj