Spaces:
Sleeping
Sleeping
File size: 13,564 Bytes
27158b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | """
inference.py β Baseline AI agent for MediRoute OpenEnv.
Connects to any OpenAI-compatible endpoint (including Hugging Face TGI,
vLLM, or the official OpenAI API) and runs the agent across all three
difficulty tasks, printing structured logs.
Environment variables (set before running):
OPENAI_API_KEY β API key (use 'EMPTY' for local / HF endpoints)
API_BASE_URL β Base URL, e.g. https://api-inference.huggingface.co/v1
MODEL_NAME β Model identifier, e.g. mistralai/Mistral-7B-Instruct-v0.3
HF_TOKEN β (Optional) Hugging Face token for gated models
Usage:
python inference.py
python inference.py --difficulty easy
python inference.py --difficulty all
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
import time
from typing import Any, Dict, List, Optional, TYPE_CHECKING
# OpenAI SDK is only required for the LLM agent mode.
if TYPE_CHECKING:
from openai import OpenAI # pragma: no cover
else:
OpenAI = Any # type: ignore[misc,assignment]
from environment import MediRouteEnv
from models import Action, VALID_ACTION_TYPES
# βββββββββββββββββββββββββββββββββββββββββββββ
# Configuration from environment variables
# βββββββββββββββββββββββββββββββββββββββββββββ
API_KEY: str = os.getenv("OPENAI_API_KEY", "EMPTY")
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
# If HF_TOKEN is set, prefer it as the API key for HF endpoints
if HF_TOKEN and API_KEY == "EMPTY":
API_KEY = HF_TOKEN
MAX_STEPS_PER_EPISODE: int = 8
ALL_DIFFICULTIES: List[str] = ["easy", "medium", "hard"]
# βββββββββββββββββββββββββββββββββββββββββββββ
# Prompt engineering
# βββββββββββββββββββββββββββββββββββββββββββββ
SYSTEM_PROMPT = """You are MediRoute, an AI medical triage and routing agent.
Your goal is to help patients by:
1. Analysing their symptoms and lab reports to determine severity.
2. Recommending the correct medical specialist.
3. Selecting the best nearby hospital.
4. Booking appointments or dispatching ambulances as appropriate.
You must respond with a single JSON object containing exactly two keys:
"action_type" : one of [analyze_symptoms, request_more_info, recommend_specialist,
select_hospital, book_appointment, call_ambulance,
provide_temp_guidance]
"target" : a string value relevant to the action, or null
Severity levels for analyze_symptoms: low | moderate | high | critical
Rules:
- For life-threatening emergencies (SpOβ < 85 %, unconscious, etc.) β call_ambulance.
- Do NOT book an appointment in a critical emergency.
- Pick the FIRST hospital in the nearby_hospitals list as the best option.
- Stop after taking a terminal action (book_appointment or call_ambulance).
- Never repeat the same action twice.
"""
def build_user_message(obs, step: int) -> str:
return f"""Step {step} β Patient Status:
Symptoms : {obs.symptoms}
Lab Results : {json.dumps(obs.lab_report_summary, indent=2)}
Severity Score : {obs.severity_score:.2f}
Location : {obs.location}
Nearby Hospitals (in order of proximity/quality):
{chr(10).join(f' {i+1}. {h}' for i, h in enumerate(obs.nearby_hospitals))}
Available Specialists:
{chr(10).join(f' - {s}' for s in obs.available_specialists)}
Actions already taken: {obs.previous_actions or '(none)'}
What is your next action? Respond ONLY with valid JSON.
"""
# βββββββββββββββββββββββββββββββββββββββββββββ
# Agent loop
# βββββββββββββββββββββββββββββββββββββββββββββ
def parse_action(response_text: str) -> Optional[Action]:
"""Extract a valid Action from the model's raw JSON response."""
text = response_text.strip()
# Strip markdown code fences if present
if text.startswith("```"):
lines = text.splitlines()
text = "\n".join(lines[1:-1] if lines and lines[-1].strip() == "```" else lines[1:])
# Extract the first JSON object defensively (models sometimes add extra prose).
m = re.search(r"\{[\s\S]*\}", text)
if m:
text = m.group(0)
try:
data = json.loads(text)
except (json.JSONDecodeError, ValueError) as exc:
# Keep logs compatible with strict parsers: no free-form prefixes.
print(f"[STEP] event=parse_error detail={str(exc)[:120]!r}")
return None
action_type = str(data.get("action_type", "request_more_info")).strip()
target = data.get("target")
if action_type not in VALID_ACTION_TYPES:
return Action(action_type="request_more_info", target=None)
if target is not None and not isinstance(target, str):
# Keep schema strict: targets are strings or null.
target = str(target)
return Action(action_type=action_type, target=target)
def rules_agent(obs) -> Action:
"""
Deterministic baseline policy.
Designed to be fully offline and reproducible for judge evaluation.
"""
labs = obs.lab_report_summary or {}
symptoms = (obs.symptoms or "").lower()
# 1) Emergency detection / severity inference
spo2_raw = str(labs.get("spo2", "")).lower()
gcs_raw = str(labs.get("consciousness", "")).lower()
emergency_signals = any(
s in spo2_raw for s in ["78", "79", "80", "81", "82", "83", "84"]
) or ("unresponsive" in gcs_raw) or ("cyanotic" in symptoms) or ("collapse" in symptoms)
if not any(a.startswith("analyze_symptoms:") for a in obs.previous_actions):
if emergency_signals:
return Action(action_type="analyze_symptoms", target="critical")
# STEMI-ish signals
ecg = str(labs.get("ecg_finding", "")).lower()
troponin = str(labs.get("troponin_i", "")).lower()
if "st-segment elevation" in ecg or "elevated" in troponin:
return Action(action_type="analyze_symptoms", target="high")
# Default outpatient
return Action(action_type="analyze_symptoms", target="low")
# 2) Route to specialist
if not any(a.startswith("recommend_specialist:") for a in obs.previous_actions):
if emergency_signals:
return Action(action_type="recommend_specialist", target="Emergency Doctor")
# Cardiology cues
ecg = str(labs.get("ecg_finding", "")).lower()
troponin = str(labs.get("troponin_i", "")).lower()
if "st-segment elevation" in ecg or "elevated" in troponin:
return Action(action_type="recommend_specialist", target="Cardiologist")
return Action(action_type="recommend_specialist", target="General Physician")
# 3) Choose hospital (prefer first listed)
if not any(a.startswith("select_hospital:") for a in obs.previous_actions):
best = obs.nearby_hospitals[0] if obs.nearby_hospitals else "General Hospital"
return Action(action_type="select_hospital", target=best)
# 4) Close episode
if emergency_signals:
return Action(action_type="call_ambulance", target=None)
return Action(action_type="book_appointment", target=None)
def run_episode(client: Optional[OpenAI], difficulty: str, agent: str) -> Dict[str, Any]:
"""Run a complete episode for the given difficulty and return the summary."""
env = MediRouteEnv()
obs = env.reset(difficulty=difficulty)
conversation: List[Dict[str, str]] = []
step = 0
episode_start = time.time()
print(f"[START] difficulty={difficulty.upper()} agent={agent} symptoms={obs.symptoms!r}")
while step < MAX_STEPS_PER_EPISODE:
step += 1
if agent == "rules":
action = rules_agent(obs)
else:
user_msg = build_user_message(obs, step)
conversation.append({"role": "user", "content": user_msg})
# ββ Call the model ββββββββββββββββββββββββββββββββββββββββββββββββ
assistant_text = ""
for attempt in range(2):
try:
if client is None:
raise RuntimeError("OpenAI client not initialized.")
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "system", "content": SYSTEM_PROMPT}] + conversation,
temperature=0.0, # deterministic (to the extent the backend supports it)
max_tokens=256,
)
assistant_text = response.choices[0].message.content or ""
except Exception as exc:
print(f"[STEP] step={step} event=llm_error detail={str(exc)[:160]!r}")
assistant_text = ""
break
action = parse_action(assistant_text)
if action is not None:
break
# One corrective retry: ask for strict JSON only.
conversation.append({"role": "assistant", "content": assistant_text})
conversation.append(
{
"role": "user",
"content": "Your last response was invalid. Respond with ONLY a JSON object with keys action_type and target.",
}
)
if assistant_text:
conversation.append({"role": "assistant", "content": assistant_text})
if action is None:
action = Action(action_type="request_more_info", target=None)
# ββ Step environment ββββββββββββββββββββββββββββββββββββββββββββββββββ
result = env.step(action)
reward_sign = "+" if result.reward >= 0 else ""
print(
f"[STEP {step}] action={action.action_type} "
f"target={action.target!r} "
f"reward={reward_sign}{result.reward:.2f} "
f"total={result.info.get('total_reward', 0):.2f} "
f"done={result.done}"
)
obs = result.observation
if result.done:
break
elapsed = time.time() - episode_start
summary = env.state().previous_actions # all actions taken
final_info = result.info if step > 0 else {}
episode_summary = final_info.get("episode_summary", {})
total_reward = final_info.get("total_reward", 0.0)
print(f"[END] difficulty={difficulty.upper()} agent={agent} "
f"score={total_reward:.4f} "
f"passed={episode_summary.get('passed', False)} "
f"steps={step} "
f"elapsed={elapsed:.1f}s "
f"breakdown={json.dumps(episode_summary.get('breakdown', {}))}")
return {
"difficulty": difficulty,
"score": total_reward,
"passed": episode_summary.get("passed", False),
"steps": step,
"elapsed_seconds": round(elapsed, 2),
"breakdown": episode_summary.get("breakdown", {}),
}
# βββββββββββββββββββββββββββββββββββββββββββββ
# Main
# βββββββββββββββββββββββββββββββββββββββββββββ
def main() -> None:
parser = argparse.ArgumentParser(description="MediRoute OpenEnv β Baseline Inference")
parser.add_argument(
"--difficulty",
choices=["easy", "medium", "hard", "all"],
default="all",
help="Which task(s) to run (default: all)",
)
parser.add_argument(
"--agent",
choices=["llm", "rules"],
default="llm",
help="Agent policy: llm (OpenAI-compatible) or rules (offline deterministic baseline).",
)
args = parser.parse_args()
# Keep output machine-parseable: rely on [START]/[STEP]/[END] markers.
client: Optional[OpenAI] = None
if args.agent == "llm":
try:
from openai import OpenAI as OpenAIClient # type: ignore
except ImportError:
print("[ERROR] openai package not found. Install it or run with: --agent rules")
sys.exit(1)
client = OpenAIClient(api_key=API_KEY, base_url=API_BASE_URL)
difficulties = ALL_DIFFICULTIES if args.difficulty == "all" else [args.difficulty]
results = []
for diff in difficulties:
result = run_episode(client, diff, agent=args.agent)
results.append(result)
# ββ Final leaderboard βββββββββββββββββββββββββββββββββββββββββββββββββββββ
avg_score = sum(r["score"] for r in results) / len(results)
# Emit one final [END] summary line for strict log parsers.
print(f"[END] summary average_score={avg_score:.4f} results={json.dumps(results)}")
if __name__ == "__main__":
main()
|