| """Baseline inference script for the Blood Pathology LIMS Environment. |
| |
| Runs an LLM agent through easy/medium/hard clinical diagnostic scenarios. |
| Uses structured tool definitions and handles JSON parsing robustly. |
| |
| Log format follows the mandatory [START], [STEP], [END] plain-text specification |
| from the hackathon guidelines. |
| """ |
| import asyncio |
| import os |
| import json |
| import re |
| from typing import List |
| from client import PathologyEnv |
| from models import PathologyAction |
| from openai import OpenAI, AsyncOpenAI |
|
|
| |
| |
| |
| |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-26B-A4B-it:novita") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if not HF_TOKEN: |
| raise ValueError("HF_TOKEN environment variable is required") |
| API_KEY = HF_TOKEN |
|
|
| |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") |
| |
| HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://yatin-superintelligence-pathology-env.hf.space") |
| BENCHMARK = "pathology_env" |
| MAX_STEPS = 20 |
| MAX_TOTAL_REWARD = 1.0 |
| SUCCESS_SCORE_THRESHOLD = 0.8 |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def log_start(task: str, env: str, model: str): |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error=None): |
| done_str = "true" if done else "false" |
| error_str = str(error) if error else "null" |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_str} error={error_str}", flush=True) |
|
|
| def log_end(success: bool, steps: int, score: float, rewards: List[float]): |
| success_str = "true" if success else "false" |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print(f"[END] success={success_str} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True) |
|
|
|
|
| |
| |
| |
| TOOLS = [ |
| {"name": "list_pending_cases", "args": {}, "description": "List all cases assigned to the AI agent."}, |
| {"name": "get_patient_demographics", "args": {"patient_id": "int"}, "description": "Get patient info including age, sex, medical history, and clinical flags."}, |
| {"name": "get_medications", "args": {"patient_id": "int"}, "description": "Get active medications for a patient. Critical for drug-lab interactions."}, |
| {"name": "get_lab_orders", "args": {"patient_id": "int"}, "description": "List all lab orders (panels) for a patient."}, |
| {"name": "get_lab_results", "args": {"order_id": "string"}, "description": "Get detailed results for a specific lab order."}, |
| {"name": "get_previous_results", "args": {"patient_id": "int", "analyte": "string (optional)"}, "description": "Get historical lab values for delta/trend comparison."}, |
| {"name": "query_reference_ranges", "args": {"analyte": "string", "context": {"sex": "M/F", "flags": "[...]"}}, "description": "Get reference ranges adjusted for sex, pregnancy, and therapeutic context."}, |
| {"name": "flag_critical_value", "args": {"patient_id": "int", "analyte": "string", "value": "float", "alert_level": "string"}, "description": "Flag a critical lab value and notify the attending physician."}, |
| {"name": "submit_diagnostic_report", "args": {"patient_id": "int", "icd_code": "string", "severity": "INFO|WARNING|HIGH|CRITICAL", "clinical_notes": "string"}, "description": "Submit final diagnosis with ICD-10 code, severity, and clinical reasoning."}, |
| ] |
|
|
| SYSTEM_PROMPT = f"""You are an AI Pathology Agent in a hospital Laboratory Information Management System (LIMS). |
| You must review pending cases, query patient data, interpret lab results using |
| proper clinical context (demographics, medications, reference ranges, previous results), |
| flag any critical values, and submit a final diagnostic report with an ICD-10 code. |
| |
| AVAILABLE TOOLS: |
| {json.dumps(TOOLS, indent=2)} |
| |
| CLINICAL WORKFLOW: |
| 1. list_pending_cases → identify STAT cases first |
| 2. get_patient_demographics → check age, sex, flags (pregnancy, warfarin, etc.) |
| 3. get_medications → identify drugs that affect lab interpretation |
| 4. get_lab_orders → find all panels ordered |
| 5. get_lab_results → review each panel's results |
| 6. get_previous_results → check trends/deltas for concerning analytes |
| 7. query_reference_ranges → get context-adjusted ranges (pass sex + flags from demographics) |
| 8. flag_critical_value → alert if any value exceeds critical thresholds |
| 9. submit_diagnostic_report → final diagnosis with ICD-10 code and clinical notes |
| |
| RULES: |
| - Always check demographics BEFORE interpreting results (sex/age/pregnancy affect ranges) |
| - Always check medications (some drugs cause expected lab abnormalities) |
| - Check previous results for delta changes when available |
| - Flag critical values BEFORE submitting your report |
| - Use proper ICD-10 codes. If all normal for context, use icd_code='NORMAL' |
| - severity: INFO | WARNING | HIGH | CRITICAL |
| - Focus on the STAT priority case(s) first |
| |
| Response format: {{"command": "...", "arguments": {{...}}}} |
| Respond ONLY with raw JSON, no markdown fences or explanation.""" |
|
|
|
|
| def parse_llm_response(raw: str) -> dict: |
| """Robustly parse LLM response to extract JSON command. |
| |
| Handles: raw JSON, markdown fences, nested JSON objects, trailing text. |
| """ |
| raw = raw.strip() |
| raw = re.sub(r'^```(?:json)?\s*', '', raw) |
| raw = re.sub(r'\s*```$', '', raw) |
| raw = raw.strip() |
|
|
| try: |
| return json.loads(raw) |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| depth = 0 |
| start = -1 |
| for i, ch in enumerate(raw): |
| if ch == '{': |
| if depth == 0: |
| start = i |
| depth += 1 |
| elif ch == '}': |
| depth -= 1 |
| if depth == 0 and start != -1: |
| try: |
| return json.loads(raw[start:i + 1]) |
| except json.JSONDecodeError: |
| start = -1 |
|
|
| return {"command": "invalid", "arguments": {}} |
|
|
|
|
| async def run_task(level: str): |
| """Run a single diagnostic task at the given difficulty level.""" |
| client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| if LOCAL_IMAGE_NAME: |
| print(f"[DEBUG] Using Docker image: {LOCAL_IMAGE_NAME}", flush=True) |
| try: |
| env = await PathologyEnv.from_docker_image(LOCAL_IMAGE_NAME, env_vars={"TASK_LEVEL": level}) |
| except TypeError: |
| env = await PathologyEnv.from_docker_image(LOCAL_IMAGE_NAME) |
| else: |
| print(f"[DEBUG] No LOCAL_IMAGE_NAME set, connecting to HF Space: {HF_SPACE_URL}", flush=True) |
| env = PathologyEnv(HF_SPACE_URL, connect_timeout_s=30, message_timeout_s=120) |
| await env.connect() |
|
|
| history: List[str] = [] |
| rewards: List[float] = [] |
| steps_taken = 0 |
| score = 0.0 |
| success = False |
| last_reward = 0.0 |
| trajectory = [] |
|
|
| log_start(task=level, env=BENCHMARK, model=MODEL_NAME) |
|
|
| try: |
| result = await env.reset(task_level=level) |
| current_obs = result.observation.output |
|
|
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": f"Environment ready.\n{current_obs}\n\nBegin by listing pending cases."} |
| ] |
|
|
| for step in range(1, MAX_STEPS + 1): |
| if result.done: |
| break |
|
|
| try: |
| response = await client.chat.completions.create( |
| model=MODEL_NAME, messages=messages, max_tokens=512, temperature=0.0 |
| ) |
| raw = response.choices[0].message.content.strip() |
| except Exception as exc: |
| print(f"[DEBUG] Model request failed: {exc}", flush=True) |
| raw = '{"command": "list_pending_cases", "arguments": {}}' |
|
|
| data = parse_llm_response(raw) |
| action = PathologyAction( |
| command=data.get("command", ""), |
| arguments=data.get("arguments", {}) |
| ) |
|
|
| messages.append({"role": "assistant", "content": raw}) |
|
|
| |
| try: |
| result = await env.step(action) |
| except Exception as ws_err: |
| print(f"[DEBUG] WebSocket error, reconnecting: {ws_err}", flush=True) |
| try: |
| await env.close() |
| except Exception: |
| pass |
| env = PathologyEnv(HF_SPACE_URL, connect_timeout_s=30, message_timeout_s=120) |
| await env.connect() |
| result = await env.reset(task_level=level) |
| result = await env.step(action) |
|
|
| obs = result.observation |
| reward = result.reward or 0.0 |
| done = result.done or False |
| error = obs.error if obs.error else None |
|
|
| |
| delta = reward - last_reward |
| rewards.append(delta) |
| last_reward = reward |
| steps_taken = step |
|
|
| |
| action_str = f"{action.command}({json.dumps(action.arguments)})" |
| log_step(step=step, action=action_str, reward=reward, done=done, error=error) |
|
|
| |
| trajectory.append({ |
| "step": step, |
| "command": action.command, |
| "arguments": action.arguments, |
| "reward": round(reward, 4), |
| "delta": round(delta, 4), |
| "done": done, |
| "error": error, |
| "env_output": obs.output[:500] if obs.output else "", |
| "raw_llm": raw[:300], |
| }) |
|
|
| history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}") |
|
|
| feedback = f"Output: {obs.output}" |
| if obs.error: |
| feedback += f"\nError: {obs.error}" |
| feedback += f"\nReward: {reward} | Done: {done}" |
| messages.append({"role": "user", "content": feedback}) |
|
|
| if done: |
| break |
|
|
| |
| score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0 |
| score = min(max(score, 0.01), 0.99) |
| success = score >= SUCCESS_SCORE_THRESHOLD |
|
|
| finally: |
| try: |
| await env.close() |
| except Exception as e: |
| print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True) |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) |
|
|
| |
| _save_trajectory_report(level, score, steps_taken, trajectory) |
| return score |
|
|
|
|
| def _save_trajectory_report(level: str, score: float, steps: int, trajectory: list): |
| """Save a full trajectory report as JSON for post-run analysis.""" |
| from datetime import datetime |
| os.makedirs("trajectories", exist_ok=True) |
| model_short = MODEL_NAME.replace("/", "_").replace(":", "-") |
| timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") |
| filename = f"trajectories/{timestamp}_{model_short}_{level}.json" |
|
|
| report = { |
| "model": MODEL_NAME, |
| "task_level": level, |
| "score": round(score, 4), |
| "steps_used": steps, |
| "max_steps": MAX_STEPS, |
| "timestamp": timestamp, |
| "trajectory": trajectory, |
| } |
|
|
| with open(filename, "w") as f: |
| json.dump(report, f, indent=2, ensure_ascii=False) |
| print(f"[DEBUG] Trajectory report saved: {filename}", flush=True) |
|
|
|
|
| async def main(): |
| """Run all three difficulty levels sequentially and produce a summary.""" |
| scores = {} |
| for level in ["easy", "medium", "hard"]: |
| score = await run_task(level) |
| scores[level] = score |
|
|
| |
| print("\n" + "=" * 60, flush=True) |
| print(f"MODEL: {MODEL_NAME}", flush=True) |
| print("=" * 60, flush=True) |
| for level, score in scores.items(): |
| status = "✅" if score >= SUCCESS_SCORE_THRESHOLD else "❌" |
| print(f" {status} {level:8s}: {score:.2f}", flush=True) |
| avg = sum(scores.values()) / len(scores) if scores else 0 |
| print(f" Average: {avg:.2f}", flush=True) |
| print("=" * 60, flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| asyncio.run(main()) |
| except Exception as exc: |
| print(f"[END] success=false steps=0 score=0.00 rewards= error={exc}", flush=True) |
| raise |
|
|