Multi-Agentic / ER_MAP /evaluate.py
Uddiii's picture
feat: add support for lowercase Hugging Face Space secrets
63726b6
"""
ER_MAP/evaluate.py
==================
Run N episodes with an LLM Doctor brain, show full conversations,
collect metrics, and plot reward curves.
Usage:
cd d:/Meta_Finals
python -u -m ER_MAP.evaluate --episodes 30
"""
import json
import os
import sys
import time
import argparse
from typing import Dict, Any, List, Optional
# Force unbuffered output
sys.stdout.reconfigure(line_buffering=True)
# ---------------------------------------------------------------------------
# Doctor LLM Brain
# ---------------------------------------------------------------------------
DOCTOR_SYSTEM_PROMPT = """You are an expert emergency room doctor performing triage. You must diagnose and treat the patient.
## Available Tools (respond with STRICT JSON)
1. speak_to: {"thought":"...","tool":"speak_to","target":"nurse or patient","message":"..."}
2. order_lab: {"thought":"...","tool":"order_lab","target":"nurse","test_name":"lab name"}
3. read_soap: {"thought":"...","tool":"read_soap","section":"Subjective or Objective or ALL"}
4. update_soap: {"thought":"...","tool":"update_soap","section":"Assessment","content":"your diagnosis"}
5. terminal_discharge: {"thought":"...","tool":"terminal_discharge","treatment":"your treatment plan"}
## Strategy
- First: Use read_soap to review the patient's HPI, medical history, allergies, and physical exam
- Ask nurse to assess patient and get vitals
- Order relevant labs based on symptoms (e.g. troponin, D-dimer, BMP, ABG, CBC, ECG, CXR, CSF, tryptase, urine_tox, CT_head, CT_abdomen, CT_angio, CK, peak_flow)
- Update Assessment with your working diagnosis before discharge
- Check Allergies before prescribing medications
- Discharge with treatment when you have enough evidence
- Be concise with patients. Use simple language.
RESPOND ONLY WITH VALID JSON."""
class DoctorBrain:
"""
Resilient Doctor LLM client.
- Accepts a single key (legacy) OR a list of (api_key, model) tuples.
- On 401 (invalid key) or 429 (rate-limited), marks that
(key, model) pair as dead and silently advances to the next pair
so the episode keeps progressing instead of looping on stale data.
- When *every* pair is dead, falls back to a deterministic clinical
decision tree (`_smart_fallback_action`) that drives the episode
toward a sensible discharge instead of spamming "Update me on the
patient" 30 times and burning Nurse/Patient tokens.
"""
def __init__(self, api_key: str = "", model: str = "llama-3.1-8b-instant",
fallback_chain: Optional[List[Dict[str, str]]] = None):
from groq import Groq
self._Groq = Groq
# Build the (key, model) chain. The Doctor's *primary* model is
# 8B-Instant: it has its own daily TPD pool, separate from the
# 70B pool used by Nurse/Patient/Judges. Rotating across both
# pools effectively gives ~5x more headroom than a single key.
if fallback_chain is None:
fallback_chain = []
if api_key:
fallback_chain.append({"key": api_key, "model": model})
self._chain: List[Dict[str, Any]] = []
seen = set()
for entry in fallback_chain:
k = (entry["key"], entry["model"])
if not entry["key"] or k in seen:
continue
seen.add(k)
self._chain.append({
"key": entry["key"],
"model": entry["model"],
"client": Groq(api_key=entry["key"]),
"dead": False,
"label": entry.get("label", entry["key"][-4:]),
})
if not self._chain:
raise ValueError("DoctorBrain: empty fallback chain")
# Keep .client / .model for backward compat with any caller
# that still pokes at them (rare, but safer to expose).
self.client = self._chain[0]["client"]
self.model = self._chain[0]["model"]
self.history = [{"role": "system", "content": DOCTOR_SYSTEM_PROMPT}]
self._consecutive_failures = 0
def reset(self):
self.history = [{"role": "system", "content": DOCTOR_SYSTEM_PROMPT}]
self._consecutive_failures = 0
def _alive_clients(self) -> List[Dict[str, Any]]:
return [c for c in self._chain if not c["dead"]]
@staticmethod
def _is_dead_error(err: Exception) -> bool:
"""Detect Groq 401 (invalid key) and 429 (rate-limited)."""
s = str(err)
return "401" in s or "429" in s or "rate_limit" in s.lower() \
or "invalid_api_key" in s.lower()
def _smart_fallback_action(self) -> str:
"""
Deterministic clinical decision tree used when every Groq client
in the chain is dead. Drives the episode toward a sensible
terminal state instead of looping on "Give me an update".
"""
depth = self._consecutive_failures
if depth <= 1:
action = {
"thought": "Fallback (no LLM available): start by reading the SOAP note",
"tool": "read_soap", "section": "ALL",
}
elif depth == 2:
action = {
"thought": "Fallback: ask nurse for vitals and a focused exam",
"tool": "speak_to", "target": "nurse",
"message": "Please get full vitals (HR/BP/RR/SpO2/Temp) and report any focal findings.",
}
elif depth == 3:
action = {
"thought": "Fallback: order a broad initial lab panel",
"tool": "order_lab", "target": "nurse",
"test_name": "CBC, BMP, lactate, troponin, ECG",
}
elif depth == 4:
action = {
"thought": "Fallback: document working assessment before discharge",
"tool": "update_soap", "section": "Assessment",
"content": "Working dx pending; treating empirically based on vitals + chief complaint.",
}
else:
# After 5+ consecutive failures, end the episode rather than
# waste any more Nurse/Patient tokens. Use a safe empirical
# treatment that covers the most common emergent diagnoses.
action = {
"thought": "Fallback: empirical discharge to terminate stuck episode",
"tool": "terminal_discharge",
"treatment": "Empirical: O2 + IV fluids + monitor; ICU admit if unstable.",
}
return json.dumps(action)
def decide(self, observation: str) -> str:
self.history.append({"role": "user", "content": f"Observation:\n{observation}"})
if len(self.history) > 17:
self.history = [self.history[0]] + self.history[-16:]
response = None
for entry in self._alive_clients():
try:
completion = entry["client"].chat.completions.create(
model=entry["model"],
messages=self.history,
temperature=0.6,
max_tokens=300,
response_format={"type": "json_object"},
)
response = completion.choices[0].message.content or ""
self._consecutive_failures = 0
break
except Exception as e:
if self._is_dead_error(e):
print(f" [Doctor: key=...{entry['label']} "
f"model={entry['model']} -> DEAD ({type(e).__name__}); "
f"trying next]", flush=True)
entry["dead"] = True
continue
# Non-fatal error (network blip, JSON parse, etc.) — give
# up on this turn but DON'T mark the key dead.
print(f" [Doctor API Error: {e}]", flush=True)
break
if response is None:
self._consecutive_failures += 1
alive = len(self._alive_clients())
print(f" [Doctor: all {len(self._chain)} clients dead "
f"({alive} alive). Smart fallback depth={self._consecutive_failures}]",
flush=True)
response = self._smart_fallback_action()
self.history.append({"role": "assistant", "content": response})
return response
# ---------------------------------------------------------------------------
# Conversation Printer
# ---------------------------------------------------------------------------
def print_doctor_action(action_str: str, step: int):
try:
a = json.loads(action_str)
except json.JSONDecodeError:
print(f" DOCTOR: [invalid JSON]", flush=True)
return
tool = a.get("tool", "?")
print(f" DOCTOR | {a.get('thought', '')[:80]}", flush=True)
if tool == "speak_to":
print(f" | -> {a.get('target','')}: \"{a.get('message','')}\"", flush=True)
elif tool == "order_lab":
print(f" | -> order_lab: {a.get('test_name','')}", flush=True)
elif tool == "terminal_discharge":
print(f" | -> DISCHARGE: {a.get('treatment','')[:100]}", flush=True)
def print_observation(obs_str: str, indent=" "):
try:
obs = json.loads(obs_str)
except json.JSONDecodeError:
print(f"{indent}ENV: {obs_str[:100]}", flush=True)
return
event = obs.get("event", "unknown")
if event == "episode_start":
print(f"{indent}ENV | New case. Nurse: {obs.get('nurse_experience')}", flush=True)
elif event == "nurse_report":
print(f"{indent}NURSE | \"{obs.get('nurse_message', '')[:120]}\"", flush=True)
print(f"{indent} | nurse_status={obs.get('nurse_status','')} patient_status={obs.get('patient_status','')}", flush=True)
for ex in obs.get("internal_exchanges", []):
if "nurse_said" in ex:
print(f"{indent} N->P | \"{ex.get('nurse_said','')[:100]}\"", flush=True)
print(f"{indent} P->N | \"{ex.get('patient_said','')[:100]}\"", flush=True)
elif "nurse_action" in ex:
print(f"{indent} N-act | {ex.get('nurse_action','')} -> {ex.get('result','')[:80]}", flush=True)
elif event == "patient_response":
print(f"{indent}PATIENT | \"{obs.get('patient_message', '')[:120]}\"", flush=True)
print(f"{indent} | status={obs.get('patient_status','')}", flush=True)
elif event == "lab_result":
tag = " (DUP)" if obs.get("redundant") else ""
print(f"{indent}LAB | [{obs.get('test_name','')}]{tag}: {obs.get('result','')[:100]}", flush=True)
elif event == "terminal_win":
print(f"{indent}RESULT | >>> WIN! Patient stabilized. <<<", flush=True)
elif event == "terminal_fatal":
print(f"{indent}RESULT | >>> FATAL! Patient died. <<<", flush=True)
elif event == "terminal_incorrect":
print(f"{indent}RESULT | >>> WRONG treatment. Correct: {obs.get('correct_treatment','')[:80]} <<<", flush=True)
elif event == "terminal_ama":
print(f"{indent}RESULT | >>> AMA! Patient left: \"{obs.get('patient_message','')[:80]}\" <<<", flush=True)
elif event == "system_error":
print(f"{indent}ERROR | {obs.get('message','')[:100]}", flush=True)
# ---------------------------------------------------------------------------
# Evaluation Runner
# ---------------------------------------------------------------------------
def run_episode(env, doctor, episode_num: int) -> Dict[str, Any]:
doctor.reset()
obs, info = env.reset()
gt = env.ground_truth
disease = info.get("ground_truth_disease", "???")
difficulty = gt.get("difficulty", "random")
p = gt["patient"]
n = gt["nurse"]
# Print episode header
print(f" Disease: {disease}", flush=True)
print(f" Difficulty: {difficulty}", flush=True)
print(f" Patient: compliance={p['compliance']}, comm={p['communication']}, literacy={p['literacy']}", flush=True)
print(f" Nurse: exp={n['experience']}, bandwidth={n['bandwidth']}, empathy={n['empathy']}", flush=True)
print(f" Correct Tx: {gt['disease']['correct_treatment'][:80]}", flush=True)
print(f" {'~'*60}", flush=True)
print_observation(obs)
total_reward = 0.0
steps = 0
outcome = "truncated"
while True:
steps += 1
time.sleep(1.2)
action_str = doctor.decide(obs)
print(f" Step {steps}:", flush=True)
print_doctor_action(action_str, steps)
obs, reward, done, truncated, step_info = env.step(action_str)
total_reward += reward
print(f" REWARD | {reward:+.2f} (total: {total_reward:+.2f})", flush=True)
print_observation(obs)
if done:
try:
obs_data = json.loads(obs)
event = obs_data.get("event", "")
if "win" in event: outcome = "WIN"
elif "fatal" in event: outcome = "FATAL"
elif "ama" in event: outcome = "AMA"
elif "incorrect" in event: outcome = "WRONG"
else: outcome = event
except:
outcome = "done"
break
if truncated:
outcome = "TRUNCATED"
break
if steps >= 30:
outcome = "MAX_STEPS"
break
return {
"episode": episode_num, "disease": disease,
"difficulty": difficulty, "compliance": p["compliance"],
"communication": p["communication"], "outcome": outcome,
"total_reward": round(total_reward, 2), "steps": steps,
}
def plot_reward_curve(results: List[Dict], output_path: str):
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError:
print(" matplotlib not installed. Skipping plot.", flush=True)
return
episodes = [r["episode"] for r in results]
rewards = [r["total_reward"] for r in results]
outcomes = [r["outcome"] for r in results]
window = min(5, len(rewards))
rolling_avg = []
for i in range(len(rewards)):
start = max(0, i - window + 1)
rolling_avg.append(sum(rewards[start:i+1]) / (i - start + 1))
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={"height_ratios": [3, 1]})
fig.patch.set_facecolor("#0d1117")
ax1.set_facecolor("#161b22")
colors = []
for o in outcomes:
if o == "WIN": colors.append("#2ea043")
elif o == "AMA": colors.append("#f0883e")
elif o in ("FATAL", "WRONG"): colors.append("#f85149")
else: colors.append("#8b949e")
ax1.bar(episodes, rewards, color=colors, alpha=0.6, width=0.8, label="Episode Reward")
ax1.plot(episodes, rolling_avg, color="#58a6ff", linewidth=2.5, label=f"Rolling Avg (window={window})", zorder=5)
ax1.axhline(y=0, color="#484f58", linewidth=1, linestyle="--")
ax1.axhline(y=2.0, color="#2ea043", linewidth=1, linestyle=":", alpha=0.5, label="Win threshold (+2.0)")
ax1.axhline(y=-1.5, color="#f85149", linewidth=1, linestyle=":", alpha=0.5, label="AMA penalty (-1.5)")
ax1.set_xlabel("Episode", color="#c9d1d9", fontsize=12)
ax1.set_ylabel("Total Reward", color="#c9d1d9", fontsize=12)
ax1.set_title("ER-MAP: LLM Doctor Reward Curve (Baseline - No RL Training)",
color="#f0f6fc", fontsize=14, fontweight="bold", pad=15)
ax1.legend(loc="upper left", facecolor="#21262d", edgecolor="#484f58", labelcolor="#c9d1d9")
ax1.tick_params(colors="#8b949e")
for spine in ax1.spines.values():
spine.set_color("#484f58")
ax2.set_facecolor("#161b22")
outcome_types = ["WIN", "AMA", "WRONG", "FATAL", "TRUNCATED", "MAX_STEPS"]
outcome_colors = ["#2ea043", "#f0883e", "#f85149", "#da3633", "#8b949e", "#6e7681"]
outcome_counts = [sum(1 for o in outcomes if o == t) for t in outcome_types]
bars = ax2.barh(outcome_types, outcome_counts, color=outcome_colors, alpha=0.8)
for bar, count in zip(bars, outcome_counts):
if count > 0:
ax2.text(bar.get_width() + 0.15, bar.get_y() + bar.get_height()/2,
str(count), va="center", color="#c9d1d9", fontsize=11, fontweight="bold")
ax2.set_xlabel("Count", color="#c9d1d9", fontsize=11)
ax2.set_title("Outcome Distribution", color="#c9d1d9", fontsize=12, pad=10)
ax2.tick_params(colors="#8b949e")
for spine in ax2.spines.values():
spine.set_color("#484f58")
plt.tight_layout(pad=2.0)
plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="#0d1117")
plt.close()
print(f"\n Reward curve saved to: {output_path}", flush=True)
def print_summary(results: List[Dict]):
total = len(results)
wins = sum(1 for r in results if r["outcome"] == "WIN")
ama = sum(1 for r in results if r["outcome"] == "AMA")
wrong = sum(1 for r in results if r["outcome"] in ("WRONG", "FATAL"))
avg_reward = sum(r["total_reward"] for r in results) / total
avg_steps = sum(r["steps"] for r in results) / total
print(flush=True)
print("=" * 70, flush=True)
print(f" EVALUATION SUMMARY ({total} episodes)", flush=True)
print("=" * 70, flush=True)
print(f" Win Rate: {wins}/{total} ({100*wins/total:.0f}%)", flush=True)
print(f" AMA Rate: {ama}/{total} ({100*ama/total:.0f}%)", flush=True)
print(f" Wrong/Fatal: {wrong}/{total} ({100*wrong/total:.0f}%)", flush=True)
print(f" Avg Reward: {avg_reward:+.2f}", flush=True)
print(f" Avg Steps: {avg_steps:.1f}", flush=True)
print("=" * 70, flush=True)
print(flush=True)
diseases = {}
for r in results:
d = r["disease"]
if d not in diseases:
diseases[d] = {"wins": 0, "total": 0, "reward_sum": 0}
diseases[d]["total"] += 1
diseases[d]["reward_sum"] += r["total_reward"]
if r["outcome"] == "WIN":
diseases[d]["wins"] += 1
print(" PER-DISEASE BREAKDOWN:", flush=True)
print(f" {'Disease':35s} {'Win':>5s} {'Total':>5s} {'Rate':>6s} {'Avg Rwd':>8s}", flush=True)
print(" " + "-" * 62, flush=True)
for d, stats in sorted(diseases.items()):
rate = f"{100*stats['wins']/stats['total']:.0f}%" if stats["total"] > 0 else "N/A"
avg = stats["reward_sum"] / stats["total"]
print(f" {d:35s} {stats['wins']:>5d} {stats['total']:>5d} {rate:>6s} {avg:>+8.2f}", flush=True)
print(flush=True)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="ER-MAP Evaluation Runner")
parser.add_argument("--episodes", type=int, default=30, help="Number of episodes")
parser.add_argument("--output", type=str, default="reward_curve.png", help="Output plot path")
args = parser.parse_args()
from ER_MAP.envs.triage_env import TriageEnv
nurse_key = ((os.environ.get("GROQ_NURSE_API_KEY") or os.environ.get("nurse")) or os.environ.get("nurse", ""))
patient_key = ((os.environ.get("GROQ_PATIENT_API_KEY") or os.environ.get("patient")) or os.environ.get("patient", ""))
doctor_key = ((os.environ.get("GROQ_DOCTOR_API_KEY") or os.environ.get("doctor")) or os.environ.get("doctor", "")) or patient_key
if not nurse_key or not patient_key:
print("ERROR: Set GROQ_NURSE_API_KEY and GROQ_PATIENT_API_KEY", flush=True)
return 1
print(flush=True)
print("=" * 70, flush=True)
print(f" ER-MAP EVALUATION: {args.episodes} episodes with LLM Doctor", flush=True)
print("=" * 70, flush=True)
print(f" Doctor: Llama-3.1-8B (unmodified baseline)", flush=True)
print(f" Nurse: Llama-3.1-8B (LIVE)", flush=True)
print(f" Patient: Llama-3.1-8B (LIVE)", flush=True)
print(f" Diseases: 15 | Persona combos: 933,120", flush=True)
print("=" * 70, flush=True)
env = TriageEnv(nurse_api_key=nurse_key, patient_api_key=patient_key)
doctor = DoctorBrain(api_key=doctor_key)
results = []
for ep in range(1, args.episodes + 1):
print(flush=True)
print(f" {'='*60}", flush=True)
print(f" EPISODE {ep}/{args.episodes}", flush=True)
print(f" {'='*60}", flush=True)
try:
result = run_episode(env, doctor, ep)
results.append(result)
icon = {"WIN": "[OK]", "AMA": "[!!]", "WRONG": "[XX]", "FATAL": "[XX]"}.get(result["outcome"], "[--]")
print(f" {icon} OUTCOME: {result['outcome']:8s} | Reward: {result['total_reward']:+.2f} | Steps: {result['steps']}", flush=True)
except Exception as e:
print(f" [ERR] Episode {ep} failed: {e}", flush=True)
results.append({
"episode": ep, "disease": "ERROR", "difficulty": "?",
"compliance": "?", "communication": "?",
"outcome": "ERROR", "total_reward": -2.0, "steps": 0
})
env.close()
# Save results
out_dir = os.path.dirname(args.output) or "."
results_path = os.path.join(out_dir, "eval_results.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\n Raw results saved to: {results_path}", flush=True)
print_summary(results)
plot_reward_curve(results, args.output)
return 0
if __name__ == "__main__":
sys.exit(main())