stack_doctor / server /stack_doctor_environment.py
bledden's picture
Upload folder using huggingface_hub
8b92d51 verified
"""
Stack Doctor Environment.
An overseer LLM diagnoses sick inference stacks by probing subsystems,
reconciling conflicting specialist-agent reports, and selecting the
minimal correct fix.
Inspired by real SM12x enablement bugs across vLLM, FlashInfer, SGLang,
CUTLASS, and Flash-Attention.
"""
from __future__ import annotations
import json
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import StackDoctorAction, StackDoctorObservation
from .scenarios import (
ROOT_CAUSE_TO_FIX,
FIX_TO_ROOT_CAUSE,
ROOT_CAUSES,
FIXES,
SPECIALISTS,
Scenario,
get_scenario,
)
MAX_STEPS = 6
INSPECT_TARGETS = {"logs", "config", "snippet", "metrics"}
VALID_FIXES = set(FIXES)
VALID_ROOT_CAUSES = set(ROOT_CAUSES)
class EpisodeState:
"""Internal mutable episode state (not exposed to agent)."""
def __init__(self, scenario: Scenario):
self.scenario = scenario
self.step_count = 0
self.fix_applied = False
self.fix_was_correct: bool | None = None
self.done = False
self.cumulative_reward = 0.0
self.actions_taken: list[dict] = []
class StackDoctorEnvironment(Environment):
"""
Stack Doctor: incident-response RL environment for
inference-stack diagnosis.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self._episode: EpisodeState | None = None
def reset(self, seed=None, episode_id=None, **kwargs) -> StackDoctorObservation:
scenario_id = kwargs.get("scenario_id")
split = kwargs.get("split", "train")
scenario = get_scenario(scenario_id, split=split)
self._state = State(
episode_id=episode_id or str(uuid4()),
step_count=0,
)
self._episode = EpisodeState(scenario)
specialist_obs = {}
for name, op in scenario.specialist_opinions.items():
specialist_obs[name] = {
"opinion": op.opinion,
"confidence": op.confidence,
}
return StackDoctorObservation(
output=(
"STACK DOCTOR — New incident assigned.\n"
"Diagnose the root cause, optionally apply a fix, then submit your diagnosis.\n"
"You have 6 steps. Use them wisely.\n\n"
"Available actions (send as JSON):\n"
' {"type":"inspect","target":"logs|config|snippet|metrics"}\n'
' {"type":"ask_specialist","specialist":"runtime|dispatch|kernel|loader"}\n'
' {"type":"apply_fix","fix":"relax_arch_check|add_whitelist_entry|fix_runtime_path|switch_backend|update_model_config|fix_weight_mapping"}\n'
' {"type":"submit","root_cause":"...","fix":"...","justification":"reason for diagnosis"}\n'
),
incident_ticket=scenario.incident_ticket,
hardware=scenario.hardware,
model_name=scenario.model_name,
backend=scenario.backend,
log_excerpt=scenario.initial_log,
code_snippet=scenario.initial_snippet,
specialist_opinions=specialist_obs,
steps_remaining=MAX_STEPS,
fix_used=False,
done=False,
reward=0.0,
)
def step(self, action: StackDoctorAction, **kwargs) -> StackDoctorObservation:
ep = self._episode
if ep is None or ep.done:
return self._terminal_obs("Episode is over. Call reset() to start a new incident.", 0.0)
self._state.step_count += 1
ep.step_count += 1
try:
parsed = json.loads(action.message)
except (json.JSONDecodeError, TypeError):
return self._handle_invalid(ep, f"Invalid JSON: {action.message[:200]}")
action_type = parsed.get("type")
if action_type == "inspect":
return self._handle_inspect(ep, parsed)
elif action_type == "ask_specialist":
return self._handle_ask_specialist(ep, parsed)
elif action_type == "apply_fix":
return self._handle_apply_fix(ep, parsed)
elif action_type == "submit":
return self._handle_submit(ep, parsed)
else:
return self._handle_invalid(ep, f"Unknown action type: {action_type}")
@property
def state(self) -> State:
return self._state
def _handle_inspect(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation:
target = parsed.get("target")
if target not in INSPECT_TARGETS:
return self._handle_invalid(ep, f"Invalid inspect target: {target}. Use: {INSPECT_TARGETS}")
reward = -0.25
ep.cumulative_reward += reward
ep.actions_taken.append({"type": "inspect", "target": target})
ir = ep.scenario.inspect_results
result_map = {"logs": ir.logs, "config": ir.config, "snippet": ir.snippet, "metrics": ir.metrics}
return self._step_obs(ep, output=f"[INSPECT {target.upper()}]\n{result_map[target]}", reward=reward)
def _handle_ask_specialist(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation:
specialist = parsed.get("specialist")
if specialist not in SPECIALISTS:
return self._handle_invalid(ep, f"Invalid specialist: {specialist}. Use: {SPECIALISTS}")
reward = -0.25
ep.cumulative_reward += reward
ep.actions_taken.append({"type": "ask_specialist", "specialist": specialist})
followup = ep.scenario.specialist_followups.get(specialist, "No additional information.")
return self._step_obs(ep, output=f"[SPECIALIST: {specialist.upper()}]\n{followup}", reward=reward)
def _handle_apply_fix(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation:
if ep.fix_applied:
return self._handle_invalid(ep, "apply_fix already used this episode. You can only apply one fix.")
fix = parsed.get("fix")
if fix not in VALID_FIXES:
return self._handle_invalid(ep, f"Invalid fix: {fix}. Use one of: {sorted(VALID_FIXES)}")
ep.fix_applied = True
is_correct = fix == ep.scenario.correct_fix
ep.fix_was_correct = is_correct
reward = 3.0 if is_correct else -2.0
ep.cumulative_reward += reward
ep.actions_taken.append({"type": "apply_fix", "fix": fix, "correct": is_correct})
if is_correct:
output = f"[FIX APPLIED: {fix}] Fix applied successfully. Systems recovering. Now submit your diagnosis."
else:
output = f"[FIX APPLIED: {fix}] Fix applied but the issue persists. Consider your diagnosis carefully."
return self._step_obs(ep, output=output, reward=reward)
def _handle_submit(self, ep: EpisodeState, parsed: dict) -> StackDoctorObservation:
root_cause = parsed.get("root_cause")
fix = parsed.get("fix")
justification = parsed.get("justification", "")
if root_cause not in VALID_ROOT_CAUSES:
return self._handle_invalid(ep, f"Invalid root_cause: {root_cause}. Use one of: {sorted(VALID_ROOT_CAUSES)}")
if fix not in VALID_FIXES:
return self._handle_invalid(ep, f"Invalid fix: {fix}. Use one of: {sorted(VALID_FIXES)}")
ep.done = True
correct_rc = ep.scenario.root_cause
correct_fix = ep.scenario.correct_fix
rc_correct = root_cause == correct_rc
fix_correct = fix == correct_fix
has_justification = len(justification.strip()) >= 10
reward = 0.0
reward += 8.0 if rc_correct else -4.0
reward += 8.0 if fix_correct else -4.0
if (rc_correct and fix_correct) and ep.step_count <= 4:
reward += 2.0
if has_justification:
reward += 1.0
ep.cumulative_reward += reward
ep.actions_taken.append({
"type": "submit", "root_cause": root_cause, "fix": fix,
"justification": justification,
"rc_correct": rc_correct, "fix_correct": fix_correct,
"has_justification": has_justification,
})
output_lines = ["[DIAGNOSIS SUBMITTED]"]
output_lines.append(f" Root cause: {root_cause}{'CORRECT' if rc_correct else 'WRONG (was: ' + correct_rc + ')'}")
output_lines.append(f" Fix: {fix}{'CORRECT' if fix_correct else 'WRONG (was: ' + correct_fix + ')'}")
if has_justification:
output_lines.append(f" Justification: {justification.strip()}")
output_lines.append(" JUSTIFICATION BONUS: +1")
else:
output_lines.append(" No justification provided (missed +1 bonus)")
output_lines.append(f" Steps used: {ep.step_count}/{MAX_STEPS}")
if rc_correct and fix_correct and ep.step_count <= 4:
output_lines.append(" EFFICIENCY BONUS: +2 (solved in <= 4 steps)")
output_lines.append(f" Episode reward: {ep.cumulative_reward:.2f}")
return self._terminal_obs("\n".join(output_lines), reward)
def _handle_invalid(self, ep: EpisodeState, msg: str) -> StackDoctorObservation:
reward = -2.0
ep.cumulative_reward += reward
ep.actions_taken.append({"type": "invalid", "message": msg})
if ep.step_count >= MAX_STEPS:
ep.done = True
return self._terminal_obs(f"[INVALID ACTION] {msg}\n[EPISODE OVER] Max steps reached. Auto-fail.", reward)
return self._step_obs(ep, output=f"[INVALID ACTION] {msg}", reward=reward)
def _step_obs(self, ep: EpisodeState, output: str, reward: float) -> StackDoctorObservation:
remaining = MAX_STEPS - ep.step_count
if remaining <= 0 and not ep.done:
ep.done = True
reward -= 4.0
ep.cumulative_reward += -4.0
output += "\n\n[EPISODE OVER] Max steps reached without submission. Auto-fail. Reward: -4"
return StackDoctorObservation(
output=output, incident_ticket=ep.scenario.incident_ticket,
hardware=ep.scenario.hardware, model_name=ep.scenario.model_name,
backend=ep.scenario.backend, log_excerpt="", code_snippet="",
specialist_opinions={}, steps_remaining=remaining, fix_used=ep.fix_applied,
done=ep.done, reward=reward,
metadata={"cumulative_reward": ep.cumulative_reward, "step": ep.step_count, "scenario_id": ep.scenario.id},
)
def _terminal_obs(self, output: str, reward: float) -> StackDoctorObservation:
ep = self._episode
return StackDoctorObservation(
output=output, incident_ticket=ep.scenario.incident_ticket if ep else "",
hardware=ep.scenario.hardware if ep else "", model_name=ep.scenario.model_name if ep else "",
backend=ep.scenario.backend if ep else "", log_excerpt="", code_snippet="",
specialist_opinions={}, steps_remaining=0, fix_used=ep.fix_applied if ep else False,
done=True, reward=reward,
metadata={"cumulative_reward": ep.cumulative_reward if ep else 0.0, "step": ep.step_count if ep else 0, "scenario_id": ep.scenario.id if ep else ""},
)