""" Stack Doctor MCP Environment. Wraps the core Stack Doctor environment with MCP tools that agents can discover and invoke. This is the agent-facing interface — agents call tools like read_log(), query_specialist(), submit_diagnosis() instead of constructing JSON action strings. The training (WebSocket) API still works through _step_impl(). """ from __future__ import annotations import json from typing import Any, Optional from uuid import uuid4 from mcp.server.fastmcp import FastMCP from openenv.core.env_server.mcp_environment import MCPEnvironment from openenv.core.env_server.types import Action, Observation, State from models import StackDoctorAction, StackDoctorObservation from .scenarios import ( ROOT_CAUSE_TO_FIX, FIX_TO_ROOT_CAUSE, ROOT_CAUSES, FIXES, SPECIALISTS, Scenario, get_scenario, ) MAX_STEPS = 6 VALID_FIXES = set(FIXES) VALID_ROOT_CAUSES = set(ROOT_CAUSES) class StackDoctorMCPEnvironment(MCPEnvironment): """ Stack Doctor with MCP tool interface for agent interaction. Agents discover available tools (read_log, check_config, view_code, run_diagnostic, query_specialist, apply_fix, submit_diagnosis) and call them to investigate incidents and submit diagnoses. """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): mcp = FastMCP("stack_doctor") self._state_obj = State(episode_id=str(uuid4()), step_count=0) self._scenario: Scenario | None = None self._step_count = 0 self._fix_applied = False self._fix_was_correct: bool | None = None self._done = False self._cumulative_reward = 0.0 self._actions_taken: list[dict] = [] env = self # capture for closures @mcp.tool() def read_log() -> str: """Read system and application logs for the current incident. Returns log output from the affected inference stack including error messages, warnings, and system state information. Costs 1 step (-0.25 reward).""" return env._do_inspect("logs") @mcp.tool() def check_config() -> str: """Check configuration files for the current incident. Returns relevant configuration parameters including GPU settings, backend configuration, model parameters, and environment variables. Costs 1 step (-0.25 reward).""" return env._do_inspect("config") @mcp.tool() def view_code() -> str: """View relevant source code snippets for the current incident. Returns code from the affected component showing the likely location of the bug or misconfiguration. Costs 1 step (-0.25 reward).""" return env._do_inspect("snippet") @mcp.tool() def run_diagnostic() -> str: """Run performance diagnostics and metrics collection. Returns metrics like latency, throughput, GPU utilization, error rates, and memory usage for the affected system. Costs 1 step (-0.25 reward).""" return env._do_inspect("metrics") @mcp.tool() def query_specialist(specialist: str) -> str: """Ask a specialist for their analysis of the incident. Specialists: 'runtime', 'dispatch', 'kernel', 'loader'. WARNING: At least one specialist gives wrong advice per incident. Cross-verify specialist opinions before trusting them. Costs 1 step (-0.25 reward).""" return env._do_ask_specialist(specialist) @mcp.tool() def apply_fix(fix: str) -> str: """Apply a fix to the system. Can only be used ONCE per incident. Available fixes: 'relax_arch_check', 'add_whitelist_entry', 'fix_runtime_path', 'switch_backend', 'update_model_config', 'fix_weight_mapping'. Correct fix: +3 reward. Wrong fix: -2 reward.""" return env._do_apply_fix(fix) @mcp.tool() def submit_diagnosis(root_cause: str, fix: str, justification: str = "") -> str: """Submit your final diagnosis. This ends the episode. Root causes: 'arch_guard', 'backend_whitelist', 'runtime_loader', 'backend_selector', 'model_config', 'weight_layout'. Fixes: 'relax_arch_check', 'add_whitelist_entry', 'fix_runtime_path', 'switch_backend', 'update_model_config', 'fix_weight_mapping'. justification: A short sentence explaining WHY you chose this root cause and fix based on the evidence you gathered. Bonus +1 if provided. Correct root_cause: +8. Wrong: -4. Correct fix: +8. Wrong: -4. Bonus +2 if solved in 4 or fewer steps. Bonus +1 for justification.""" return env._do_submit(root_cause, fix, justification) super().__init__(mcp) # ------------------------------------------------------------------ # MCP tool implementations # ------------------------------------------------------------------ def _check_episode(self) -> str | None: """Return error message if episode is not active.""" if self._scenario is None: return "No active incident. Call reset() first." if self._done: return "Episode is over. Call reset() to start a new incident." if self._step_count >= MAX_STEPS: self._done = True return "Max steps reached. Episode over." return None def _record_step(self, reward: float, action: dict) -> None: self._step_count += 1 self._state_obj.step_count = self._step_count self._cumulative_reward += reward self._actions_taken.append(action) def _do_inspect(self, target: str) -> str: err = self._check_episode() if err: return err ir = self._scenario.inspect_results result_map = { "logs": ir.logs, "config": ir.config, "snippet": ir.snippet, "metrics": ir.metrics, } self._record_step(-0.25, {"type": "inspect", "target": target}) remaining = MAX_STEPS - self._step_count return ( f"[INSPECT {target.upper()}]\n" f"{result_map[target]}\n\n" f"[Steps remaining: {remaining} | Reward: -0.25 | Cumulative: {self._cumulative_reward:.2f}]" ) def _do_ask_specialist(self, specialist: str) -> str: err = self._check_episode() if err: return err if specialist not in SPECIALISTS: self._record_step(-2.0, {"type": "invalid", "message": f"Unknown specialist: {specialist}"}) return f"Invalid specialist '{specialist}'. Available: {SPECIALISTS}. Penalty: -2.0" followup = self._scenario.specialist_followups.get(specialist, "No additional information.") self._record_step(-0.25, {"type": "ask_specialist", "specialist": specialist}) remaining = MAX_STEPS - self._step_count return ( f"[SPECIALIST: {specialist.upper()}]\n" f"{followup}\n\n" f"[Steps remaining: {remaining} | Reward: -0.25 | Cumulative: {self._cumulative_reward:.2f}]" ) def _do_apply_fix(self, fix: str) -> str: err = self._check_episode() if err: return err if self._fix_applied: self._record_step(-2.0, {"type": "invalid", "message": "Fix already applied"}) return "You already applied a fix this episode. Only one fix allowed. Penalty: -2.0" if fix not in VALID_FIXES: self._record_step(-2.0, {"type": "invalid", "message": f"Invalid fix: {fix}"}) return f"Invalid fix '{fix}'. Available: {sorted(VALID_FIXES)}. Penalty: -2.0" self._fix_applied = True is_correct = fix == self._scenario.correct_fix self._fix_was_correct = is_correct reward = 3.0 if is_correct else -2.0 self._record_step(reward, {"type": "apply_fix", "fix": fix, "correct": is_correct}) remaining = MAX_STEPS - self._step_count if is_correct: return ( f"[FIX APPLIED: {fix}] Fix applied successfully. Systems recovering.\n" f"Now submit your diagnosis with submit_diagnosis().\n\n" f"[Steps remaining: {remaining} | Reward: +3.0 | Cumulative: {self._cumulative_reward:.2f}]" ) else: return ( f"[FIX APPLIED: {fix}] Fix applied but the issue persists.\n" f"Consider your diagnosis carefully.\n\n" f"[Steps remaining: {remaining} | Reward: -2.0 | Cumulative: {self._cumulative_reward:.2f}]" ) def _do_submit(self, root_cause: str, fix: str, justification: str = "") -> str: err = self._check_episode() if err: return err if root_cause not in VALID_ROOT_CAUSES: self._record_step(-2.0, {"type": "invalid", "message": f"Invalid root_cause: {root_cause}"}) return f"Invalid root_cause '{root_cause}'. Available: {sorted(VALID_ROOT_CAUSES)}. Penalty: -2.0" if fix not in VALID_FIXES: self._record_step(-2.0, {"type": "invalid", "message": f"Invalid fix: {fix}"}) return f"Invalid fix '{fix}'. Available: {sorted(VALID_FIXES)}. Penalty: -2.0" self._done = True rc_correct = root_cause == self._scenario.root_cause fix_correct = fix == self._scenario.correct_fix has_justification = len(justification.strip()) >= 10 reward = 0.0 reward += 8.0 if rc_correct else -4.0 reward += 8.0 if fix_correct else -4.0 if rc_correct and fix_correct and self._step_count + 1 <= 4: reward += 2.0 if has_justification: reward += 1.0 self._record_step(reward, { "type": "submit", "root_cause": root_cause, "fix": fix, "justification": justification, "rc_correct": rc_correct, "fix_correct": fix_correct, "has_justification": has_justification, }) lines = ["[DIAGNOSIS SUBMITTED]"] lines.append(f" Root cause: {root_cause} — {'CORRECT' if rc_correct else 'WRONG (was: ' + self._scenario.root_cause + ')'}") lines.append(f" Fix: {fix} — {'CORRECT' if fix_correct else 'WRONG (was: ' + self._scenario.correct_fix + ')'}") if has_justification: lines.append(f" Justification: {justification.strip()}") lines.append(" JUSTIFICATION BONUS: +1") else: lines.append(" No justification provided (missed +1 bonus)") lines.append(f" Steps used: {self._step_count}/{MAX_STEPS}") if rc_correct and fix_correct and self._step_count <= 4: lines.append(" EFFICIENCY BONUS: +2 (solved in <= 4 steps)") lines.append(f" Episode reward: {self._cumulative_reward:.2f}") return "\n".join(lines) # ------------------------------------------------------------------ # OpenEnv Environment interface (for training / WebSocket API) # ------------------------------------------------------------------ def reset(self, seed=None, episode_id=None, **kwargs) -> StackDoctorObservation: scenario_id = kwargs.get("scenario_id") split = kwargs.get("split", "train") self._scenario = get_scenario(scenario_id, split=split) self._state_obj = State( episode_id=episode_id or str(uuid4()), step_count=0, ) self._step_count = 0 self._fix_applied = False self._fix_was_correct = None self._done = False self._cumulative_reward = 0.0 self._actions_taken = [] specialist_obs = {} for name, op in self._scenario.specialist_opinions.items(): specialist_obs[name] = { "opinion": op.opinion, "confidence": op.confidence, } return StackDoctorObservation( output=( "STACK DOCTOR — New incident assigned.\n" "Investigate using the available tools: read_log(), check_config(), " "view_code(), run_diagnostic(), query_specialist(name).\n" "When ready, apply_fix(fix) and/or submit_diagnosis(root_cause, fix).\n" "You have 6 steps. At least one specialist is WRONG — cross-verify.\n" ), incident_ticket=self._scenario.incident_ticket, hardware=self._scenario.hardware, model_name=self._scenario.model_name, backend=self._scenario.backend, log_excerpt=self._scenario.initial_log, code_snippet=self._scenario.initial_snippet, specialist_opinions=specialist_obs, steps_remaining=MAX_STEPS, fix_used=False, done=False, reward=0.0, ) def _step_impl( self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Observation: """Handle non-MCP actions (JSON action strings for training).""" if not isinstance(action, StackDoctorAction): return self._make_obs("Invalid action type.", -2.0) try: parsed = json.loads(action.message) except (json.JSONDecodeError, TypeError): return self._make_obs(f"Invalid JSON: {action.message[:200]}", -2.0) action_type = parsed.get("type") if action_type == "inspect": result = self._do_inspect(parsed.get("target", "logs")) elif action_type == "ask_specialist": result = self._do_ask_specialist(parsed.get("specialist", "")) elif action_type == "apply_fix": result = self._do_apply_fix(parsed.get("fix", "")) elif action_type == "submit": result = self._do_submit(parsed.get("root_cause", ""), parsed.get("fix", ""), parsed.get("justification", "")) else: self._record_step(-2.0, {"type": "invalid", "message": f"Unknown: {action_type}"}) result = f"Unknown action type: {action_type}. Penalty: -2.0" # Extract last reward from actions last_reward = 0.0 if self._actions_taken: last = self._actions_taken[-1] if last.get("type") == "submit": # Calculate submit reward rc_c = last.get("rc_correct", False) fx_c = last.get("fix_correct", False) last_reward = (8.0 if rc_c else -4.0) + (8.0 if fx_c else -4.0) if rc_c and fx_c and self._step_count <= 4: last_reward += 2.0 if last.get("has_justification", False): last_reward += 1.0 elif last.get("type") == "apply_fix": last_reward = 3.0 if last.get("correct") else -2.0 elif last.get("type") == "invalid": last_reward = -2.0 else: last_reward = -0.25 return self._make_obs(result, last_reward) def _make_obs(self, output: str, reward: float) -> StackDoctorObservation: remaining = MAX_STEPS - self._step_count return StackDoctorObservation( output=output, incident_ticket=self._scenario.incident_ticket if self._scenario else "", hardware=self._scenario.hardware if self._scenario else "", model_name=self._scenario.model_name if self._scenario else "", backend=self._scenario.backend if self._scenario else "", log_excerpt="", code_snippet="", specialist_opinions={}, steps_remaining=remaining, fix_used=self._fix_applied, done=self._done, reward=reward, metadata={ "cumulative_reward": self._cumulative_reward, "step": self._step_count, "scenario_id": self._scenario.id if self._scenario else "", }, ) @property def state(self) -> State: return self._state_obj