Spaces:
Sleeping
Sleeping
| """ | |
| Stack Doctor Environment. | |
| An overseer LLM diagnoses sick inference stacks by probing subsystems, | |
| reconciling conflicting specialist-agent reports, and selecting the | |
| minimal correct fix. | |
| Inspired by real SM12x enablement bugs across vLLM, FlashInfer, SGLang, | |
| CUTLASS, and Flash-Attention. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import StackDoctorAction, StackDoctorObservation | |
| from .scenarios import ( | |
| ROOT_CAUSE_TO_FIX, | |
| FIX_TO_ROOT_CAUSE, | |
| ROOT_CAUSES, | |
| FIXES, | |
| SPECIALISTS, | |
| Scenario, | |
| get_scenario, | |
| ) | |
| MAX_STEPS = 6 | |
| INSPECT_TARGETS = {"logs", "config", "snippet", "metrics"} | |
| VALID_FIXES = set(FIXES) | |
| VALID_ROOT_CAUSES = set(ROOT_CAUSES) | |
| class EpisodeState: | |
| """Internal mutable episode state (not exposed to agent).""" | |
| def __init__(self, scenario: Scenario): | |
| self.scenario = scenario | |
| self.step_count = 0 | |
| self.fix_applied = False | |
| self.fix_was_correct: bool | None = None | |
| self.done = False | |
| self.cumulative_reward = 0.0 | |
| self.actions_taken: list[dict] = [] | |
| class StackDoctorEnvironment(Environment): | |
| """ | |
| Stack Doctor: incident-response RL environment for | |
| inference-stack diagnosis. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._episode: EpisodeState | None = None | |
| def reset(self, seed=None, episode_id=None, **kwargs) -> StackDoctorObservation: | |
| scenario_id = kwargs.get("scenario_id") | |
| split = kwargs.get("split", "train") | |
| scenario = get_scenario(scenario_id, split=split) | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| ) | |
| self._episode = EpisodeState(scenario) | |
| specialist_obs = {} | |
| for name, op in scenario.specialist_opinions.items(): | |
| specialist_obs[name] = { | |
| "opinion": op.opinion, | |
| "confidence": op.confidence, | |
| } | |
| return StackDoctorObservation( | |
| output=( | |
| "STACK DOCTOR — New incident assigned.\n" | |
| "Diagnose the root cause, optionally apply a fix, then submit your diagnosis.\n" | |
| "You have 6 steps. Use them wisely.\n\n" | |
| "Available actions (send as JSON):\n" | |
| ' {"type":"inspect","target":"logs|config|snippet|metrics"}\n' | |
| ' {"type":"ask_specialist","specialist":"runtime|dispatch|kernel|loader"}\n' | |
| ' {"type":"apply_fix","fix":"relax_arch_check|add_whitelist_entry|fix_runtime_path|switch_backend|update_model_config|fix_weight_mapping"}\n' | |
| ' {"type":"submit","root_cause":"...","fix":"...","justification":"reason for diagnosis"}\n' | |
| ), | |
| incident_ticket=scenario.incident_ticket, | |
| hardware=scenario.hardware, | |
| model_name=scenario.model_name, | |
| backend=scenario.backend, | |
| log_excerpt=scenario.initial_log, | |
| code_snippet=scenario.initial_snippet, | |
| specialist_opinions=specialist_obs, | |
| steps_remaining=MAX_STEPS, | |
| fix_used=False, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: StackDoctorAction, **kwargs) -> StackDoctorObservation: | |
| ep = self._episode | |
| if ep is None or ep.done: | |
| return self._terminal_obs("Episode is over. Call reset() to start a new incident.", 0.0) | |
| self._state.step_count += 1 | |
| ep.step_count += 1 | |
| try: | |
| parsed = json.loads(action.message) | |
| except (json.JSONDecodeError, TypeError): | |
| return self._handle_invalid(ep, f"Invalid JSON: {action.message[:200]}") | |
| action_type = parsed.get("type") | |
| if action_type == "inspect": | |
| return self._handle_inspect(ep, parsed) | |
| elif action_type == "ask_specialist": | |
| return self._handle_ask_specialist(ep, parsed) | |
| elif action_type == "apply_fix": | |
| return self._handle_apply_fix(ep, parsed) | |
| elif action_type == "submit": | |
| return self._handle_submit(ep, parsed) | |
| else: | |
| return self._handle_invalid(ep, f"Unknown action type: {action_type}") | |
| def state(self) -> State: | |
| return self._state | |
| def _handle_inspect(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation: | |
| target = parsed.get("target") | |
| if target not in INSPECT_TARGETS: | |
| return self._handle_invalid(ep, f"Invalid inspect target: {target}. Use: {INSPECT_TARGETS}") | |
| reward = -0.25 | |
| ep.cumulative_reward += reward | |
| ep.actions_taken.append({"type": "inspect", "target": target}) | |
| ir = ep.scenario.inspect_results | |
| result_map = {"logs": ir.logs, "config": ir.config, "snippet": ir.snippet, "metrics": ir.metrics} | |
| return self._step_obs(ep, output=f"[INSPECT {target.upper()}]\n{result_map[target]}", reward=reward) | |
| def _handle_ask_specialist(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation: | |
| specialist = parsed.get("specialist") | |
| if specialist not in SPECIALISTS: | |
| return self._handle_invalid(ep, f"Invalid specialist: {specialist}. Use: {SPECIALISTS}") | |
| reward = -0.25 | |
| ep.cumulative_reward += reward | |
| ep.actions_taken.append({"type": "ask_specialist", "specialist": specialist}) | |
| followup = ep.scenario.specialist_followups.get(specialist, "No additional information.") | |
| return self._step_obs(ep, output=f"[SPECIALIST: {specialist.upper()}]\n{followup}", reward=reward) | |
| def _handle_apply_fix(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation: | |
| if ep.fix_applied: | |
| return self._handle_invalid(ep, "apply_fix already used this episode. You can only apply one fix.") | |
| fix = parsed.get("fix") | |
| if fix not in VALID_FIXES: | |
| return self._handle_invalid(ep, f"Invalid fix: {fix}. Use one of: {sorted(VALID_FIXES)}") | |
| ep.fix_applied = True | |
| is_correct = fix == ep.scenario.correct_fix | |
| ep.fix_was_correct = is_correct | |
| reward = 3.0 if is_correct else -2.0 | |
| ep.cumulative_reward += reward | |
| ep.actions_taken.append({"type": "apply_fix", "fix": fix, "correct": is_correct}) | |
| if is_correct: | |
| output = f"[FIX APPLIED: {fix}] Fix applied successfully. Systems recovering. Now submit your diagnosis." | |
| else: | |
| output = f"[FIX APPLIED: {fix}] Fix applied but the issue persists. Consider your diagnosis carefully." | |
| return self._step_obs(ep, output=output, reward=reward) | |
| def _handle_submit(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation: | |
| root_cause = parsed.get("root_cause") | |
| fix = parsed.get("fix") | |
| justification = parsed.get("justification", "") | |
| if root_cause not in VALID_ROOT_CAUSES: | |
| return self._handle_invalid(ep, f"Invalid root_cause: {root_cause}. Use one of: {sorted(VALID_ROOT_CAUSES)}") | |
| if fix not in VALID_FIXES: | |
| return self._handle_invalid(ep, f"Invalid fix: {fix}. Use one of: {sorted(VALID_FIXES)}") | |
| ep.done = True | |
| correct_rc = ep.scenario.root_cause | |
| correct_fix = ep.scenario.correct_fix | |
| rc_correct = root_cause == correct_rc | |
| fix_correct = fix == correct_fix | |
| has_justification = len(justification.strip()) >= 10 | |
| reward = 0.0 | |
| reward += 8.0 if rc_correct else -4.0 | |
| reward += 8.0 if fix_correct else -4.0 | |
| if (rc_correct and fix_correct) and ep.step_count <= 4: | |
| reward += 2.0 | |
| if has_justification: | |
| reward += 1.0 | |
| ep.cumulative_reward += reward | |
| ep.actions_taken.append({ | |
| "type": "submit", "root_cause": root_cause, "fix": fix, | |
| "justification": justification, | |
| "rc_correct": rc_correct, "fix_correct": fix_correct, | |
| "has_justification": has_justification, | |
| }) | |
| output_lines = ["[DIAGNOSIS SUBMITTED]"] | |
| output_lines.append(f" Root cause: {root_cause} — {'CORRECT' if rc_correct else 'WRONG (was: ' + correct_rc + ')'}") | |
| output_lines.append(f" Fix: {fix} — {'CORRECT' if fix_correct else 'WRONG (was: ' + correct_fix + ')'}") | |
| if has_justification: | |
| output_lines.append(f" Justification: {justification.strip()}") | |
| output_lines.append(" JUSTIFICATION BONUS: +1") | |
| else: | |
| output_lines.append(" No justification provided (missed +1 bonus)") | |
| output_lines.append(f" Steps used: {ep.step_count}/{MAX_STEPS}") | |
| if rc_correct and fix_correct and ep.step_count <= 4: | |
| output_lines.append(" EFFICIENCY BONUS: +2 (solved in <= 4 steps)") | |
| output_lines.append(f" Episode reward: {ep.cumulative_reward:.2f}") | |
| return self._terminal_obs("\n".join(output_lines), reward) | |
| def _handle_invalid(self, ep: EpisodeState, msg: str) -> StackDoctorObservation: | |
| reward = -2.0 | |
| ep.cumulative_reward += reward | |
| ep.actions_taken.append({"type": "invalid", "message": msg}) | |
| if ep.step_count >= MAX_STEPS: | |
| ep.done = True | |
| return self._terminal_obs(f"[INVALID ACTION] {msg}\n[EPISODE OVER] Max steps reached. Auto-fail.", reward) | |
| return self._step_obs(ep, output=f"[INVALID ACTION] {msg}", reward=reward) | |
| def _step_obs(self, ep: EpisodeState, output: str, reward: float) -> StackDoctorObservation: | |
| remaining = MAX_STEPS - ep.step_count | |
| if remaining <= 0 and not ep.done: | |
| ep.done = True | |
| reward -= 4.0 | |
| ep.cumulative_reward += -4.0 | |
| output += "\n\n[EPISODE OVER] Max steps reached without submission. Auto-fail. Reward: -4" | |
| return StackDoctorObservation( | |
| output=output, incident_ticket=ep.scenario.incident_ticket, | |
| hardware=ep.scenario.hardware, model_name=ep.scenario.model_name, | |
| backend=ep.scenario.backend, log_excerpt="", code_snippet="", | |
| specialist_opinions={}, steps_remaining=remaining, fix_used=ep.fix_applied, | |
| done=ep.done, reward=reward, | |
| metadata={"cumulative_reward": ep.cumulative_reward, "step": ep.step_count, "scenario_id": ep.scenario.id}, | |
| ) | |
| def _terminal_obs(self, output: str, reward: float) -> StackDoctorObservation: | |
| ep = self._episode | |
| return StackDoctorObservation( | |
| output=output, incident_ticket=ep.scenario.incident_ticket if ep else "", | |
| hardware=ep.scenario.hardware if ep else "", model_name=ep.scenario.model_name if ep else "", | |
| backend=ep.scenario.backend if ep else "", log_excerpt="", code_snippet="", | |
| specialist_opinions={}, steps_remaining=0, fix_used=ep.fix_applied if ep else False, | |
| done=True, reward=reward, | |
| metadata={"cumulative_reward": ep.cumulative_reward if ep else 0.0, "step": ep.step_count if ep else 0, "scenario_id": ep.scenario.id if ep else ""}, | |
| ) | |