""" ER_MAP/evaluate.py ================== Run N episodes with an LLM Doctor brain, show full conversations, collect metrics, and plot reward curves. Usage: cd d:/Meta_Finals python -u -m ER_MAP.evaluate --episodes 30 """ import json import os import sys import time import argparse from typing import Dict, Any, List, Optional # Force unbuffered output sys.stdout.reconfigure(line_buffering=True) # --------------------------------------------------------------------------- # Doctor LLM Brain # --------------------------------------------------------------------------- DOCTOR_SYSTEM_PROMPT = """You are an expert emergency room doctor performing triage. You must diagnose and treat the patient. ## Available Tools (respond with STRICT JSON) 1. speak_to: {"thought":"...","tool":"speak_to","target":"nurse or patient","message":"..."} 2. order_lab: {"thought":"...","tool":"order_lab","target":"nurse","test_name":"lab name"} 3. read_soap: {"thought":"...","tool":"read_soap","section":"Subjective or Objective or ALL"} 4. update_soap: {"thought":"...","tool":"update_soap","section":"Assessment","content":"your diagnosis"} 5. terminal_discharge: {"thought":"...","tool":"terminal_discharge","treatment":"your treatment plan"} ## Strategy - First: Use read_soap to review the patient's HPI, medical history, allergies, and physical exam - Ask nurse to assess patient and get vitals - Order relevant labs based on symptoms (e.g. troponin, D-dimer, BMP, ABG, CBC, ECG, CXR, CSF, tryptase, urine_tox, CT_head, CT_abdomen, CT_angio, CK, peak_flow) - Update Assessment with your working diagnosis before discharge - Check Allergies before prescribing medications - Discharge with treatment when you have enough evidence - Be concise with patients. Use simple language. RESPOND ONLY WITH VALID JSON.""" class DoctorBrain: """ Resilient Doctor LLM client. - Accepts a single key (legacy) OR a list of (api_key, model) tuples. - On 401 (invalid key) or 429 (rate-limited), marks that (key, model) pair as dead and silently advances to the next pair so the episode keeps progressing instead of looping on stale data. - When *every* pair is dead, falls back to a deterministic clinical decision tree (`_smart_fallback_action`) that drives the episode toward a sensible discharge instead of spamming "Update me on the patient" 30 times and burning Nurse/Patient tokens. """ def __init__(self, api_key: str = "", model: str = "llama-3.1-8b-instant", fallback_chain: Optional[List[Dict[str, str]]] = None): from groq import Groq self._Groq = Groq # Build the (key, model) chain. The Doctor's *primary* model is # 8B-Instant: it has its own daily TPD pool, separate from the # 70B pool used by Nurse/Patient/Judges. Rotating across both # pools effectively gives ~5x more headroom than a single key. if fallback_chain is None: fallback_chain = [] if api_key: fallback_chain.append({"key": api_key, "model": model}) self._chain: List[Dict[str, Any]] = [] seen = set() for entry in fallback_chain: k = (entry["key"], entry["model"]) if not entry["key"] or k in seen: continue seen.add(k) self._chain.append({ "key": entry["key"], "model": entry["model"], "client": Groq(api_key=entry["key"]), "dead": False, "label": entry.get("label", entry["key"][-4:]), }) if not self._chain: raise ValueError("DoctorBrain: empty fallback chain") # Keep .client / .model for backward compat with any caller # that still pokes at them (rare, but safer to expose). self.client = self._chain[0]["client"] self.model = self._chain[0]["model"] self.history = [{"role": "system", "content": DOCTOR_SYSTEM_PROMPT}] self._consecutive_failures = 0 def reset(self): self.history = [{"role": "system", "content": DOCTOR_SYSTEM_PROMPT}] self._consecutive_failures = 0 def _alive_clients(self) -> List[Dict[str, Any]]: return [c for c in self._chain if not c["dead"]] @staticmethod def _is_dead_error(err: Exception) -> bool: """Detect Groq 401 (invalid key) and 429 (rate-limited).""" s = str(err) return "401" in s or "429" in s or "rate_limit" in s.lower() \ or "invalid_api_key" in s.lower() def _smart_fallback_action(self) -> str: """ Deterministic clinical decision tree used when every Groq client in the chain is dead. Drives the episode toward a sensible terminal state instead of looping on "Give me an update". """ depth = self._consecutive_failures if depth <= 1: action = { "thought": "Fallback (no LLM available): start by reading the SOAP note", "tool": "read_soap", "section": "ALL", } elif depth == 2: action = { "thought": "Fallback: ask nurse for vitals and a focused exam", "tool": "speak_to", "target": "nurse", "message": "Please get full vitals (HR/BP/RR/SpO2/Temp) and report any focal findings.", } elif depth == 3: action = { "thought": "Fallback: order a broad initial lab panel", "tool": "order_lab", "target": "nurse", "test_name": "CBC, BMP, lactate, troponin, ECG", } elif depth == 4: action = { "thought": "Fallback: document working assessment before discharge", "tool": "update_soap", "section": "Assessment", "content": "Working dx pending; treating empirically based on vitals + chief complaint.", } else: # After 5+ consecutive failures, end the episode rather than # waste any more Nurse/Patient tokens. Use a safe empirical # treatment that covers the most common emergent diagnoses. action = { "thought": "Fallback: empirical discharge to terminate stuck episode", "tool": "terminal_discharge", "treatment": "Empirical: O2 + IV fluids + monitor; ICU admit if unstable.", } return json.dumps(action) def decide(self, observation: str) -> str: self.history.append({"role": "user", "content": f"Observation:\n{observation}"}) if len(self.history) > 17: self.history = [self.history[0]] + self.history[-16:] response = None for entry in self._alive_clients(): try: completion = entry["client"].chat.completions.create( model=entry["model"], messages=self.history, temperature=0.6, max_tokens=300, response_format={"type": "json_object"}, ) response = completion.choices[0].message.content or "" self._consecutive_failures = 0 break except Exception as e: if self._is_dead_error(e): print(f" [Doctor: key=...{entry['label']} " f"model={entry['model']} -> DEAD ({type(e).__name__}); " f"trying next]", flush=True) entry["dead"] = True continue # Non-fatal error (network blip, JSON parse, etc.) — give # up on this turn but DON'T mark the key dead. print(f" [Doctor API Error: {e}]", flush=True) break if response is None: self._consecutive_failures += 1 alive = len(self._alive_clients()) print(f" [Doctor: all {len(self._chain)} clients dead " f"({alive} alive). Smart fallback depth={self._consecutive_failures}]", flush=True) response = self._smart_fallback_action() self.history.append({"role": "assistant", "content": response}) return response # --------------------------------------------------------------------------- # Conversation Printer # --------------------------------------------------------------------------- def print_doctor_action(action_str: str, step: int): try: a = json.loads(action_str) except json.JSONDecodeError: print(f" DOCTOR: [invalid JSON]", flush=True) return tool = a.get("tool", "?") print(f" DOCTOR | {a.get('thought', '')[:80]}", flush=True) if tool == "speak_to": print(f" | -> {a.get('target','')}: \"{a.get('message','')}\"", flush=True) elif tool == "order_lab": print(f" | -> order_lab: {a.get('test_name','')}", flush=True) elif tool == "terminal_discharge": print(f" | -> DISCHARGE: {a.get('treatment','')[:100]}", flush=True) def print_observation(obs_str: str, indent=" "): try: obs = json.loads(obs_str) except json.JSONDecodeError: print(f"{indent}ENV: {obs_str[:100]}", flush=True) return event = obs.get("event", "unknown") if event == "episode_start": print(f"{indent}ENV | New case. Nurse: {obs.get('nurse_experience')}", flush=True) elif event == "nurse_report": print(f"{indent}NURSE | \"{obs.get('nurse_message', '')[:120]}\"", flush=True) print(f"{indent} | nurse_status={obs.get('nurse_status','')} patient_status={obs.get('patient_status','')}", flush=True) for ex in obs.get("internal_exchanges", []): if "nurse_said" in ex: print(f"{indent} N->P | \"{ex.get('nurse_said','')[:100]}\"", flush=True) print(f"{indent} P->N | \"{ex.get('patient_said','')[:100]}\"", flush=True) elif "nurse_action" in ex: print(f"{indent} N-act | {ex.get('nurse_action','')} -> {ex.get('result','')[:80]}", flush=True) elif event == "patient_response": print(f"{indent}PATIENT | \"{obs.get('patient_message', '')[:120]}\"", flush=True) print(f"{indent} | status={obs.get('patient_status','')}", flush=True) elif event == "lab_result": tag = " (DUP)" if obs.get("redundant") else "" print(f"{indent}LAB | [{obs.get('test_name','')}]{tag}: {obs.get('result','')[:100]}", flush=True) elif event == "terminal_win": print(f"{indent}RESULT | >>> WIN! Patient stabilized. <<<", flush=True) elif event == "terminal_fatal": print(f"{indent}RESULT | >>> FATAL! Patient died. <<<", flush=True) elif event == "terminal_incorrect": print(f"{indent}RESULT | >>> WRONG treatment. Correct: {obs.get('correct_treatment','')[:80]} <<<", flush=True) elif event == "terminal_ama": print(f"{indent}RESULT | >>> AMA! Patient left: \"{obs.get('patient_message','')[:80]}\" <<<", flush=True) elif event == "system_error": print(f"{indent}ERROR | {obs.get('message','')[:100]}", flush=True) # --------------------------------------------------------------------------- # Evaluation Runner # --------------------------------------------------------------------------- def run_episode(env, doctor, episode_num: int) -> Dict[str, Any]: doctor.reset() obs, info = env.reset() gt = env.ground_truth disease = info.get("ground_truth_disease", "???") difficulty = gt.get("difficulty", "random") p = gt["patient"] n = gt["nurse"] # Print episode header print(f" Disease: {disease}", flush=True) print(f" Difficulty: {difficulty}", flush=True) print(f" Patient: compliance={p['compliance']}, comm={p['communication']}, literacy={p['literacy']}", flush=True) print(f" Nurse: exp={n['experience']}, bandwidth={n['bandwidth']}, empathy={n['empathy']}", flush=True) print(f" Correct Tx: {gt['disease']['correct_treatment'][:80]}", flush=True) print(f" {'~'*60}", flush=True) print_observation(obs) total_reward = 0.0 steps = 0 outcome = "truncated" while True: steps += 1 time.sleep(1.2) action_str = doctor.decide(obs) print(f" Step {steps}:", flush=True) print_doctor_action(action_str, steps) obs, reward, done, truncated, step_info = env.step(action_str) total_reward += reward print(f" REWARD | {reward:+.2f} (total: {total_reward:+.2f})", flush=True) print_observation(obs) if done: try: obs_data = json.loads(obs) event = obs_data.get("event", "") if "win" in event: outcome = "WIN" elif "fatal" in event: outcome = "FATAL" elif "ama" in event: outcome = "AMA" elif "incorrect" in event: outcome = "WRONG" else: outcome = event except: outcome = "done" break if truncated: outcome = "TRUNCATED" break if steps >= 30: outcome = "MAX_STEPS" break return { "episode": episode_num, "disease": disease, "difficulty": difficulty, "compliance": p["compliance"], "communication": p["communication"], "outcome": outcome, "total_reward": round(total_reward, 2), "steps": steps, } def plot_reward_curve(results: List[Dict], output_path: str): try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except ImportError: print(" matplotlib not installed. Skipping plot.", flush=True) return episodes = [r["episode"] for r in results] rewards = [r["total_reward"] for r in results] outcomes = [r["outcome"] for r in results] window = min(5, len(rewards)) rolling_avg = [] for i in range(len(rewards)): start = max(0, i - window + 1) rolling_avg.append(sum(rewards[start:i+1]) / (i - start + 1)) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={"height_ratios": [3, 1]}) fig.patch.set_facecolor("#0d1117") ax1.set_facecolor("#161b22") colors = [] for o in outcomes: if o == "WIN": colors.append("#2ea043") elif o == "AMA": colors.append("#f0883e") elif o in ("FATAL", "WRONG"): colors.append("#f85149") else: colors.append("#8b949e") ax1.bar(episodes, rewards, color=colors, alpha=0.6, width=0.8, label="Episode Reward") ax1.plot(episodes, rolling_avg, color="#58a6ff", linewidth=2.5, label=f"Rolling Avg (window={window})", zorder=5) ax1.axhline(y=0, color="#484f58", linewidth=1, linestyle="--") ax1.axhline(y=2.0, color="#2ea043", linewidth=1, linestyle=":", alpha=0.5, label="Win threshold (+2.0)") ax1.axhline(y=-1.5, color="#f85149", linewidth=1, linestyle=":", alpha=0.5, label="AMA penalty (-1.5)") ax1.set_xlabel("Episode", color="#c9d1d9", fontsize=12) ax1.set_ylabel("Total Reward", color="#c9d1d9", fontsize=12) ax1.set_title("ER-MAP: LLM Doctor Reward Curve (Baseline - No RL Training)", color="#f0f6fc", fontsize=14, fontweight="bold", pad=15) ax1.legend(loc="upper left", facecolor="#21262d", edgecolor="#484f58", labelcolor="#c9d1d9") ax1.tick_params(colors="#8b949e") for spine in ax1.spines.values(): spine.set_color("#484f58") ax2.set_facecolor("#161b22") outcome_types = ["WIN", "AMA", "WRONG", "FATAL", "TRUNCATED", "MAX_STEPS"] outcome_colors = ["#2ea043", "#f0883e", "#f85149", "#da3633", "#8b949e", "#6e7681"] outcome_counts = [sum(1 for o in outcomes if o == t) for t in outcome_types] bars = ax2.barh(outcome_types, outcome_counts, color=outcome_colors, alpha=0.8) for bar, count in zip(bars, outcome_counts): if count > 0: ax2.text(bar.get_width() + 0.15, bar.get_y() + bar.get_height()/2, str(count), va="center", color="#c9d1d9", fontsize=11, fontweight="bold") ax2.set_xlabel("Count", color="#c9d1d9", fontsize=11) ax2.set_title("Outcome Distribution", color="#c9d1d9", fontsize=12, pad=10) ax2.tick_params(colors="#8b949e") for spine in ax2.spines.values(): spine.set_color("#484f58") plt.tight_layout(pad=2.0) plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="#0d1117") plt.close() print(f"\n Reward curve saved to: {output_path}", flush=True) def print_summary(results: List[Dict]): total = len(results) wins = sum(1 for r in results if r["outcome"] == "WIN") ama = sum(1 for r in results if r["outcome"] == "AMA") wrong = sum(1 for r in results if r["outcome"] in ("WRONG", "FATAL")) avg_reward = sum(r["total_reward"] for r in results) / total avg_steps = sum(r["steps"] for r in results) / total print(flush=True) print("=" * 70, flush=True) print(f" EVALUATION SUMMARY ({total} episodes)", flush=True) print("=" * 70, flush=True) print(f" Win Rate: {wins}/{total} ({100*wins/total:.0f}%)", flush=True) print(f" AMA Rate: {ama}/{total} ({100*ama/total:.0f}%)", flush=True) print(f" Wrong/Fatal: {wrong}/{total} ({100*wrong/total:.0f}%)", flush=True) print(f" Avg Reward: {avg_reward:+.2f}", flush=True) print(f" Avg Steps: {avg_steps:.1f}", flush=True) print("=" * 70, flush=True) print(flush=True) diseases = {} for r in results: d = r["disease"] if d not in diseases: diseases[d] = {"wins": 0, "total": 0, "reward_sum": 0} diseases[d]["total"] += 1 diseases[d]["reward_sum"] += r["total_reward"] if r["outcome"] == "WIN": diseases[d]["wins"] += 1 print(" PER-DISEASE BREAKDOWN:", flush=True) print(f" {'Disease':35s} {'Win':>5s} {'Total':>5s} {'Rate':>6s} {'Avg Rwd':>8s}", flush=True) print(" " + "-" * 62, flush=True) for d, stats in sorted(diseases.items()): rate = f"{100*stats['wins']/stats['total']:.0f}%" if stats["total"] > 0 else "N/A" avg = stats["reward_sum"] / stats["total"] print(f" {d:35s} {stats['wins']:>5d} {stats['total']:>5d} {rate:>6s} {avg:>+8.2f}", flush=True) print(flush=True) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="ER-MAP Evaluation Runner") parser.add_argument("--episodes", type=int, default=30, help="Number of episodes") parser.add_argument("--output", type=str, default="reward_curve.png", help="Output plot path") args = parser.parse_args() from ER_MAP.envs.triage_env import TriageEnv nurse_key = ((os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or os.environ.get("nurse", "")) patient_key = ((os.environ.get("GROQ_PATIENT_API_KEY") or os.environ.get("patient")) or os.environ.get("patient", "")) doctor_key = ((os.environ.get("GROQ_DOCTOR_API_KEY") or os.environ.get("doctor")) or os.environ.get("doctor", "")) or patient_key if not nurse_key or not patient_key: print("ERROR: Set GROQ_NURSE_API_KEY and GROQ_PATIENT_API_KEY", flush=True) return 1 print(flush=True) print("=" * 70, flush=True) print(f" ER-MAP EVALUATION: {args.episodes} episodes with LLM Doctor", flush=True) print("=" * 70, flush=True) print(f" Doctor: Llama-3.1-8B (unmodified baseline)", flush=True) print(f" Nurse: Llama-3.1-8B (LIVE)", flush=True) print(f" Patient: Llama-3.1-8B (LIVE)", flush=True) print(f" Diseases: 15 | Persona combos: 933,120", flush=True) print("=" * 70, flush=True) env = TriageEnv(nurse_api_key=nurse_key, patient_api_key=patient_key) doctor = DoctorBrain(api_key=doctor_key) results = [] for ep in range(1, args.episodes + 1): print(flush=True) print(f" {'='*60}", flush=True) print(f" EPISODE {ep}/{args.episodes}", flush=True) print(f" {'='*60}", flush=True) try: result = run_episode(env, doctor, ep) results.append(result) icon = {"WIN": "[OK]", "AMA": "[!!]", "WRONG": "[XX]", "FATAL": "[XX]"}.get(result["outcome"], "[--]") print(f" {icon} OUTCOME: {result['outcome']:8s} | Reward: {result['total_reward']:+.2f} | Steps: {result['steps']}", flush=True) except Exception as e: print(f" [ERR] Episode {ep} failed: {e}", flush=True) results.append({ "episode": ep, "disease": "ERROR", "difficulty": "?", "compliance": "?", "communication": "?", "outcome": "ERROR", "total_reward": -2.0, "steps": 0 }) env.close() # Save results out_dir = os.path.dirname(args.output) or "." results_path = os.path.join(out_dir, "eval_results.json") with open(results_path, "w") as f: json.dump(results, f, indent=2) print(f"\n Raw results saved to: {results_path}", flush=True) print_summary(results) plot_reward_curve(results, args.output) return 0 if __name__ == "__main__": sys.exit(main())