File size: 13,564 Bytes
27158b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""
inference.py β€” Baseline AI agent for MediRoute OpenEnv.

Connects to any OpenAI-compatible endpoint (including Hugging Face TGI,
vLLM, or the official OpenAI API) and runs the agent across all three
difficulty tasks, printing structured logs.

Environment variables (set before running):
    OPENAI_API_KEY   – API key (use 'EMPTY' for local / HF endpoints)
    API_BASE_URL     – Base URL, e.g. https://api-inference.huggingface.co/v1
    MODEL_NAME       – Model identifier, e.g. mistralai/Mistral-7B-Instruct-v0.3
    HF_TOKEN         – (Optional) Hugging Face token for gated models

Usage:
    python inference.py
    python inference.py --difficulty easy
    python inference.py --difficulty all
"""

from __future__ import annotations

import argparse
import json
import os
import re
import sys
import time
from typing import Any, Dict, List, Optional, TYPE_CHECKING

# OpenAI SDK is only required for the LLM agent mode.
if TYPE_CHECKING:
    from openai import OpenAI  # pragma: no cover
else:
    OpenAI = Any  # type: ignore[misc,assignment]

from environment import MediRouteEnv
from models import Action, VALID_ACTION_TYPES

# ─────────────────────────────────────────────
#  Configuration from environment variables
# ─────────────────────────────────────────────

API_KEY: str = os.getenv("OPENAI_API_KEY", "EMPTY")
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN: str = os.getenv("HF_TOKEN", "")

# If HF_TOKEN is set, prefer it as the API key for HF endpoints
if HF_TOKEN and API_KEY == "EMPTY":
    API_KEY = HF_TOKEN

MAX_STEPS_PER_EPISODE: int = 8
ALL_DIFFICULTIES: List[str] = ["easy", "medium", "hard"]


# ─────────────────────────────────────────────
#  Prompt engineering
# ─────────────────────────────────────────────

SYSTEM_PROMPT = """You are MediRoute, an AI medical triage and routing agent.
Your goal is to help patients by:
1. Analysing their symptoms and lab reports to determine severity.
2. Recommending the correct medical specialist.
3. Selecting the best nearby hospital.
4. Booking appointments or dispatching ambulances as appropriate.

You must respond with a single JSON object containing exactly two keys:
  "action_type" : one of [analyze_symptoms, request_more_info, recommend_specialist,
                           select_hospital, book_appointment, call_ambulance,
                           provide_temp_guidance]
  "target"      : a string value relevant to the action, or null

Severity levels for analyze_symptoms: low | moderate | high | critical

Rules:
- For life-threatening emergencies (SpOβ‚‚ < 85 %, unconscious, etc.) β†’ call_ambulance.
- Do NOT book an appointment in a critical emergency.
- Pick the FIRST hospital in the nearby_hospitals list as the best option.
- Stop after taking a terminal action (book_appointment or call_ambulance).
- Never repeat the same action twice.
"""


def build_user_message(obs, step: int) -> str:
    return f"""Step {step} β€” Patient Status:

Symptoms       : {obs.symptoms}
Lab Results    : {json.dumps(obs.lab_report_summary, indent=2)}
Severity Score : {obs.severity_score:.2f}
Location       : {obs.location}
Nearby Hospitals (in order of proximity/quality):
  {chr(10).join(f'  {i+1}. {h}' for i, h in enumerate(obs.nearby_hospitals))}
Available Specialists:
  {chr(10).join(f'  - {s}' for s in obs.available_specialists)}
Actions already taken: {obs.previous_actions or '(none)'}

What is your next action? Respond ONLY with valid JSON.
"""


# ─────────────────────────────────────────────
#  Agent loop
# ─────────────────────────────────────────────

def parse_action(response_text: str) -> Optional[Action]:
    """Extract a valid Action from the model's raw JSON response."""
    text = response_text.strip()

    # Strip markdown code fences if present
    if text.startswith("```"):
        lines = text.splitlines()
        text = "\n".join(lines[1:-1] if lines and lines[-1].strip() == "```" else lines[1:])

    # Extract the first JSON object defensively (models sometimes add extra prose).
    m = re.search(r"\{[\s\S]*\}", text)
    if m:
        text = m.group(0)

    try:
        data = json.loads(text)
    except (json.JSONDecodeError, ValueError) as exc:
        # Keep logs compatible with strict parsers: no free-form prefixes.
        print(f"[STEP] event=parse_error detail={str(exc)[:120]!r}")
        return None

    action_type = str(data.get("action_type", "request_more_info")).strip()
    target = data.get("target")
    if action_type not in VALID_ACTION_TYPES:
        return Action(action_type="request_more_info", target=None)
    if target is not None and not isinstance(target, str):
        # Keep schema strict: targets are strings or null.
        target = str(target)
    return Action(action_type=action_type, target=target)


def rules_agent(obs) -> Action:
    """
    Deterministic baseline policy.
    Designed to be fully offline and reproducible for judge evaluation.
    """
    labs = obs.lab_report_summary or {}
    symptoms = (obs.symptoms or "").lower()

    # 1) Emergency detection / severity inference
    spo2_raw = str(labs.get("spo2", "")).lower()
    gcs_raw = str(labs.get("consciousness", "")).lower()
    emergency_signals = any(
        s in spo2_raw for s in ["78", "79", "80", "81", "82", "83", "84"]
    ) or ("unresponsive" in gcs_raw) or ("cyanotic" in symptoms) or ("collapse" in symptoms)

    if not any(a.startswith("analyze_symptoms:") for a in obs.previous_actions):
        if emergency_signals:
            return Action(action_type="analyze_symptoms", target="critical")
        # STEMI-ish signals
        ecg = str(labs.get("ecg_finding", "")).lower()
        troponin = str(labs.get("troponin_i", "")).lower()
        if "st-segment elevation" in ecg or "elevated" in troponin:
            return Action(action_type="analyze_symptoms", target="high")
        # Default outpatient
        return Action(action_type="analyze_symptoms", target="low")

    # 2) Route to specialist
    if not any(a.startswith("recommend_specialist:") for a in obs.previous_actions):
        if emergency_signals:
            return Action(action_type="recommend_specialist", target="Emergency Doctor")
        # Cardiology cues
        ecg = str(labs.get("ecg_finding", "")).lower()
        troponin = str(labs.get("troponin_i", "")).lower()
        if "st-segment elevation" in ecg or "elevated" in troponin:
            return Action(action_type="recommend_specialist", target="Cardiologist")
        return Action(action_type="recommend_specialist", target="General Physician")

    # 3) Choose hospital (prefer first listed)
    if not any(a.startswith("select_hospital:") for a in obs.previous_actions):
        best = obs.nearby_hospitals[0] if obs.nearby_hospitals else "General Hospital"
        return Action(action_type="select_hospital", target=best)

    # 4) Close episode
    if emergency_signals:
        return Action(action_type="call_ambulance", target=None)
    return Action(action_type="book_appointment", target=None)


def run_episode(client: Optional[OpenAI], difficulty: str, agent: str) -> Dict[str, Any]:
    """Run a complete episode for the given difficulty and return the summary."""
    env = MediRouteEnv()
    obs = env.reset(difficulty=difficulty)
    conversation: List[Dict[str, str]] = []
    step = 0
    episode_start = time.time()

    print(f"[START] difficulty={difficulty.upper()} agent={agent} symptoms={obs.symptoms!r}")

    while step < MAX_STEPS_PER_EPISODE:
        step += 1
        if agent == "rules":
            action = rules_agent(obs)
        else:
            user_msg = build_user_message(obs, step)
            conversation.append({"role": "user", "content": user_msg})

            # ── Call the model ────────────────────────────────────────────────
            assistant_text = ""
            for attempt in range(2):
                try:
                    if client is None:
                        raise RuntimeError("OpenAI client not initialized.")
                    response = client.chat.completions.create(
                        model=MODEL_NAME,
                        messages=[{"role": "system", "content": SYSTEM_PROMPT}] + conversation,
                        temperature=0.0,   # deterministic (to the extent the backend supports it)
                        max_tokens=256,
                    )
                    assistant_text = response.choices[0].message.content or ""
                except Exception as exc:
                    print(f"[STEP] step={step} event=llm_error detail={str(exc)[:160]!r}")
                    assistant_text = ""
                    break

                action = parse_action(assistant_text)
                if action is not None:
                    break

                # One corrective retry: ask for strict JSON only.
                conversation.append({"role": "assistant", "content": assistant_text})
                conversation.append(
                    {
                        "role": "user",
                        "content": "Your last response was invalid. Respond with ONLY a JSON object with keys action_type and target.",
                    }
                )

            if assistant_text:
                conversation.append({"role": "assistant", "content": assistant_text})

            if action is None:
                action = Action(action_type="request_more_info", target=None)

        # ── Step environment ──────────────────────────────────────────────────
        result = env.step(action)

        reward_sign = "+" if result.reward >= 0 else ""
        print(
            f"[STEP {step}] action={action.action_type}  "
            f"target={action.target!r}  "
            f"reward={reward_sign}{result.reward:.2f}  "
            f"total={result.info.get('total_reward', 0):.2f}  "
            f"done={result.done}"
        )

        obs = result.observation

        if result.done:
            break

    elapsed = time.time() - episode_start
    summary = env.state().previous_actions  # all actions taken

    final_info = result.info if step > 0 else {}
    episode_summary = final_info.get("episode_summary", {})
    total_reward = final_info.get("total_reward", 0.0)

    print(f"[END] difficulty={difficulty.upper()} agent={agent}  "
          f"score={total_reward:.4f}  "
          f"passed={episode_summary.get('passed', False)}  "
          f"steps={step}  "
          f"elapsed={elapsed:.1f}s "
          f"breakdown={json.dumps(episode_summary.get('breakdown', {}))}")

    return {
        "difficulty": difficulty,
        "score": total_reward,
        "passed": episode_summary.get("passed", False),
        "steps": step,
        "elapsed_seconds": round(elapsed, 2),
        "breakdown": episode_summary.get("breakdown", {}),
    }


# ─────────────────────────────────────────────
#  Main
# ─────────────────────────────────────────────

def main() -> None:
    parser = argparse.ArgumentParser(description="MediRoute OpenEnv β€” Baseline Inference")
    parser.add_argument(
        "--difficulty",
        choices=["easy", "medium", "hard", "all"],
        default="all",
        help="Which task(s) to run (default: all)",
    )
    parser.add_argument(
        "--agent",
        choices=["llm", "rules"],
        default="llm",
        help="Agent policy: llm (OpenAI-compatible) or rules (offline deterministic baseline).",
    )
    args = parser.parse_args()

    # Keep output machine-parseable: rely on [START]/[STEP]/[END] markers.

    client: Optional[OpenAI] = None
    if args.agent == "llm":
        try:
            from openai import OpenAI as OpenAIClient  # type: ignore
        except ImportError:
            print("[ERROR] openai package not found. Install it or run with: --agent rules")
            sys.exit(1)
        client = OpenAIClient(api_key=API_KEY, base_url=API_BASE_URL)

    difficulties = ALL_DIFFICULTIES if args.difficulty == "all" else [args.difficulty]
    results = []

    for diff in difficulties:
        result = run_episode(client, diff, agent=args.agent)
        results.append(result)

    # ── Final leaderboard ─────────────────────────────────────────────────────
    avg_score = sum(r["score"] for r in results) / len(results)

    # Emit one final [END] summary line for strict log parsers.
    print(f"[END] summary average_score={avg_score:.4f} results={json.dumps(results)}")


if __name__ == "__main__":
    main()