""" Oracle, heuristic, and random baselines for Stack Doctor. Used to validate the reward function: random < heuristic < oracle must hold. """ from __future__ import annotations import json import random from .scenarios import ( ROOT_CAUSE_TO_FIX, ROOT_CAUSES, FIXES, SPECIALISTS, Scenario, SCENARIOS, TRAIN_SCENARIOS, EVAL_SCENARIOS, ) def oracle_policy(scenario: Scenario) -> list[dict]: """Perfect policy: submit correct answer in 1 step.""" return [ { "type": "submit", "root_cause": scenario.root_cause, "fix": scenario.correct_fix, "justification": f"Root cause is {scenario.root_cause}, applying the correct fix.", } ] def heuristic_policy(scenario: Scenario) -> list[dict]: """ Reasonable heuristic: inspect logs, ask the highest-confidence specialist, then submit based on clues. Uses keyword matching on specialist opinions and logs to guess root cause. """ actions = [] # Step 1: inspect logs actions.append({"type": "inspect", "target": "logs"}) # Step 2: ask the highest-confidence specialist best_spec = max( scenario.specialist_opinions.items(), key=lambda kv: kv[1].confidence, ) actions.append({"type": "ask_specialist", "specialist": best_spec[0]}) # Step 3: heuristic root-cause guess from keywords combined_text = ( scenario.incident_ticket + " " + scenario.initial_log + " " + best_spec[1].opinion ).lower() guess = _keyword_guess(combined_text) # Step 4: apply fix actions.append({"type": "apply_fix", "fix": ROOT_CAUSE_TO_FIX[guess]}) # Step 5: submit actions.append({ "type": "submit", "root_cause": guess, "fix": ROOT_CAUSE_TO_FIX[guess], }) return actions def random_policy(scenario: Scenario) -> list[dict]: """Random policy: random actions, random submit.""" actions = [] n_steps = random.randint(1, 5) for _ in range(n_steps - 1): choice = random.choice(["inspect", "ask_specialist"]) if choice == "inspect": actions.append({ "type": "inspect", "target": random.choice(["logs", "config", "snippet", "metrics"]), }) else: actions.append({ "type": "ask_specialist", "specialist": random.choice(SPECIALISTS), }) # Final: random submit rc = random.choice(ROOT_CAUSES) actions.append({ "type": "submit", "root_cause": rc, "fix": ROOT_CAUSE_TO_FIX[rc], }) return actions def _keyword_guess(text: str) -> str: """Guess root cause from keyword presence in text.""" scores = { "arch_guard": 0, "backend_whitelist": 0, "runtime_loader": 0, "backend_selector": 0, "model_config": 0, "weight_layout": 0, } # arch_guard keywords for kw in ["arch", "architecture", "sm_12", "sm_120", "sm_121", "supported_arch", "capability", "is_supported"]: if kw in text: scores["arch_guard"] += 1 # backend_whitelist keywords for kw in ["whitelist", "supported_gpu", "not in", "marlin", "awq", "gpu name"]: if kw in text: scores["backend_whitelist"] += 1 # runtime_loader keywords for kw in ["runtime", "libcuda", "ld_library", "cuda_home", "symlink", "shared object", "rocm_path", "hipError"]: if kw in text: scores["runtime_loader"] += 1 # backend_selector keywords for kw in ["backend", "selector", "xformers", "flash_attn", "latency", "slow", "e4m3fn", "fp8 format"]: if kw in text: scores["backend_selector"] += 1 # model_config keywords for kw in ["config", "num_expert", "shape mismatch", "rope", "checkpoint", "config.json"]: if kw in text: scores["model_config"] += 1 # weight_layout keywords for kw in ["weight", "mapping", "swap", "gate_proj", "up_proj", "convert", "layout", "qkv"]: if kw in text: scores["weight_layout"] += 1 return max(scores, key=scores.get) def evaluate_policy(policy_fn, scenarios: list[Scenario], n_runs: int = 1) -> dict: """ Run a policy across scenarios and compute metrics. Returns dict with: - rc_accuracy: fraction of correct root cause submissions - fix_accuracy: fraction of correct fix submissions - avg_steps: average steps to resolution - avg_reward: average cumulative reward """ from .stack_doctor_environment import StackDoctorEnvironment from models import StackDoctorAction total_rc_correct = 0 total_fix_correct = 0 total_steps = 0 total_reward = 0.0 total_episodes = 0 for _ in range(n_runs): for scenario in scenarios: env = StackDoctorEnvironment() env.reset(scenario_id=scenario.id) actions = policy_fn(scenario) cumulative = 0.0 steps = 0 for action_dict in actions: obs = env.step(StackDoctorAction(message=json.dumps(action_dict))) cumulative += obs.reward steps += 1 if obs.done: break # Check if submit happened last_action = actions[-1] if actions else {} if last_action.get("type") == "submit": if last_action["root_cause"] == scenario.root_cause: total_rc_correct += 1 if last_action["fix"] == scenario.correct_fix: total_fix_correct += 1 total_steps += steps total_reward += cumulative total_episodes += 1 return { "rc_accuracy": total_rc_correct / total_episodes if total_episodes else 0, "fix_accuracy": total_fix_correct / total_episodes if total_episodes else 0, "avg_steps": total_steps / total_episodes if total_episodes else 0, "avg_reward": total_reward / total_episodes if total_episodes else 0, "n_episodes": total_episodes, }